}
};
+// Default conversion which applies the BinaryStandardOp separately on the real
+// and imaginary parts. Can for example be used for complex::AddOp and
+// complex::SubOp.
+template <typename BinaryComplexOp, typename BinaryStandardOp>
+struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
+ using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ typename BinaryComplexOp::Adaptor transformed(operands);
+ auto type = transformed.lhs().getType().template cast<ComplexType>();
+ auto elementType = type.getElementType().template cast<FloatType>();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
+ Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
+ Value resultReal =
+ b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
+ Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
+ Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
+ Value resultImag =
+ b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+ resultImag);
+ return success();
+ }
+};
+
struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
AbsOpConversion,
ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
+ BinaryComplexOpConversion<complex::AddOp, AddFOp>,
+ BinaryComplexOpConversion<complex::SubOp, SubFOp>,
DivOpConversion,
ExpOpConversion,
LogOpConversion,
populateComplexToStandardConversionPatterns(patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<StandardOpsDialect, math::MathDialect,
- complex::ComplexDialect>();
- target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
- complex::ExpOp, complex::LogOp, complex::Log1pOp,
- complex::MulOp, complex::NegOp, complex::NotEqualOp,
- complex::SignOp>();
+ target.addLegalDialect<StandardOpsDialect, math::MathDialect>();
+ target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: return %[[NORM]] : f32
+// CHECK-LABEL: func @complex_add
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %add = complex.add %lhs, %rhs: complex<f32>
+ return %add : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = addf %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
// CHECK-LABEL: func @complex_div
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+
+// CHECK-LABEL: func @complex_sub
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %sub = complex.sub %lhs, %rhs: complex<f32>
+ return %sub : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = subf %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>