From 391456f33c7a2518721eda92b27630fb1c37e5d6 Mon Sep 17 00:00:00 2001 From: bakhtiyar Date: Mon, 9 Aug 2021 15:54:16 -0700 Subject: [PATCH] Fix a bug in algebraic simplification, and enable the tests. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D107788 --- .../Math/Transforms/AlgebraicSimplification.cpp | 12 ++++++++--- .../Dialect/Math/algebraic-simplification.mlir | 24 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 2614fc7..8918b21 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -80,7 +80,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, return success(); } - // Replace `pow(x, 2.0)` with `x * x * x`. + // Replace `pow(x, 3.0)` with `x * x * x`. if (isExponentValue(3.0)) { Value square = rewriter.create(op.getLoc(), ValueRange({x, x})); rewriter.replaceOpWithNewOp(op, ValueRange({x, square})); @@ -95,12 +95,18 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, return success(); } - // Replace `pow(x, -2.0)` with `sqrt(x)`. - if (isExponentValue(-1.0)) { + // Replace `pow(x, 0.5)` with `sqrt(x)`. + if (isExponentValue(0.5)) { rewriter.replaceOpWithNewOp(op, x); return success(); } + // Replace `pow(x, -0.5)` with `rsqrt(x)`. + if (isExponentValue(-0.5)) { + rewriter.replaceOpWithNewOp(op, x); + return success(); + } + return failure(); } diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir index cb39bb7..8a81076 100644 --- a/mlir/test/Dialect/Math/algebraic-simplification.mlir +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -49,3 +49,27 @@ func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { %1 = math.powf %arg1, %v : vector<4xf32> return %0, %1 : f32, vector<4xf32> } + +// CHECK-LABEL: @pow_sqrt +func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0 + // CHECK: %[[VECTOR:.*]] = math.sqrt %arg1 + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = constant 0.5 : f32 + %v = constant dense <0.5> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} + +// CHECK-LABEL: @pow_rsqrt +func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) { + // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0 + // CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1 + // CHECK: return %[[SCALAR]], %[[VECTOR]] + %c = constant -0.5 : f32 + %v = constant dense <-0.5> : vector<4xf32> + %0 = math.powf %arg0, %c : f32 + %1 = math.powf %arg1, %v : vector<4xf32> + return %0, %1 : f32, vector<4xf32> +} -- 2.7.4