[mlir][tosa] Handle tosa.resize nearest rounding correctly
authorRob Suderman <suderman@google.com>
Mon, 5 Dec 2022 20:45:06 +0000 (12:45 -0800)
committerRob Suderman <suderman@google.com>
Mon, 5 Dec 2022 21:10:08 +0000 (13:10 -0800)
Rounding of tosa.resize did not handle rounding to the nearest pixel correctly.
Rather than dividing the scale by 2 we should double the partial pixel to
guarantee we include a check on the lowest bit.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D139162

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

index ab21055..ade94e1 100644 (file)
@@ -1557,8 +1557,8 @@ public:
       y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
       x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
 
-      iy = rewriter.create<arith::DivUIOp>(loc, y, yScaleN);
-      ix = rewriter.create<arith::DivUIOp>(loc, x, xScaleN);
+      iy = rewriter.create<arith::DivSIOp>(loc, y, yScaleN);
+      ix = rewriter.create<arith::DivSIOp>(loc, x, xScaleN);
 
       Value tempY = rewriter.create<arith::MulIOp>(loc, iy, yScaleN);
       Value tempX = rewriter.create<arith::MulIOp>(loc, ix, xScaleN);
@@ -1583,14 +1583,12 @@ public:
         xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
                                                dx, halfVal);
       } else {
-        Value yScaleNHalfVal =
-            rewriter.create<arith::ShRSIOp>(loc, yScaleN, oneVal);
-        Value xScaleNHalfVal =
-            rewriter.create<arith::ShRSIOp>(loc, xScaleN, oneVal);
+        Value dyDoubled = rewriter.create<arith::ShLIOp>(loc, dy, oneVal);
+        Value dxDoubled = rewriter.create<arith::ShLIOp>(loc, dx, oneVal);
         yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
-                                               dy, yScaleNHalfVal);
+                                               dyDoubled, yScaleN);
         xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
-                                               dx, xScaleNHalfVal);
+                                               dxDoubled, xScaleN);
       }
 
       auto yOffset =
index 4453ad1..f9bfb4d 100644 (file)
@@ -145,8 +145,8 @@ func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
   // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
   // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
   // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
-  // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
+  // CHECK: %[[I_Y:.*]] = arith.divsi %[[Y]], %[[SCALE_Y_N]]
+  // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]]
   // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
   // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
   // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]
@@ -156,10 +156,10 @@ func.func @resize_nearest_int(%arg0: tensor<1x15x13x1xi8>) -> () {
 
   // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0
   // CHECK-DAG: %[[ONE:.*]] = arith.constant 1
-  // CHECK: %[[SCALE_Y_N_HALF:.*]] = arith.shrsi %[[SCALE_Y_N]], %[[ONE]]
-  // CHECK: %[[SCALE_X_N_HALF:.*]] = arith.shrsi %[[SCALE_X_N]], %[[ONE]]
-  // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y]], %[[SCALE_Y_N_HALF]]
-  // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X]], %[[SCALE_X_N_HALF]]
+  // CHECK: %[[D_Y_DOUBLE:.*]] = arith.shli %[[D_Y]], %[[ONE]]
+  // CHECK: %[[D_X_DOUBLE:.*]] = arith.shli %[[D_X]], %[[ONE]]
+  // CHECK: %[[PRED_Y:.*]] = arith.cmpi sge, %[[D_Y_DOUBLE]], %[[SCALE_Y_N]]
+  // CHECK: %[[PRED_X:.*]] = arith.cmpi sge, %[[D_X_DOUBLE]], %[[SCALE_X_N]]
   // CHECK: %[[VAL_37:.*]] = arith.select %[[PRED_Y]], %[[ONE]], %[[ZERO]]
   // CHECK: %[[VAL_38:.*]] = arith.select %[[PRED_X]], %[[ONE]], %[[ZERO]]
   // CHECK: %[[VAL_39:.*]] = arith.addi %[[I_Y]], %[[VAL_37]]
@@ -217,8 +217,8 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x19x1xi8>) {
   // CHECK: %[[TEMP_X:.*]] = arith.muli %[[X]], %[[SCALE_X_D]]
   // CHECK: %[[Y:.*]] = arith.addi %[[TEMP_Y]], %[[OFFSET_Y]]
   // CHECK: %[[X:.*]] = arith.addi %[[TEMP_X]], %[[OFFSET_X]]
-  // CHECK: %[[I_Y:.*]] = arith.divui %[[Y]], %[[SCALE_Y_N]]
-  // CHECK: %[[I_X:.*]] = arith.divui %[[X]], %[[SCALE_X_N]]
+  // CHECK: %[[I_Y:.*]] = arith.divsi %[[Y]], %[[SCALE_Y_N]]
+  // CHECK: %[[I_X:.*]] = arith.divsi %[[X]], %[[SCALE_X_N]]
   // CHECK: %[[TEMP_Y:.*]] = arith.muli %[[I_Y]], %[[SCALE_Y_N]]
   // CHECK: %[[TEMP_X:.*]] = arith.muli %[[I_X]], %[[SCALE_X_N]]
   // CHECK: %[[D_Y:.*]] = arith.subi %[[Y]], %[[TEMP_Y]]