[mlir][tosa] Fix tosa.transpose_conv2d decompositions for new version
authorRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 19:21:25 +0000 (11:21 -0800)
committerRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 19:36:13 +0000 (11:36 -0800)
The decomposition was no longer correct for transpose_conv2d to conv2d
after the updated TOSA specification. Specifically the behavior for
padding was changed to refer to padding the tranpsose_conv2d instead
of referencing the conv applied to the inverse transform.

Test was validated using the TOSA conformance tests.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D140503

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir

index 14ea18c..1a3ba90 100644 (file)
@@ -1293,7 +1293,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   if (!ShapedType::isDynamic(inputHeight) &&
       !ShapedType::isDynamic(weightHeight)) {
     int64_t calculateSize =
-        (inputHeight - 1) * stride[0] - padding[0] - padding[1] + weightHeight;
+        (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
     outputShape[1] =
         ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
   }
@@ -1301,7 +1301,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   if (!ShapedType::isDynamic(inputWidth) &&
       !ShapedType::isDynamic(weightWidth)) {
     int64_t calculateSize =
-        (inputWidth - 1) * stride[1] - padding[2] - padding[3] + weightWidth;
+        (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
     outputShape[2] =
         ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
   }
index 148130a..d7f326f 100644 (file)
@@ -115,10 +115,10 @@ public:
     int64_t kernelWidth = weightTy.getDimSize(2);
 
     llvm::SmallVector<int64_t> convPad(4, 0);
-    convPad[0] = kernelHeight - 1 - pad[0];
-    convPad[1] = kernelHeight - 1 - pad[1];
-    convPad[2] = kernelWidth - 1 - pad[2];
-    convPad[3] = kernelWidth - 1 - pad[3];
+    convPad[0] = kernelHeight - 1 + pad[0];
+    convPad[1] = kernelHeight - 1 + pad[1];
+    convPad[2] = kernelWidth - 1 + pad[2];
+    convPad[3] = kernelWidth - 1 + pad[3];
 
     auto reverse1 = rewriter.create<tosa::ReverseOp>(
         loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
@@ -176,7 +176,7 @@ public:
 
     // If strides are all 1 we dont need to use this one.
     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
-      return failure();
+      return rewriter.notifyMatchFailure(op, "non-one stride found.");
 
     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
@@ -337,24 +337,50 @@ public:
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
         rewriter.getI64ArrayAttr(convReshapeDims1));
 
-    // Slice out the final result.
-    llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
-    llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
-                                            resultTy.getShape().begin());
-    sliceBegin[1] = pad[0];
-    sliceBegin[2] = pad[2];
+    // Determine the amount to slice / pad from the result start.
+    int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
+    int64_t resultSliceLeft = std::max<int64_t>(0, -pad[2]);
+    int64_t resultPadTop = std::max<int64_t>(0, pad[0]);
+    int64_t resultPadLeft = std::max<int64_t>(0, pad[2]);
+
+    // Try to slice the targetted result size, cap to the convolutions width.
+    int64_t resultSliceHeight =
+        std::min<int64_t>(convReshapeDims1[1] - resultSliceTop,
+                          resultTy.getDimSize(1) - resultPadTop);
+    int64_t resultSliceWidth =
+        std::min<int64_t>(convReshapeDims1[2] - resultSliceLeft,
+                          resultTy.getDimSize(2) - resultPadLeft);
+
+    llvm::SmallVector<int64_t, 4> sliceBegin = {0, resultSliceTop,
+                                                resultSliceLeft, 0};
+    llvm::SmallVector<int64_t, 4> sliceSize(convReshapeDims1.begin(),
+                                            convReshapeDims1.end());
+    sliceSize[1] = resultSliceHeight;
+    sliceSize[2] = resultSliceWidth;
 
     auto slice = createOpAndInfer<tosa::SliceOp>(
                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
                      rewriter.getI64ArrayAttr(sliceBegin),
-                     rewriter.getI64ArrayAttr(resultTy.getShape()))
+                     rewriter.getI64ArrayAttr(sliceSize))
                      .getResult();
 
-    auto addBias =
-        createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
+    llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
+    resultPadding[2] = resultPadTop;
+    resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
+    resultPadding[4] = resultPadLeft;
+    resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
+
+    DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
+        RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
+
+    Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
+        rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
 
-    rewriter.replaceOp(op, addBias.getResult());
+    auto resultPad = createOpAndInfer<tosa::PadOp>(
+        rewriter, loc, UnrankedTensorType::get(resultETy), slice,
+        resultPaddingVal);
 
+    rewriter.replaceOpWithNewOp<tosa::AddOp>(op, op.getType(), resultPad, bias);
     return success();
   }
 };
index 41d295a..6fe55ad 100644 (file)
@@ -1,18 +1,19 @@
 // RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
 
 // CHECK-LABEL: @transpose_conv2d
-func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> {
+func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x18x19x5xf32> {
   // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
   // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
-  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]}
+  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2)
+  // CHECK-SAME: dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
-  %1 = tensor.cast %0 : tensor<2x18x19x5xf32> to tensor<2x?x?x5xf32>
-  return %1 : tensor<2x?x?x5xf32>
+  return %0 : tensor<2x18x19x5xf32>
 }
 
 // -----
 
 // CHECK-LABEL: @transpose_conv2d_quantized
+
 func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
   // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
   // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
@@ -24,12 +25,18 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
 // -----
 
 // CHECK-LABEL: @transpose_conv2d_quantized_padded
-func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x7x7x18xi8>, %arg1: tensor<12x3x5x18xi8>, %arg2: tensor<12xi32>) -> (tensor<2x7x7x12xi32>) {
-  // CHECK: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
-  // CHECK: %[[REV2:.+]] = "tosa.reverse"(%[[REV1]]) {axis = 2 : i64}
-  // CHECK: "tosa.conv2d"(%arg0, %[[REV2]], %arg2) {dilation = [1, 1], pad = [1, 1, 2, 2], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = [1, 1]}
-  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [1, 1, 2, 2], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x7x7x18xi8>, tensor<12x3x5x18xi8>, tensor<12xi32>) -> tensor<2x7x7x12xi32>
-  return %0 : tensor<2x7x7x12xi32>
+func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
+  // CHECK-DAG: %[[REV0:.+]] = "tosa.reverse"(%0) {axis = 2 : i64}
+  // CHECK-DAG: %[[REV1:.+]] = "tosa.reverse"(%arg1) {axis = 1 : i64}
+  // CHECK: "tosa.conv2d"(%arg0, %1, %arg2) 
+  // CHECK-SAME: dilation = [1, 1], pad = [3, 4, 8, 9],
+  // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = [1, 1]}
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {
+    out_pad = [1, 2, 3, 4],
+    quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>,
+    out_shape = [-1, -1, -1, -1],
+    stride = [1, 1]} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x21x26x5xi32>
+  return %0 : tensor<2x21x26x5xi32>
 }
 
 // -----
@@ -94,3 +101,38 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = [-1, -1, -1, -1], stride = [2, 3]} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
   return %0 : tensor<2x35x47x5xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @transpose_conv2d_strided_overpad
+func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : tensor<1x2x1x1xi8>, %arg2 : tensor<1xi32>) -> (tensor<1x19x2x1xi32>) {
+  // CHECK: %[[WEIGHT_PAD:.+]] = "tosa.const"() 
+  // CHECK-SAME{literal}: value = dense<[[0, 0], [0, 0], [0, 1], [0, 0]]> : tensor<4x2xi32>
+  // CHECK: %[[WEIGHT_PERMS:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} : () -> tensor<6xi32>
+  // CHECK: %[[INPUT_PAD:.+]] = "tosa.const"() 
+  // CHECK-SAME{literal}: value = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: %[[RESULT_PERMS:.+]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>}
+  // CHECK: %[[RESULT_PAD:.+]] = "tosa.const"() 
+  // CHECK-SAME{literal}: value = dense<[[0, 0], [2, 0], [0, 0], [0, 0]]> : tensor<4x2xi32>}
+  // CHECK: %[[PAD_WEIGHT:.+]] = "tosa.pad"(%arg1, %[[WEIGHT_PAD]]) {quantization_info = #tosa.pad_quant<input_zp = 93>}
+  // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = "tosa.reshape"(%[[PAD_WEIGHT]]) {new_shape = [1, 2, 1, 1, 2, 1]}
+  // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = "tosa.transpose"(%[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]])
+  // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_WEIGHT]]) {new_shape = [2, 2, 1, 1]}
+  // CHECK: %[[REVERSE:.+]] = "tosa.reverse"(%[[RESHAPE_WEIGHT_1]]) {axis = 1 : i64}
+  // CHECK: %[[PAD_INPUT:.+]] = "tosa.pad"(%arg0, %[[INPUT_PAD]]) {quantization_info = #tosa.pad_quant<input_zp = -103>}
+  // CHECK: %[[CONV:.+]] = "tosa.conv2d"(%[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]]) 
+  // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]}
+  // CHECK: %[[RESHAPE_RESULT_0:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [1, 17, 1, 1, 2, 1]}
+  // CHECK: %[[TRANSPOSE_RESULT:.+]] = "tosa.transpose"(%[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]])
+  // CHECK: %[[RESHAPE_RESULT_1:.+]] = "tosa.reshape"(%[[TRANSPOSE_RESULT]]) {new_shape = [1, 17, 2, 1]}
+  // CHECK: %[[PAD_RESULT:.+]] = "tosa.pad"(%[[RESHAPE_RESULT_1]], %[[RESULT_PAD]])
+  // CHECK: %[[ADD:.+]] = "tosa.add"(%[[PAD_RESULT]], %arg2)
+  %2 =  "tosa.transpose_conv2d"(%arg0, %arg1, %arg2)  {
+    out_pad = [2, 0, 0, 1],
+    out_shape = [1, -1, -1, 1],
+    stride = [1, 2],
+    quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>} :
+    (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>) -> (tensor<1x19x2x1xi32>)
+  "func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
+}