[mlir][tosa] Fix tosa.cast UiToFp32 for tosa-to-linalg
authorRob Suderman <rob.suderman@gmail.com>
Thu, 14 Oct 2021 02:21:09 +0000 (19:21 -0700)
committerRob Suderman <rob.suderman@gmail.com>
Thu, 14 Oct 2021 18:34:10 +0000 (11:34 -0700)
Part of the arith update broke UiToFp32. Fixed the lowering and included a new
test to detect a regression.

Differential Revision: https://reviews.llvm.org/D111772

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

index ff0296d7a063c216f7d60ee14731d0440aabb43d..3d12042960634cdb1be6466af625670de826344a 100644 (file)
@@ -567,11 +567,6 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
       return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
                                              mlir::None);
 
-    // All other si-to-fp conversions should be handled by SIToFP.
-    if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
-      return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
-                                              mlir::None);
-
     // Unsigned integers need an unrealized cast so that they can be passed
     // to UIToFP.
     if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
@@ -585,6 +580,11 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
                                               unrealizedCast);
     }
 
+    // All other si-to-fp conversions should be handled by SIToFP.
+    if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
+      return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
+                                              mlir::None);
+
     // Casting to boolean, floats need to only be checked as not-equal to zero.
     if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
       Value zero = rewriter.create<arith::ConstantOp>(
index 4db36aab197dbbb3f57a96344b776e878b122d63..ab828b14f62441ea0821e879ffe4f7b7673ef425 100644 (file)
@@ -291,6 +291,15 @@ func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_simple_ui8
+func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
+  // CHECK: arith.uitofp
+  %0 = "tosa.cast"(%arg0) : (tensor<1xui8>) -> tensor<1xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_simple_i32
 func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: linalg.generic