From: Adrian Kuegel Date: Wed, 21 Jul 2021 10:44:34 +0000 (+0200) Subject: [mlir][Complex]: Add lowerings for AddOp and SubOp from Complex dialect to X-Git-Tag: llvmorg-14-init~492 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fb978f092c9c1eff56906c65123944140c89f9cd;p=platform%2Fupstream%2Fllvm.git [mlir][Complex]: Add lowerings for AddOp and SubOp from Complex dialect to Standard. Differential Revision: https://reviews.llvm.org/D106429 --- diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 4d3d522..f651eed 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -79,6 +79,35 @@ struct ComparisonOpConversion : public OpConversionPattern { } }; +// Default conversion which applies the BinaryStandardOp separately on the real +// and imaginary parts. Can for example be used for complex::AddOp and +// complex::SubOp. +template +struct BinaryComplexOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BinaryComplexOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + typename BinaryComplexOp::Adaptor transformed(operands); + auto type = transformed.lhs().getType().template cast(); + auto elementType = type.getElementType().template cast(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + Value realLhs = b.create(elementType, transformed.lhs()); + Value realRhs = b.create(elementType, transformed.rhs()); + Value resultReal = + b.create(elementType, realLhs, realRhs); + Value imagLhs = b.create(elementType, transformed.lhs()); + Value imagRhs = b.create(elementType, transformed.rhs()); + Value resultImag = + b.create(elementType, imagLhs, imagRhs); + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); + return success(); + } +}; + struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -554,6 +583,8 @@ void mlir::populateComplexToStandardConversionPatterns( AbsOpConversion, ComparisonOpConversion, ComparisonOpConversion, + BinaryComplexOpConversion, + BinaryComplexOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, @@ -578,12 +609,8 @@ void ConvertComplexToStandardPass::runOnFunction() { populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalOp(); + target.addLegalDialect(); + target.addLegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 765d79c..9d9593a 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -14,6 +14,21 @@ func @complex_abs(%arg: complex) -> f32 { // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 +// CHECK-LABEL: func @complex_add +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func @complex_add(%lhs: complex, %rhs: complex) -> complex { + %add = complex.add %lhs, %rhs: complex + return %add : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = addf %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_div // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func @complex_div(%lhs: complex, %rhs: complex) -> complex { @@ -366,3 +381,18 @@ func @complex_sign(%arg: complex) -> complex { // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex // CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex // CHECK: return %[[RESULT]] : complex + +// CHECK-LABEL: func @complex_sub +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func @complex_sub(%lhs: complex, %rhs: complex) -> complex { + %sub = complex.sub %lhs, %rhs: complex + return %sub : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = subf %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex