[mlir] turn complex-to-llvm into a partial conversion
authorAlex Zinenko <zinenko@google.com>
Thu, 28 Jan 2021 16:42:41 +0000 (17:42 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 28 Jan 2021 18:14:01 +0000 (19:14 +0100)
It is no longer necessary to also convert other "standard" ops along with the
complex dialect: the element types are now built-in integers or floating point
types, and the top-level cast between complex and struct is automatically
inserted and removed in progressive lowering.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D95625

mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir [new file with mode: 0644]
mlir/test/Dialect/LLVMIR/dialect-cast.mlir

index 270b948..4421a9f 100644 (file)
@@ -286,12 +286,13 @@ void ConvertComplexToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect using the converter defined above.
   OwningRewritePatternList patterns;
   LLVMTypeConverter converter(&getContext());
-  populateStdToLLVMConversionPatterns(converter, patterns);
   populateComplexToLLVMConversionPatterns(converter, patterns);
 
   LLVMConversionTarget target(getContext());
-  target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-  if (failed(applyFullConversion(module, target, std::move(patterns))))
+  target.addLegalOp<ModuleOp, FuncOp>();
+  target.addLegalOp<LLVM::DialectCastOp>();
+  target.addIllegalDialect<complex::ComplexDialect>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
 
index a31402e..adf7ff7 100644 (file)
@@ -1375,6 +1375,17 @@ static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type,
     return success();
   }
 
+  // Complex types are compatible with the two-element structs.
+  if (auto complexType = type.dyn_cast<ComplexType>()) {
+    auto structType = llvmType.dyn_cast<LLVMStructType>();
+    if (!structType || structType.getBody().size() != 2 ||
+        structType.getBody()[0] != structType.getBody()[1] ||
+        structType.getBody()[0] != complexType.getElementType())
+      return op->emitOpError("expected 'complex' to map to two-element struct "
+                             "with identical element types");
+    return success();
+  }
+
   // Everything else is not supported.
   return op->emitError("unsupported cast");
 }
index ffc7bbf..6ad3595 100644 (file)
@@ -1,14 +1,13 @@
-// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-complex-to-llvm | FileCheck %s
 
-// CHECK-LABEL: llvm.func @complex_numbers()
-// CHECK-NEXT:    %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32
-// CHECK-NEXT:    %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32
+// CHECK-LABEL: func @complex_numbers
+// CHECK-NEXT:    %[[REAL0:.*]] = constant 1.200000e+00 : f32
+// CHECK-NEXT:    %[[IMAG0:.*]] = constant 3.400000e+00 : f32
 // CHECK-NEXT:    %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)>
 // CHECK-NEXT:    %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)>
 // CHECK-NEXT:    %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)>
 // CHECK-NEXT:    %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)>
 // CHECK-NEXT:    %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)>
-// CHECK-NEXT:    llvm.return
 func @complex_numbers() {
   %real0 = constant 1.2 : f32
   %imag0 = constant 3.4 : f32
@@ -18,9 +17,7 @@ func @complex_numbers() {
   return
 }
 
-// -----
-
-// CHECK-LABEL: llvm.func @complex_addition()
+// CHECK-LABEL: func @complex_addition
 // CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
 // CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
 // CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
@@ -41,9 +38,7 @@ func @complex_addition() {
   return
 }
 
-// -----
-
-// CHECK-LABEL: llvm.func @complex_substraction()
+// CHECK-LABEL: func @complex_substraction
 // CHECK-DAG:     %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)>
 // CHECK-DAG:     %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)>
 // CHECK-DAG:     %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)>
@@ -64,18 +59,19 @@ func @complex_substraction() {
   return
 }
 
-// -----
-
-// CHECK-LABEL: llvm.func @complex_div
-// CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+// CHECK-LABEL: func @complex_div
+// CHECK-SAME:    %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
   %div = complex.div %lhs, %rhs : complex<f32>
   return %div : complex<f32>
 }
-// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
-// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
-// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
-// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[CASTED_LHS:.*]] = llvm.mlir.cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[CASTED_RHS:.*]] = llvm.mlir.cast %[[RHS]] : complex<f32> to ![[C_TY]]
+
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
 
 // CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
 
@@ -95,20 +91,23 @@ func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
 // CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]]  : f32
 // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
-// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
-
-// -----
+//
+// CHECK: %[[CASTED_RESULT:.*]] = llvm.mlir.cast %[[RESULT_2]] : ![[C_TY]] to complex<f32>
+// CHECK: return %[[CASTED_RESULT]] : complex<f32>
 
-// CHECK-LABEL: llvm.func @complex_mul
-// CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+// CHECK-LABEL: func @complex_mul
+// CHECK-SAME:    %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
   %mul = complex.mul %lhs, %rhs : complex<f32>
   return %mul : complex<f32>
 }
-// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
-// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
-// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
-// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[CASTED_LHS:.*]] = llvm.mlir.cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[CASTED_RHS:.*]] = llvm.mlir.cast %[[RHS]] : complex<f32> to ![[C_TY]]
+
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
 // CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
 
 // CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]]  : f32
@@ -121,21 +120,22 @@ func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 
 // CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
 // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
-// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
 
-// -----
+// CHECK: %[[CASTED_RESULT:.*]] = llvm.mlir.cast %[[RESULT_2]] : ![[C_TY]] to complex<f32>
+// CHECK: return %[[CASTED_RESULT]] : complex<f32>
 
-// CHECK-LABEL: llvm.func @complex_abs
-// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+// CHECK-LABEL: func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func @complex_abs(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg: complex<f32>
   return %abs : f32
 }
-// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
-// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK: %[[CASTED_ARG:.*]] = llvm.mlir.cast %[[ARG]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[CASTED_ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[CASTED_ARG]][1] : ![[C_TY]]
 // CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]]  : f32
 // CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]]  : f32
 // CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]]  : f32
 // CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
-// CHECK: llvm.return %[[NORM]] : f32
+// CHECK: return %[[NORM]] : f32
 
diff --git a/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir
new file mode 100644 (file)
index 0000000..6844f70
--- /dev/null
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s
+
+// CHECK-LABEL: llvm.func @complex_div
+// CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %div = complex.div %lhs, %rhs : complex<f32>
+  return %div : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]]  : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]]  : f32
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]]  : f32
+// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]]  : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]]  : f32
+// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]]  : f32
+
+// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]]  : f32
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]]  : f32
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// CHECK-LABEL: llvm.func @complex_mul
+// CHECK-SAME:    %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]]
+func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %mul = complex.mul %lhs, %rhs : complex<f32>
+  return %mul : complex<f32>
+}
+// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]]
+
+// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]]  : f32
+// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]]  : f32
+// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]]  : f32
+
+// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]]  : f32
+// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]]  : f32
+// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]]  : f32
+
+// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
+// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
+// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]]
+
+// CHECK-LABEL: llvm.func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+func @complex_abs(%arg: complex<f32>) -> f32 {
+  %abs = complex.abs %arg: complex<f32>
+  return %abs : f32
+}
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]]  : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]]  : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]]  : f32
+// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: llvm.return %[[NORM]] : f32
+
index 90eaaa2..b72141d 100644 (file)
@@ -222,3 +222,31 @@ func @mlir_dialect_cast_unranked_rank(%0: memref<*xf32>) {
   // expected-error@+1 {{expected second element of a memref descriptor to be an !llvm.ptr<i8>}}
   llvm.mlir.cast %0 : memref<*xf32> to !llvm.struct<(i64, f32)>
 }
+
+// -----
+
+func @mlir_dialect_cast_complex_non_struct(%0: complex<f32>) {
+  // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}}
+  llvm.mlir.cast %0 : complex<f32> to f32
+}
+
+// -----
+
+func @mlir_dialect_cast_complex_bad_size(%0: complex<f32>) {
+  // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}}
+  llvm.mlir.cast %0 : complex<f32> to !llvm.struct<(f32, f32, f32)>
+}
+
+// -----
+
+func @mlir_dialect_cast_complex_mismatching_type_struct(%0: complex<f32>) {
+  // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}}
+  llvm.mlir.cast %0 : complex<f32> to !llvm.struct<(f32, f64)>
+}
+
+// -----
+
+func @mlir_dialect_cast_complex_mismatching_element(%0: complex<f32>) {
+  // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}}
+  llvm.mlir.cast %0 : complex<f32> to !llvm.struct<(f64, f64)>
+}