From 6c6eddb6172f910c7e38d1327e5c6493b62c2950 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Wed, 8 Jun 2022 09:11:28 -0700 Subject: [PATCH] [mlir] Lower complex.power and complex.rsqrt to standard dialect. Add conversion tests and correctness tests. Reviewed By: pifon2a Differential Revision: https://reviews.llvm.org/D127255 --- .../ComplexToStandard/ComplexToStandard.cpp | 107 ++++++++++++++++++++- .../ComplexToStandard/convert-to-standard.mlir | 19 +++- .../Dialect/Complex/CPU/correctness.mlir | 64 ++++++++++++ 3 files changed, 188 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index e314f2e..0a5124a 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -906,6 +906,109 @@ struct ConjOpConversion : public OpConversionPattern { } }; +/// Coverts x^y = (a+bi)^(c+di) to +/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), +/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) +static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, + ComplexType type, Value a, Value b, Value c, + Value d) { + auto elementType = type.getElementType().cast(); + + // Compute (a*a+b*b)^(0.5c). + Value aaPbb = builder.create( + builder.create(a, a), builder.create(b, b)); + Value half = builder.create( + elementType, builder.getFloatAttr(elementType, 0.5)); + Value halfC = builder.create(half, c); + Value aaPbbTohalfC = builder.create(aaPbb, halfC); + + // Compute exp(-d*atan2(b,a)). + Value negD = builder.create(d); + Value argX = builder.create(b, a); + Value negDArgX = builder.create(negD, argX); + Value eToNegDArgX = builder.create(negDArgX); + + // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)). + Value coeff = builder.create(aaPbbTohalfC, eToNegDArgX); + + // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b). + Value lnAaPbb = builder.create(aaPbb); + Value halfD = builder.create(half, d); + Value q = builder.create( + builder.create(c, argX), + builder.create(halfD, lnAaPbb)); + + Value cosQ = builder.create(q); + Value sinQ = builder.create(q); + Value zero = builder.create( + elementType, builder.getFloatAttr(elementType, 0)); + Value one = builder.create( + elementType, builder.getFloatAttr(elementType, 1)); + + Value xEqZero = + builder.create(arith::CmpFPredicate::OEQ, aaPbb, zero); + Value yGeZero = builder.create( + builder.create(arith::CmpFPredicate::OGE, c, zero), + builder.create(arith::CmpFPredicate::OEQ, d, zero)); + Value cEqZero = + builder.create(arith::CmpFPredicate::OEQ, c, zero); + Value complexZero = builder.create(type, zero, zero); + Value complexOne = builder.create(type, one, zero); + Value complexOther = builder.create( + type, builder.create(coeff, cosQ), + builder.create(coeff, sinQ)); + + // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + return builder.create( + builder.create(xEqZero, yGeZero), + builder.create(cEqZero, complexOne, complexZero), + complexOther); +} + +struct PowOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + auto type = adaptor.getLhs().getType().cast(); + auto elementType = type.getElementType().cast(); + + Value a = builder.create(elementType, adaptor.getLhs()); + Value b = builder.create(elementType, adaptor.getLhs()); + Value c = builder.create(elementType, adaptor.getRhs()); + Value d = builder.create(elementType, adaptor.getRhs()); + + rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); + return success(); + } +}; + +struct RsqrtOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + auto type = adaptor.getComplex().getType().cast(); + auto elementType = type.getElementType().cast(); + + Value a = builder.create(elementType, adaptor.getComplex()); + Value b = builder.create(elementType, adaptor.getComplex()); + Value c = builder.create( + elementType, builder.getFloatAttr(elementType, -0.5)); + Value d = builder.create( + elementType, builder.getFloatAttr(elementType, 0)); + + rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); + return success(); + } +}; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -931,7 +1034,9 @@ void mlir::populateComplexToStandardConversionPatterns( SinOpConversion, SqrtOpConversion, TanOpConversion, - TanhOpConversion + TanhOpConversion, + PowOpConversion, + RsqrtOpConversion >(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 9687523..5b37899 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -676,4 +676,21 @@ func.func @complex_conj(%arg: complex) -> complex { // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex -// CHECK: return %[[RESULT]] : complex \ No newline at end of file +// CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func.func @complex_pow +func.func @complex_pow(%lhs: complex, + %rhs: complex) -> complex { + %pow = complex.pow %lhs, %rhs : complex + return %pow : complex +} + +// ----- + +// CHECK-LABEL: func.func @complex_rsqrt +func.func @complex_rsqrt(%arg: complex) -> complex { + %rsqrt = complex.rsqrt %arg : complex + return %rsqrt : complex +} diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir index 67867f3..00ab3ed 100644 --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -38,6 +38,11 @@ func.func @tanh(%arg: complex) -> complex { func.return %tanh : complex } +func.func @rsqrt(%arg: complex) -> complex { + %sqrt = complex.rsqrt %arg : complex + func.return %sqrt : complex +} + // %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] func.func @test_binary(%input: tensor>, %func: (complex, complex) -> complex) { @@ -67,6 +72,10 @@ func.func @atan2(%lhs: complex, %rhs: complex) -> complex { func.return %atan2 : complex } +func.func @pow(%lhs: complex, %rhs: complex) -> complex { + %pow = complex.pow %lhs, %rhs : complex + func.return %pow : complex +} func.func @entry() { // complex.sqrt test @@ -121,6 +130,30 @@ func.func @entry() { : (tensor>, (complex, complex) -> complex) -> () + // complex.pow test + %pow_test = arith.constant dense<[ + (0.0, 0.0), (0.0, 0.0), + // CHECK: 1 + // CHECK-NEXT: 0 + (0.0, 0.0), (1.0, 0.0), + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + (0.0, 0.0), (-1.0, 0.0), + // CHECK-NEXT: -nan + // CHECK-NEXT: -nan + (1.0, 1.0), (1.0, 1.0) + // CHECK-NEXT: 0.273 + // CHECK-NEXT: 0.583 + ]> : tensor<8xcomplex> + %pow_test_cast = tensor.cast %pow_test + : tensor<8xcomplex> to tensor> + + %pow_func = func.constant @pow : (complex, complex) + -> complex + call @test_binary(%pow_test_cast, %pow_func) + : (tensor>, (complex, complex) + -> complex) -> () + // complex.tanh test %tanh_test = arith.constant dense<[ (-1.0, -1.0), @@ -152,5 +185,36 @@ func.func @entry() { call @test_unary(%tanh_test_cast, %tanh_func) : (tensor>, (complex) -> complex) -> () + // complex.rsqrt test + %rsqrt_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: 0.321 + // CHECK-NEXT: 0.776 + (-1.0, 1.0), + // CHECK-NEXT: 0.321 + // CHECK-NEXT: -0.776 + (0.0, 0.0), + // CHECK-NEXT: nan + // CHECK-NEXT: nan + (0.0, 1.0), + // CHECK-NEXT: 0.707 + // CHECK-NEXT: -0.707 + (1.0, -1.0), + // CHECK-NEXT: 0.776 + // CHECK-NEXT: 0.321 + (1.0, 0.0), + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 0.776 + // CHECK-NEXT: -0.321 + ]> : tensor<7xcomplex> + %rsqrt_test_cast = tensor.cast %rsqrt_test + : tensor<7xcomplex> to tensor> + + %rsqrt_func = func.constant @rsqrt : (complex) -> complex + call @test_unary(%rsqrt_test_cast, %rsqrt_func) + : (tensor>, (complex) -> complex) -> () + func.return } -- 2.7.4