[mlir][tosa] TosaToLinalg: Lower TOSA.Cast via RoundEven according to TOSA spec 0...
authorMatthias Gehre <matthias.gehre@xilinx.com>
Thu, 23 Mar 2023 01:08:43 +0000 (01:08 +0000)
committerRobert Suderman <suderman@google.com>
Thu, 23 Mar 2023 01:52:26 +0000 (01:52 +0000)
TOSA now specifies rounding of ties to even in section 1.8.2., "Main Inference Profile"

Reviewed By: eric-k256, rsuderman

Differential Revision: https://reviews.llvm.org/D146617

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index 271a095..be24f5e 100644 (file)
@@ -471,11 +471,6 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     }
 
     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
-      auto zero = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getF32FloatAttr(0.0f));
-      auto half = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getF32FloatAttr(0.5f));
-
       auto intMin = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getF32FloatAttr(
                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
@@ -486,12 +481,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
                        .getSExtValue()));
 
-      auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
-      auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
-      auto negative = rewriter.create<arith::CmpFOp>(
-          loc, arith::CmpFPredicate::OLT, args[0], zero);
-      auto rounded =
-          rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
+      auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
 
       auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
 
index 133999e..476131b 100644 (file)
@@ -237,14 +237,9 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
   %19 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
-  // CHECK: arith.constant 0.000000e+00
-  // CHECK: arith.constant 5.000000e-01
   // CHECK: arith.constant -2.14748365E+9
   // CHECK: arith.constant 2.14748365E+9
-  // CHECK: arith.addf
-  // CHECK: arith.subf
-  // CHECK: arith.cmpf olt
-  // CHECK: select
+  // CHECK: math.roundeven
   // CHECK: arith.minf
   // CHECK: arith.maxf
   // CHECK: arith.fptosi