[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.
authorMaheshRavishankar <ravishankarm@google.com>
Wed, 27 Jan 2021 07:21:33 +0000 (23:21 -0800)
committerMaheshRavishankar <ravishankarm@google.com>
Wed, 27 Jan 2021 07:22:28 +0000 (23:22 -0800)
Differential Revision: https://reviews.llvm.org/D95305

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index a6f3576..2982132 100644 (file)
@@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder,
 }
 
 namespace {
-struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
+/// Since `init_tensor` operation creates a tensor needed only for its shape, a
+/// subtensor of this is also needed only for its shape. The result can be
+/// replaced by a new init_tensor operation of the same size as the subtensor
+/// op.
+struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
+  using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorOp subtensorOp,
+                                PatternRewriter &rewriter) const override {
+    if (!subtensorOp.source().getDefiningOp<linalg::InitTensorOp>())
+      return failure();
+    rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
+        subtensorOp, subtensorOp.sizes(),
+        llvm::to_vector<4>(llvm::map_range(
+            subtensorOp.static_sizes(),
+            [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
+        subtensorOp.getSourceType().getElementType());
+    return success();
+  }
+};
+
+struct FoldInitTensorWithTensorReshapeOp
+    : public OpRewritePattern<TensorReshapeOp> {
   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
@@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
 
 void InitTensorOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
-                 ReplaceStaticShapeDims>(context);
+  results
+      .insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
+              ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
 }
 
 //===----------------------------------------------------------------------===//
index cc00b98..418d9d2 100644 (file)
@@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
 // CHECK-LABEL: func @keep_not_noop
 //       CHECK:   %[[RESULT:.+]]:2 = linalg.generic
 //       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @fold_init_tensor_with_subtensor
+  (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
+{
+  %0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32>
+  %1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
+    : tensor<?x10x40xf32> to tensor<5x?x20xf32>
+  return %1 : tensor<5x?x20xf32>
+}
+//      CHECK: func @fold_init_tensor_with_subtensor
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20]
+//      CHECK:   return %[[T0]]