[mlir][linalg] Canonicalize dim ops of tiled_loop block args
authorMatthias Springer <springerm@google.com>
Thu, 19 Aug 2021 02:23:36 +0000 (11:23 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 19 Aug 2021 02:24:33 +0000 (11:24 +0900)
E.g.:
```
%y = ... : tensor<...>
linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
  tensor.dim %x, %c0 : tensor<...>
}
```

is rewritten to:
```
%y = ... : tensor<...>
linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
  tensor.dim %y, %c0 : tensor<...>
}
```

Differential Revision: https://reviews.llvm.org/D108272

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir

index a58d91e..6b65a9e 100644 (file)
@@ -2064,6 +2064,57 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
   }
 };
 
+/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
+/// to dim(y) where `y` is the initial input/output value of the argument.
+///
+/// E.g.:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+///   tensor.dim %x, %c0 : tensor<...>
+/// }
+///
+/// is folded to:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+///   tensor.dim %y, %c0 : tensor<...>
+/// }
+template <typename OpTy>
+struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto src = dimOp.source().template dyn_cast<BlockArgument>();
+    if (!src)
+      return failure();
+    auto loopOp = dyn_cast<TiledLoopOp>(
+        src.getOwner()->getParent()->getParentOp());
+    if (!loopOp)
+      return failure();
+
+    auto inputArgs = loopOp.getRegionInputArgs();
+    auto it1 = llvm::find(inputArgs, src);
+    if (it1 != inputArgs.end()) {
+      rewriter.updateRootInPlace(dimOp, [&] {
+        dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
+      });
+      return success();
+    }
+
+    auto outputArgs = loopOp.getRegionOutputArgs();
+    auto it2 = llvm::find(outputArgs, src);
+    if (it2 != outputArgs.end()) {
+      rewriter.updateRootInPlace(dimOp, [&] {
+        dimOp.sourceMutable().assign(
+            loopOp.outputs()[it2 - outputArgs.begin()]);
+      });
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 // Folds away TiledLoopOp output tensors when the following conditions are met:
 // * result of `linalg.tiled_loop` has no uses
 // * output tensor is the argument of `linalg.yield`
@@ -2167,7 +2218,9 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
 
 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
+  results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
+                 DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
+                 DimOfTiledLoopInsOutsFolder<memref::DimOp>>(context);
 }
 
 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
index 34dacd5..41a4bfe 100644 (file)
@@ -919,3 +919,30 @@ func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
   %r = tensor.dim %0, %c0 : tensor<?x?xf32>
   return %r : index
 }
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_input(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   linalg.tiled_loop
+//       CHECK:     %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]]
+//       CHECK:     index_cast %[[dim]]
+func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+      to (%d0, %d1) step (%c1, %c1)
+      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+      outs (%out1 = %arg2 : tensor<?x?xf32>) {
+    %inner_dim = tensor.dim %in1, %c0 : tensor<?x?xf32>
+    %cast1 = std.index_cast %inner_dim : index to i32
+    %cast2 = std.sitofp %cast1 : i32 to f32
+    %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+    linalg.yield %fill : tensor<?x?xf32>
+  }
+  return %r : tensor<?x?xf32>
+}
index de7c0b7..01dab48 100644 (file)
@@ -121,8 +121,8 @@ module {
 // TLOOP:    %[[AB_SUB:.*]] = linalg.matmul
 // TLOOP-SAME:  ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
 
-// TLOOP:    %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]]
-// TLOOP:    %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_C_1:.*]] = tensor.dim %[[C]], %[[C1]] : [[TY]]
 
 // TLOOP:    %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = 
 // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
@@ -300,7 +300,7 @@ module {
 // TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
 // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
 
-// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
 // TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
 // TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
 // TLOOP:    %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
@@ -371,7 +371,7 @@ module {
 // TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
 // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
 
-// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
 // TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
 // TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
 // TLOOP:    %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]