Avoid infinity arithmetics when computing exp approximations
authorAhmed Taei <ataei@google.com>
Wed, 20 Oct 2021 00:56:55 +0000 (17:56 -0700)
committerAhmed Taei <ataei@google.com>
Thu, 21 Oct 2021 17:09:18 +0000 (10:09 -0700)
Otherwise this can result a poison value on some platforms see https://bugs.llvm.org/show_bug.cgi?id=51204

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D112115

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/Dialect/Math/polynomial-approximation.mlir
mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

index 29fd0b6..7dddcfa 100644 (file)
@@ -567,15 +567,20 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
 
   Value isNegInfinityX = builder.create<arith::CmpFOp>(
       arith::CmpFPredicate::OEQ, x, constNegIfinity);
+  Value isPosInfinityX = builder.create<arith::CmpFOp>(
+      arith::CmpFPredicate::OEQ, x, constPosInfinity);
   Value isPostiveX =
       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
   Value isComputable = builder.create<arith::AndIOp>(rightBound, leftBound);
 
   expY = builder.create<SelectOp>(
-      isComputable, expY,
+      isNegInfinityX, zerof32Const,
       builder.create<SelectOp>(
-          isPostiveX, constPosInfinity,
-          builder.create<SelectOp>(isNegInfinityX, zerof32Const, underflow)));
+          isPosInfinityX, constPosInfinity,
+          builder.create<SelectOp>(isComputable, expY,
+                                   builder.create<SelectOp>(isPostiveX,
+                                                            constPosInfinity,
+                                                            underflow))));
 
   rewriter.replaceOp(op, expY);
 
index 7d9ec88..bc6f39b 100644 (file)
 // CHECK:           %[[VAL_31:.*]] = arith.cmpi sle, %[[VAL_26]], %[[VAL_13]] : i32
 // CHECK:           %[[VAL_32:.*]] = arith.cmpi sge, %[[VAL_26]], %[[VAL_14]] : i32
 // CHECK:           %[[VAL_33:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_11]] : f32
-// CHECK:           %[[VAL_34:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32
-// CHECK:           %[[VAL_35:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
-// CHECK:           %[[VAL_36:.*]] = select %[[VAL_33]], %[[VAL_9]], %[[VAL_12]] : f32
-// CHECK:           %[[VAL_37:.*]] = select %[[VAL_34]], %[[VAL_10]], %[[VAL_36]] : f32
-// CHECK:           %[[VAL_38:.*]] = select %[[VAL_35]], %[[VAL_30]], %[[VAL_37]] : f32
-// CHECK:           return %[[VAL_38]] : f32
-// CHECK:         }
+// CHECK:           %[[VAL_34:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_10]] : f32
+// CHECK:           %[[VAL_35:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32
+// CHECK:           %[[VAL_36:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
+// CHECK:           %[[VAL_37:.*]] = select %[[VAL_35]], %[[VAL_10]], %[[VAL_12]] : f32
+// CHECK:           %[[VAL_38:.*]] = select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32
+// CHECK:           %[[VAL_39:.*]] = select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32
+// CHECK:           %[[VAL_40:.*]] = select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32
+// CHECK:           return %[[VAL_40]] : f32
 func @exp_scalar(%arg0: f32) -> f32 {
   %0 = math.exp %arg0 : f32
   return %0 : f32
@@ -54,10 +55,9 @@ func @exp_scalar(%arg0: f32) -> f32 {
 // CHECK-SAME:                     %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32>
 // CHECK-NOT:       exp
-// CHECK-COUNT-2:   select
-// CHECK:           %[[VAL_38:.*]] = select
-// CHECK:           return %[[VAL_38]] : vector<8xf32>
-// CHECK:         }
+// CHECK-COUNT-3:   select
+// CHECK:           %[[VAL_40:.*]] = select
+// CHECK:           return %[[VAL_40]] : vector<8xf32>
 func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
   %0 = math.exp %arg0 : vector<8xf32>
   return %0 : vector<8xf32>
@@ -70,7 +70,7 @@ func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK-DAG:           %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK:           %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32
 // CHECK-NOT:       exp
-// CHECK-COUNT-2:   select
+// CHECK-COUNT-3:   select
 // CHECK:           %[[EXP_X:.*]] = select
 // CHECK:           %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32
 // CHECK:           %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32
@@ -95,7 +95,7 @@ func @expm1_scalar(%arg0: f32) -> f32 {
 // CHECK-SAME:                       %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8xf32>
 // CHECK-NOT:       exp
-// CHECK-COUNT-3:   select
+// CHECK-COUNT-4:   select
 // CHECK-NOT:       log
 // CHECK-COUNT-5:   select
 // CHECK-NOT:       expm1
index 8d66ec6..c0f319e 100644 (file)
@@ -10,9 +10,6 @@
 // RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext    \
 // RUN: | FileCheck %s
 
-// XFAIL: s390x
-// (see https://bugs.llvm.org/show_bug.cgi?id=51204)
-
 // -------------------------------------------------------------------------- //
 // Tanh.
 // -------------------------------------------------------------------------- //