[mlir][linalg] Improve codegen when tiling PadTensor evenly
authorMatthias Springer <springerm@google.com>
Thu, 15 Jul 2021 02:27:52 +0000 (11:27 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 15 Jul 2021 02:29:21 +0000 (11:29 +0900)
Produce simpler IR with more static type information and fewer affine expressions.

Differential Revision: https://reviews.llvm.org/D105530

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
mlir/test/Dialect/Linalg/tile.mlir

index 4aa7792..5418bc3 100644 (file)
@@ -494,6 +494,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
+  tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
@@ -513,7 +514,15 @@ static void insertTilingPatterns(RewritePatternSet &patterns,
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
                      >::insert(patterns, options);
   patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
+}
+
+static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
+  MLIRContext *ctx = funcOp.getContext();
+  RewritePatternSet patterns(ctx);
   patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+  (void)applyPatternsAndFoldGreedily(
+      funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
 }
 
 static void
@@ -527,6 +536,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
   MLIRContext *ctx = funcOp.getContext();
   RewritePatternSet patterns(ctx);
   insertTilingPatterns(patterns, options);
+  patterns.add<AffineMinSCFCanonicalizationPattern>(patterns.getContext());
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
   (void)applyPatternsAndFoldGreedily(
       funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
@@ -534,6 +544,10 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
   funcOp.walk([](LinalgOp op) {
     op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
   });
+
+  // Apply swap pattern after generating loop nest and running
+  // canonicalizations.
+  applyExtractSliceOfPadTensorSwapPattern(funcOp);
 }
 
 namespace {
index 10f4dc3..36dc34a 100644 (file)
@@ -92,3 +92,33 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
     } : tensor<7x9xf32> to tensor<15x16xf32>
   return %0 : tensor<15x16xf32>
 }
+
+// -----
+
+// TILE1-LABEL: func @static_pad_tile_evenly(
+//  TILE1-SAME:     %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<14x15xf32>
+//   TILE1-DAG:   %[[C0:.*]] = constant 0 : index
+//   TILE1-DAG:   %[[C3:.*]] = constant 3 : index
+//   TILE1-DAG:   %[[C15:.*]] = constant 15 : index
+//       TILE1:   %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       TILE1:     %[[R2:.*]] = scf.if
+//       TILE1:       %[[GEN:.*]] = tensor.generate
+//       TILE1:       scf.yield %[[GEN]] : tensor<14x3xf32>
+//       TILE1:     else
+//       TILE1:       %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
+//       TILE1:       %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}]
+//       TILE1:       %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
+//       TILE1:       scf.yield %[[CAST]] : tensor<14x3xf32>
+//       TILE1:     %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
+//       TILE1:     scf.yield %[[R3]] : tensor<14x15xf32>
+//       TILE1:   return %[[RESULT]] : tensor<14x15xf32>
+func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>,
+                             %output_tensor: tensor<14x15xf32>,
+                             %pad_value: f32) -> tensor<14x15xf32> {
+  %0 = linalg.pad_tensor %input_tensor
+    low[0, 0] high[7, 6] into %output_tensor {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad_value : f32
+    } : tensor<7x9xf32> to tensor<14x15xf32>
+  return %0 : tensor<14x15xf32>
+}
index 47d6dc1..97b17eb 100644 (file)
 // TILE-234-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
 // TILE-234-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
 
-//   TILE-2-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 10)>
-//  TILE-02-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 12)>
-// TILE-002-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 16)>
-
 //   TILE-2-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
 //  TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
 // TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
@@ -132,10 +128,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-2-DAG: %[[C2:.*]] = constant 2 : index
 //       TILE-2-DAG: %[[M:.*]] = constant 10 : index
 //       TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} {
-//       TILE-2:   %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[I]])
-//       TILE-2:   %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x16xf32, #[[$strided2D]]>
-//       TILE-2:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
-//       TILE-2:   %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
+//       TILE-2:   %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x16xf32, #[[$strided2D]]>
+//       TILE-2:   %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]>
 //       TILE-2:   linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]]
 
 // TILE-02-LABEL: func @matmul_static(
@@ -143,10 +137,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-02-DAG: %[[C2:.*]] = constant 2 : index
 //       TILE-02-DAG: %[[N:.*]] = constant 12 : index
 //       TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} {
-//       TILE-02:   %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[J]])
-//       TILE-02:   %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
-//       TILE-02:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
-//       TILE-02:   %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
+//       TILE-02:   %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, 2] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x2xf32, #[[$strided2D]]>
+//       TILE-02:   %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, 2] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]>
 //       TILE-02:   linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
 
 // TILE-002-LABEL: func @matmul_static(
@@ -154,10 +146,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-002-DAG: %[[C2:.*]] = constant 2 : index
 //       TILE-002-DAG: %[[C16:.*]] = constant 16 : index
 //       TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
-//       TILE-002:   %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[K]])
-//       TILE-002:   %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-//       TILE-002:   %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
-//       TILE-002:   %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
+//       TILE-002:   %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, 2] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]>
+//       TILE-002:   %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [2, 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]>
 //       TILE-002:   linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
 
 // TILE-234-LABEL: func @matmul_static(
@@ -171,9 +161,9 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
 //       TILE-234:  scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[C10]] step %{{.*}} {
 //       TILE-234:    scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[C12]] step %{{.*}} {
 //       TILE-234:      scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
-//       TILE-234:        %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-234:        %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-//       TILE-234:        %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
+//       TILE-234:        %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [2, 4] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x4xf32, #[[$strided2D]]>
+//       TILE-234:        %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [4, 3] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<4x3xf32, #[[$strided2D]]>
+//       TILE-234:        %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x3xf32, #[[$strided2D]]>
 //
 //       TILE-234:        linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
 
@@ -312,7 +302,7 @@ func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
 //       TILE-234:     for
 //   TILE-234-NOT:   for
 //       TILE-234:       memref.subview{{.*}} : memref<127x99xf32>
-//       TILE-234:       linalg.fill{{.*}} : f32, memref<?x?xf32, #[[$stride_99_1_layout_map]]>
+//       TILE-234:       linalg.fill{{.*}} : f32, memref<?x3xf32, #[[$stride_99_1_layout_map]]>
 
 
 func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {