From f5efe2807056d3fa525e51d35ea94c91e0945eb2 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Thu, 17 Feb 2022 10:20:50 +0100 Subject: [PATCH] [mlir] Propagate NaNs in PolynomialApproximation Previously, NaNs would be dropped in favor of bounded values which was strictly incorrect. Now the min/max operation propagate this information. Not all uses of min/max need this, but the given change will help protect future additions, and this prevents the need for an additional cmpf and select operation to handle NaNs. Differential Revision: https://reviews.llvm.org/D120020 --- .../lib/Dialect/Math/Transforms/PolynomialApproximation.cpp | 13 +++++++++---- mlir/test/Dialect/Math/polynomial-approximation.mlir | 6 +++--- mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir | 5 +++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 5d3f629..dbd611c 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -182,16 +182,21 @@ static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { // Helper functions to build math functions approximations. //----------------------------------------------------------------------------// -static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { +// Return the minimum of the two values or NaN if value is NaN +static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { return builder.create( - builder.create(arith::CmpFPredicate::OLT, a, b), a, b); + builder.create(arith::CmpFPredicate::ULT, value, bound), + value, bound); } -static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { +// Return the maximum of the two values or NaN if value is NaN +static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { return builder.create( - builder.create(arith::CmpFPredicate::OGT, a, b), a, b); + builder.create(arith::CmpFPredicate::UGT, value, bound), + value, bound); } +// Return the clamped value or NaN if value is NaN static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound) { return max(builder, min(builder, value, upperBound), lowerBound); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir index 457e585..cff92f3 100644 --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -225,7 +225,7 @@ func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { // CHECK: %[[VAL_20:.*]] = arith.constant 1056964608 : i32 // CHECK: %[[VAL_21:.*]] = arith.constant 23 : i32 // CHECK: %[[VAL_22:.*]] = arith.constant 0.693147182 : f32 -// CHECK: %[[VAL_23:.*]] = arith.cmpf ogt, %[[X]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_23:.*]] = arith.cmpf ugt, %[[X]], %[[VAL_4]] : f32 // CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[X]], %[[VAL_4]] : f32 // CHECK-NOT: frexp // CHECK: %[[VAL_25:.*]] = arith.bitcast %[[VAL_24]] : f32 to i32 @@ -355,9 +355,9 @@ func @log1p_vector(%arg0: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[VAL_12:.*]] = arith.constant 0.00226843474 : f32 // CHECK: %[[VAL_13:.*]] = arith.constant 1.18534706E-4 : f32 // CHECK: %[[VAL_14:.*]] = arith.constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_15:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_15:.*]] = arith.cmpf ult, %[[VAL_0]], %[[VAL_2]] : f32 // CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_17:.*]] = arith.cmpf ugt, %[[VAL_16]], %[[VAL_1]] : f32 // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_16]], %[[VAL_1]] : f32 // CHECK: %[[VAL_19:.*]] = math.abs %[[VAL_0]] : f32 // CHECK: %[[VAL_20:.*]] = arith.cmpf olt, %[[VAL_19]], %[[VAL_3]] : f32 diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir index 413c04c..ccbd15b 100644 --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -29,6 +29,11 @@ func @tanh() { %5 = math.tanh %4 : vector<8xf32> vector.print %5 : vector<8xf32> + // CHECK-NEXT: nan + %nan = arith.constant 0x7fc00000 : f32 + %6 = math.tanh %nan : f32 + vector.print %6 : f32 + return } -- 2.7.4