[mlir][TilingInterface] Enabling tiling `tensor.pad` using `TilingInterface`.
authorMahesh Ravishankar <ravishankarm@google.com>
Fri, 26 Aug 2022 16:29:02 +0000 (16:29 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Fri, 26 Aug 2022 16:29:32 +0000 (16:29 +0000)
Update the implementation of `TilingInterface` for `tensor.pad`
operations to allow tiling the op using the existing patterns for the
interface. Verify that tests that pass with existing pad tiling
patterns producer the same results through TilingInterface patterns.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D132720

mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir [new file with mode: 0644]
mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

index 4bdc774..3649517 100644 (file)
@@ -54,7 +54,7 @@ Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
 /// ops on `affine.apply` and Affine dialect already depends on TensorOps. In
 /// order to break the cyclic dependency (TensorOps->AffineOps->TensorOps) the
 /// implementation is moved to a separate library.
-void registerTilingOpInterfaceExternalModels(mlir::DialectRegistry &registry);
+void registerTilingInterfaceExternalModels(mlir::DialectRegistry &registry);
 
 } // namespace tensor
 } // namespace mlir
index 0eceaa1..c91eb56 100644 (file)
@@ -125,7 +125,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
   tensor::registerBufferizableOpInterfaceExternalModels(registry);
   tensor::registerInferTypeOpInterfaceExternalModels(registry);
-  tensor::registerTilingOpInterfaceExternalModels(registry);
+  tensor::registerTilingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
 }
 
index 17def96..04cab1b 100644 (file)
@@ -22,6 +22,8 @@ namespace {
 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
 
   SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(op);
     ReifiedRankedShapedTypeDims reifiedShapes;
     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
         dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
@@ -69,6 +71,17 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
       return {};
     return {result};
   }
+
+  LogicalResult
+  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+                        ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
+    resultOffsets.assign(offsets.begin(), offsets.end());
+    resultSizes.assign(sizes.begin(), sizes.end());
+    return success();
+  }
 };
 
 } // namespace
@@ -281,7 +294,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
   return createPadOfExtractSlice();
 }
 
-void mlir::tensor::registerTilingOpInterfaceExternalModels(
+void mlir::tensor::registerTilingInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
new file mode 100644 (file)
index 0000000..2d60699
--- /dev/null
@@ -0,0 +1,140 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -resolve-shaped-type-result-dims -cse -split-input-file %s | FileCheck %s
+
+// 2D tiling of dynamic 2D pad tensor op.
+func.func @dynamic_2d_pad_tensor(%input_tensor: tensor<?x?xf32>,
+                         %pad_value: f32) -> tensor<?x?xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } {__internal_linalg_transform__ = "pad_2dtiling"}: tensor<?x?xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 7)>
+//       CHECK: func @dynamic_2d_pad_tensor(
+//  CHECK-SAME:     %[[IN:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//       CHECK:   %[[DIM_IN0:.+]] = tensor.dim %[[IN]], %[[C0]]
+//       CHECK:   %[[DIM0:.+]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]]
+//       CHECK:   %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]]
+//       CHECK:   %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
+//       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[DIM0]] step %[[C2]]
+//       CHECK:     scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
+//       CHECK:         tensor.generate
+//       CHECK:       else
+//       CHECK:         %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:         %[[PAD:.*]] = tensor.pad %[[SLICE]]
+//       CHECK:       tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @dynamic_2d_pad_tensor_inner_tiling(%input_tensor: tensor<?x?xf32>,
+                         %pad_value: f32) -> tensor<?x?xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } {__internal_linalg_transform__ = "pad_inner_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
+//       CHECK: func @dynamic_2d_pad_tensor_inner_tiling(
+//  CHECK-SAME:     %[[IN:.*]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:   %[[DIM_IN0:.*]] = tensor.dim %[[IN]], %[[C0]]
+//       CHECK:   %[[DIM0:.*]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]]
+//       CHECK:   %[[DIM_IN1:.*]] = tensor.dim %[[IN]], %[[C1]]
+//       CHECK:   %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
+//       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       CHECK:     %[[SWAP_RESULT:.*]] = scf.if
+//       CHECK:       tensor.generate
+//       CHECK:     else
+//       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[{{.*}}, {{.*}}]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
+                        %pad_value: f32) -> tensor<15x16xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } {__internal_linalg_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32>
+  return %0 : tensor<15x16xf32>
+}
+// CHECK-LABEL: func @static_pad_tensor(
+//  CHECK-SAME:     %[[IN:.*]]: tensor<7x9xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+//       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
+//       CHECK:     scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
+//       CHECK:         tensor.generate
+//       CHECK:       else
+//       CHECK:         %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:         %[[PAD:.*]] = tensor.pad %[[SLICE]]
+//       CHECK:       tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @static_pad_tensor_inner_tiling(%input_tensor: tensor<7x9xf32>,
+                        %pad_value: f32) -> tensor<15x16xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
+  return %0 : tensor<15x16xf32>
+}
+// CHECK-LABEL: func @static_pad_tensor_inner_tiling(
+//  CHECK-SAME:     %[[IN:.*]]: tensor<7x9xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+//       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       CHECK:     %[[SWAP_RESULT:.*]] = scf.if
+//       CHECK:       tensor.generate
+//       CHECK:     else
+//       CHECK:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
+//       CHECK:       %[[PAD:.*]] = tensor.pad %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
+//       CHECK:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][%[[C0]], {{.*}}] [%[[C15]], {{.*}}] [1, 1]
+//       CHECK:   return %[[RESULT]]
+
+/// Rest of the tests only check that they dont fail.
+
+// -----
+
+func.func @dynamic_2d_pad_tensor_outer_tiling(%input_tensor: tensor<?x?xf32>,
+                         %pad_value: f32) -> tensor<?x?xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } {__internal_linalg_transform__ = "pad_outer_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @dynamic_2d_pad_tensor_outer_tiling
+
+// -----
+
+func.func @static_pad_tensor_outer_tiling(%input_tensor: tensor<7x9xf32>,
+                        %pad_value: f32) -> tensor<15x16xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
+  return %0 : tensor<15x16xf32>
+}
+// CHECK-LABEL: func @static_pad_tensor_outer_tiling
index 26234d1..b562718 100644 (file)
@@ -12,4 +12,5 @@ add_mlir_library(MLIRTilingInterfaceTestPasses
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRTensorDialect
+  MLIRTensorTilingInterfaceImpl
   )
index c535fca..2ba93ca 100644 (file)
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -117,9 +119,10 @@ struct TestTilingInterfacePass
   TestTilingInterfacePass(const TestTilingInterfacePass &pass)
       : PassWrapper(pass) {}
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
-                    tensor::TensorDialect>();
+    registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
+                    scf::SCFDialect, tensor::TensorDialect>();
     linalg::registerTilingInterfaceExternalModels(registry);
+    tensor::registerTilingInterfaceExternalModels(registry);
   }
   StringRef getArgument() const final { return "test-tiling-interface"; }
   StringRef getDescription() const final {
@@ -184,6 +187,16 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
     // 6. Tiling + interchange of an operation
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
         context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
+    // 7. Tiling for 2D pad tensor operations.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, patterns, "pad_2dtiling", {2, 3});
+    // 8. Tiling inner dimension of 2d pad tensor operations.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, patterns, "pad_inner_tiling", {0, 3});
+    // 9. Tiling inner dimension of 2d pad tensor operations.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, patterns, "pad_outer_tiling", {2, 3});
+
     return;
   }
   if (testTileConsumerAndFuseProducer) {