[mlir][tosa] Make tosa.resize to linalg avoid redundant loads for unit width
authorRob Suderman <suderman@google.com>
Fri, 16 Dec 2022 00:07:31 +0000 (16:07 -0800)
committerRob Suderman <suderman@google.com>
Fri, 16 Dec 2022 00:22:46 +0000 (16:22 -0800)
When using a tosa resize for ?x1x1x? to ?x1x?x? we should avoid doing a 2D
interpolation as only two unique values are loaded. As the extract operation
performance numerical computation on its values the superfluous extracts may
fail to be coalesced. Instead we only interpolate between the values if there
are multiple values to interpolate between.

For the integer case we also perform scaling by the scaling-factor to apply
the same integer scaling behavior as interpolation.

Reviewed By: jpienaar, NatashaKnk

Differential Revision: https://reviews.llvm.org/D139979

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir

index d704b5e..ebc63bd 100644 (file)
@@ -1468,7 +1468,10 @@ public:
       Value x = b.create<linalg::IndexOp>(2);
       Value channel = b.create<linalg::IndexOp>(3);
 
-      Value zeroI32 = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
+      Value zeroI32 =
+          b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
+      Value zeroFp32 =
+          b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
       Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
       Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
 
@@ -1498,6 +1501,11 @@ public:
       auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
                                     Value scaleN, Value scaleD, Value offset,
                                     int size, ImplicitLocOpBuilder &b) {
+        if (size == 1) {
+          index = zeroI32;
+          delta = zeroFp32;
+          return;
+        }
         // x = x * scale_d + offset;
         // ix = floor(x / scale_n)
         // dx = x / scale_n - ix
@@ -1517,6 +1525,11 @@ public:
       auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
                                      Value scaleN, Value scaleD, Value offset,
                                      int size, ImplicitLocOpBuilder &b) {
+        if (size == 1) {
+          index = zeroI32;
+          delta = zeroI32;
+          return;
+        }
         // x = x * scale_d + offset;
         // ix = floor(x / scale_n)
         //  dx = x - ix * scale_n;
@@ -1606,7 +1619,10 @@ public:
         if (floatingPointMode) {
           auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
           auto interpolate = [&](Value val0, Value val1, Value delta,
+                                 int inputSize,
                                  ImplicitLocOpBuilder &b) -> Value {
+            if (inputSize == 1)
+              return val0;
             Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
             Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
             Value mul1 = b.create<arith::MulFOp>(val1, delta);
@@ -1616,16 +1632,16 @@ public:
           // Linalg equivalent to the section below:
           //   topAcc = v00 * (unit_x - dx);
           //   topAcc += v01 * dx;
-          Value topAcc = interpolate(y0x0, y0x1, dx, b);
+          Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
 
           // Linalg equivalent to the section below:
           //   bottomAcc = v10 * (unit_x - dx);
           //   bottomAcc += v11 * dx;
-          Value bottomAcc = interpolate(y1x0, y1x1, dx, b);
+          Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
 
           // Linalg equivalent to the section below:
           //   result = topAcc * (unit_y - dy) + bottomAcc * dy
-          Value result = interpolate(topAcc, bottomAcc, dy, b);
+          Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
           b.create<linalg::YieldOp>(result);
         } else {
           // Perform in quantized space.
@@ -1650,22 +1666,21 @@ public:
             xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
           }
 
-          auto interpolate = [](Value val0, Value val1, Value weight0,
-                                Value weight1,
+          auto interpolate = [](Value val0, Value val1, Value weight1,
+                                Value scale, int inputSize,
                                 ImplicitLocOpBuilder &b) -> Value {
+            if (inputSize == 1)
+              return b.create<arith::MulIOp>(val0, scale);
+            Value weight0 = b.create<arith::SubIOp>(scale, weight1);
             Value mul0 = b.create<arith::MulIOp>(val0, weight0);
             Value mul1 = b.create<arith::MulIOp>(val1, weight1);
             return b.create<arith::AddIOp>(mul0, mul1);
           };
 
-          Value weight0 = b.create<arith::SubIOp>(xScaleNExt, dx);
-          Value weight1 = dx;
-          Value topAcc = interpolate(y0x0, y0x1, weight0, weight1, b);
-          Value bottomAcc = interpolate(y1x0, y1x1, weight0, weight1, b);
-
-          weight0 = b.create<arith::SubIOp>(yScaleNExt, dy);
-          weight1 = dy;
-          Value result = interpolate(topAcc, bottomAcc, weight0, weight1, b);
+          Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
+          Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
+          Value result =
+              interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
           b.create<linalg::YieldOp>(result);
         }
       }
index 3aa6d2a..382dae5 100644 (file)
@@ -278,6 +278,7 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
   // CHECK: %[[WLOLO:.+]] = arith.muli %[[XLOLO]], %[[NDX]]
   // CHECK: %[[WLOHI:.+]] = arith.muli %[[XLOHI]], %[[D_X_EXT]]
   // CHECK: %[[LO:.+]] = arith.addi %[[WLOLO]], %[[WLOHI]]
+  // CHECK: %[[NDX:.+]] = arith.subi %[[X_N_EXT]], %[[D_X_EXT]]
   // CHECK: %[[WHILO:.+]] = arith.muli %[[XHILO]], %[[NDX]]
   // CHECK: %[[WHIHI:.+]] = arith.muli %[[XHIHI]], %[[D_X_EXT]]
   // CHECK: %[[HI:.+]] = arith.addi %[[WHILO]], %[[WHIHI]]
@@ -492,3 +493,47 @@ func.func @resize_bilinear_int48(%arg0: tensor<1x19x19x1xi16>) {
   %0 = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [16, 1, 16, 1], offset = [0, 0], border = [0, 0]} : (tensor<1x19x19x1xi16>) -> tensor<1x289x289x1xi48>
            return
 }
+
+// -----
+
+// CHECK-LABEL: skip_interpolate_bilinear_i8
+func.func @skip_interpolate_bilinear_i8(%arg0 : tensor<3x1x2x7xi8>) -> tensor<3x1x5x7xi32> {
+  // CHECK:  %[[GENERIC:.+]] = linalg.generic
+  // CHECK:    %[[BATCH:.+]] = linalg.index 0
+  // CHECK:    %[[CHANNEL:.+]] = linalg.index 3
+  // CHECK-DAG:    %[[C3:.+]] = arith.constant 3
+  // CHECK-DAG:    %[[C2:.+]] = arith.constant 2
+  // CHECK:    %[[EXTRACT0:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xi8>
+  // CHECK:    %[[EXTRACT1:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xi8>
+  // CHECK:    %[[EXT0:.+]] = arith.extsi %[[EXTRACT0]] : i8 to i32
+  // CHECK:    %[[EXT1:.+]] = arith.extsi %[[EXTRACT1]] : i8 to i32
+  // CHECK:    %[[SUB:.+]] = arith.subi %[[C3]], %[[DX:.+]]
+  // CHECK:    %[[MUL0:.+]] = arith.muli %[[EXT0]], %[[SUB]]
+  // CHECK:    %[[MUL1:.+]] = arith.muli %[[EXT1]], %[[DX]]
+  // CHECK:    %[[ADD:.+]] = arith.addi %[[MUL0]], %[[MUL1]]
+  // CHECK:    %[[RES:.+]] = arith.muli %[[ADD]], %[[C2]]
+  // CHECK:    linalg.yield %[[RES]]
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x2x7xi8>) -> tensor<3x1x5x7xi32>
+
+  // CHECK:  return %[[GENERIC]]
+  return %resize : tensor<3x1x5x7xi32>
+}
+
+// CHECK-LABEL: skip_interpolate_bilinear_f32
+func.func @skip_interpolate_bilinear_f32(%arg0 : tensor<3x1x2x7xf32>) -> tensor<3x1x5x7xf32> {
+  // CHECK:  %[[GENERIC:.+]] = linalg.generic
+  // CHECK:    %[[BATCH:.+]] = linalg.index 0 : index
+  // CHECK:    %[[CHANNEL:.+]] = linalg.index 3 : index
+  // CHECK:    %[[EXTRACT0:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xf32>
+  // CHECK:    %[[EXTRACT1:.+]] = tensor.extract %arg0[%[[BATCH]], %{{.+}}, %{{.+}}, %[[CHANNEL]]] : tensor<3x1x2x7xf32>
+  // CHECK:    %[[C1:.+]] = arith.constant 1.000000e+00
+  // CHECK:    %[[SUB:.+]] = arith.subf %[[C1]], %[[DX:.+]]
+  // CHECK:    %[[MUL0:.+]] = arith.mulf %[[EXTRACT0]], %[[SUB]]
+  // CHECK:    %[[MUL1:.+]] = arith.mulf %[[EXTRACT1]], %[[DX]]
+  // CHECK:    %[[ADD:.+]] = arith.addf %[[MUL0]], %[[MUL1]]
+  // CHECK:    linalg.yield %[[ADD]]
+  %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = [2, 1, 3, 1], offset = [0, 0], border = [0, 0]} : (tensor<3x1x2x7xf32>) -> tensor<3x1x5x7xf32>
+
+  // CHECK:  return %[[GENERIC]]
+  return %resize : tensor<3x1x5x7xf32>
+}