mlir/tosa: move tosa.pad from Linalg to Tensor conversion
authorRamkumar Ramachandra <r@artagnon.com>
Thu, 1 Dec 2022 10:48:24 +0000 (11:48 +0100)
committerRamkumar Ramachandra <r@artagnon.com>
Tue, 6 Dec 2022 06:39:29 +0000 (07:39 +0100)
Since tosa.pad is lowered strictly to artih and tensor ops, move
ConvertPad from TosaToLinalg to TosaToTensor, benefitting non-Linalg
Tosa targets. TensorToLinalg exists, and is trivial, so nothing is lost.

Signed-off-by: Ramkumar Ramachandra <r@artagnon.com>
Differential Revision: https://reviews.llvm.org/D139091

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

index ade94e1..3c74da1 100644 (file)
@@ -1932,81 +1932,6 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
   }
 };
 
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
-public:
-  using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::PadOp padOp,
-                                PatternRewriter &rewriter) const final {
-    auto loc = padOp.getLoc();
-    auto input = padOp.getInput1();
-    auto padding = padOp.getPadding();
-
-    ShapedType inputTy = input.getType().cast<ShapedType>();
-    Type elementTy = inputTy.getElementType();
-    int64_t rank = inputTy.getRank();
-
-    // Setup the default constantAttr.
-
-    Value padConstant;
-
-    if (padOp.getPadConst()) {
-      padConstant = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padOp.getPadConst(), ValueRange({}));
-    } else {
-      Attribute constantAttr;
-      if (elementTy.isa<FloatType>()) {
-        constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-      } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
-        constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-      } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
-        int64_t value = padOp.getQuantizationInfo()->getInputZp();
-        constantAttr = rewriter.getIntegerAttr(elementTy, value);
-      }
-      if (constantAttr)
-        padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
-    }
-
-    if (!padConstant) {
-      return rewriter.notifyMatchFailure(
-          padOp, "tosa.pad was unable to determine the pad constant value.");
-    }
-
-    Value lowIndex =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
-    Value highIndex =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
-
-    SmallVector<OpFoldResult, 3> lowValues;
-    SmallVector<OpFoldResult, 3> highValues;
-
-    lowValues.reserve(rank);
-    highValues.reserve(rank);
-
-    for (int i = 0; i < rank; i++) {
-      Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
-      Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({inputIndex, lowIndex}));
-      Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
-          loc, padding, ValueRange({inputIndex, highIndex}));
-
-      lowVal = rewriter.createOrFold<arith::IndexCastOp>(
-          loc, rewriter.getIndexType(), lowVal);
-      highVal = rewriter.createOrFold<arith::IndexCastOp>(
-          loc, rewriter.getIndexType(), highVal);
-
-      lowValues.push_back(lowVal);
-      highValues.push_back(highVal);
-    }
-
-    auto newPadOp = rewriter.create<tensor::PadOp>(
-        loc, padOp.getType(), input, lowValues, highValues, padConstant);
-
-    rewriter.replaceOp(padOp, newPadOp.getResult());
-    return success();
-  }
-};
-
 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
 // op, producing two output buffers.
 //
@@ -2375,7 +2300,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ArgMaxConverter,
       ConcatConverter,
       GatherConverter,
-      PadConverter,
       ReshapeConverterCollapse,
       ReshapeConverterExpand,
       ReshapeConverterCollapseExpand,
index 5290923..4b1e351 100644 (file)
@@ -56,6 +56,7 @@ public:
     target.addLegalOp<tosa::ConstOp>();
     target.addLegalOp<tosa::WhileOp>();
     target.addLegalOp<tosa::SliceOp>();
+    target.addLegalOp<tosa::PadOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
index cb2eea2..047cb31 100644 (file)
@@ -22,7 +22,7 @@ using namespace tosa;
 
 namespace {
 
-class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
+class SliceConverter : public OpRewritePattern<tosa::SliceOp> {
 public:
   using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
 
@@ -59,9 +59,84 @@ public:
   }
 };
 
+class PadConverter : public OpRewritePattern<tosa::PadOp> {
+public:
+  using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::PadOp padOp,
+                                PatternRewriter &rewriter) const final {
+    auto loc = padOp.getLoc();
+    auto input = padOp.getInput1();
+    auto padding = padOp.getPadding();
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    Type elementTy = inputTy.getElementType();
+    int64_t rank = inputTy.getRank();
+
+    // Setup the default constantAttr.
+
+    Value padConstant;
+
+    if (padOp.getPadConst()) {
+      padConstant = rewriter.createOrFold<tensor::ExtractOp>(
+          loc, padOp.getPadConst(), ValueRange({}));
+    } else {
+      Attribute constantAttr;
+      if (elementTy.isa<FloatType>()) {
+        constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+      } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
+        constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+      } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
+        int64_t value = padOp.getQuantizationInfo()->getInputZp();
+        constantAttr = rewriter.getIntegerAttr(elementTy, value);
+      }
+      if (constantAttr)
+        padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
+    }
+
+    if (!padConstant) {
+      return rewriter.notifyMatchFailure(
+          padOp, "tosa.pad was unable to determine the pad constant value.");
+    }
+
+    Value lowIndex =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value highIndex =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
+
+    SmallVector<OpFoldResult, 3> lowValues;
+    SmallVector<OpFoldResult, 3> highValues;
+
+    lowValues.reserve(rank);
+    highValues.reserve(rank);
+
+    for (int i = 0; i < rank; i++) {
+      Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
+      Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
+          loc, padding, ValueRange({inputIndex, lowIndex}));
+      Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
+          loc, padding, ValueRange({inputIndex, highIndex}));
+
+      lowVal = rewriter.createOrFold<arith::IndexCastOp>(
+          loc, rewriter.getIndexType(), lowVal);
+      highVal = rewriter.createOrFold<arith::IndexCastOp>(
+          loc, rewriter.getIndexType(), highVal);
+
+      lowValues.push_back(lowVal);
+      highValues.push_back(highVal);
+    }
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, padOp.getType(), input, lowValues, highValues, padConstant);
+
+    rewriter.replaceOp(padOp, newPadOp.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<SliceOpConverter>(patterns->getContext());
+  patterns->add<SliceConverter, PadConverter>(patterns->getContext());
 }
index bf8d709..af6a08e 100644 (file)
@@ -36,6 +36,7 @@ public:
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::SliceOp>();
+    target.addIllegalOp<tosa::PadOp>();
     target.addLegalDialect<arith::ArithDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
index 349ad7e..8406c05 100644 (file)
@@ -1301,93 +1301,6 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
 
 // -----
 
-// CHECK-LABEL: @pad_float
-// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // TODO: Output contains multiple "arith.constant 1 : index".
-  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
-  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
-  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
-  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
-  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
-  // CHECK:   tensor.yield [[CST]]
-  // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)
-  return %1 : tensor<4x9xf32>
-}
-
-func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // CHECK: [[CST:%.+]] = arith.constant 0 : i32
-  // CHECK: tensor.pad
-  // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
-  return %1 : tensor<4x9xi32>
-}
-
-func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // CHECK: [[CST:%.+]] = arith.constant 42 : i32
-  // CHECK: tensor.pad
-  // CHECK:   tensor.yield [[CST]]
-  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
-  return %1 : tensor<4x9xi32>
-}
-
-// -----
-
-func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // TODO: Output contains multiple "arith.constant 1 : index".
-  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
-  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
-  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
-  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
-  // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
-  // CHECK:   tensor.yield [[CST]]
-  // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
-  %1 = arith.constant dense<42.0> : tensor<f32>
-  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
-  return %2 : tensor<4x9xf32>
-}
-
-// -----
-
-func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  // TODO: Output contains multiple "arith.constant 1 : index".
-  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
-  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
-  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
-  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
-  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
-  // CHECK:   tensor.yield [[CST]]
-  // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
-  return %1 : tensor<?x9xf32>
-}
-
-func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
-  %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
-  // TODO: Output contains multiple "arith.constant 1 : index".
-  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
-  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
-  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
-  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
-  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
-  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
-  // CHECK:   tensor.yield [[CST]]
-  // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
-  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
-  return %1 : tensor<?x9xf32>
-}
-
-// -----
-
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
index 08105c7..b50af43 100644 (file)
@@ -19,3 +19,90 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
   %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor<?xf32>)  -> (tensor<?xf32>)
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @pad_float
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK:   tensor.yield [[CST]]
+  // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)
+  return %1 : tensor<4x9xf32>
+}
+
+func.func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: [[CST:%.+]] = arith.constant 0 : i32
+  // CHECK: tensor.pad
+  // CHECK:   tensor.yield [[CST]]
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
+  return %1 : tensor<4x9xi32>
+}
+
+func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: [[CST:%.+]] = arith.constant 42 : i32
+  // CHECK: tensor.pad
+  // CHECK:   tensor.yield [[CST]]
+  %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
+  return %1 : tensor<4x9xi32>
+}
+
+// -----
+
+func.func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK:   tensor.yield [[CST]]
+  // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+  %1 = arith.constant dense<42.0> : tensor<f32>
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
+  return %2 : tensor<4x9xf32>
+}
+
+// -----
+
+func.func @pad_dyn_input(%arg0 : tensor<?x2xf32>) -> (tensor<?x9xf32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK:   tensor.yield [[CST]]
+  // CHECK: } : tensor<?x2xf32> to tensor<?x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<?x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+  return %1 : tensor<?x9xf32>
+}
+
+func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
+  %0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK: tensor.pad %[[ARG0]] low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK:   tensor.yield [[CST]]
+  // CHECK: } : tensor<1x2xf32> to tensor<?x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)
+  return %1 : tensor<?x9xf32>
+}