[mlir][complex] Add complex.conj op
authorlewuathe <lewuathe@me.com>
Tue, 7 Jun 2022 07:37:20 +0000 (09:37 +0200)
committerAlexander Belyaev <pifon@google.com>
Tue, 7 Jun 2022 07:38:35 +0000 (09:38 +0200)
Add complex.conj op to calculate the complex conjugate which is widely used for the mathematical operation on the complex space.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D127181

mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

index 054289f..21797d3 100644 (file)
@@ -564,4 +564,24 @@ def TanOp : ComplexUnaryOp<"tan", [SameOperandsAndResultType]> {
   let results = (outs Complex<AnyFloat>:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// Conj
+//===----------------------------------------------------------------------===//
+
+def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> {
+  let summary = "Calculate the complex conjugate";
+  let description = [{
+    The `conj` op takes a single complex number and computes the
+    complex conjugate.
+
+    Example:
+
+    ```mlir
+    %a = complex.conj %b: complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 #endif // COMPLEX_OPS
index 2e981c0..e314f2e 100644 (file)
@@ -885,6 +885,27 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
   }
 };
 
+struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
+  using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto type = adaptor.getComplex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+    Value real =
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+    Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
@@ -893,23 +914,25 @@ void mlir::populateComplexToStandardConversionPatterns(
   patterns.add<
       AbsOpConversion,
       Atan2OpConversion,
-      ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
-      ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
       BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
+      ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
+      ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
+      ConjOpConversion,
       CosOpConversion,
       DivOpConversion,
       ExpOpConversion,
       Expm1OpConversion,
-      LogOpConversion,
       Log1pOpConversion,
+      LogOpConversion,
       MulOpConversion,
       NegOpConversion,
       SignOpConversion,
       SinOpConversion,
       SqrtOpConversion,
       TanOpConversion,
-      TanhOpConversion>(patterns.getContext());
+      TanhOpConversion
+  >(patterns.getContext());
   // clang-format on
 }
 
index 319a443..9687523 100644 (file)
@@ -663,3 +663,17 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
   %sqrt = complex.sqrt %arg : complex<f32>
   return %sqrt : complex<f32>
 }
+
+// -----
+
+// CHECK-LABEL: func @complex_conj
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_conj(%arg: complex<f32>) -> complex<f32> {
+  %conj = complex.conj %arg: complex<f32>
+  return %conj : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file