[mlir] Add canonicalization pattern out_tensor->linalg->dim to out_tensor->dim.
authorAlexander Belyaev <pifon@google.com>
Tue, 5 Jan 2021 12:52:25 +0000 (13:52 +0100)
committerAlexander Belyaev <pifon@google.com>
Tue, 5 Jan 2021 14:15:21 +0000 (15:15 +0100)
Differential Revision: https://reviews.llvm.org/D94079

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index bcbd6d9..529ba35 100644 (file)
@@ -1958,14 +1958,33 @@ struct DeduplicateInputs : public RewritePattern {
     return success();
   }
 };
+
+/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
+/// with the corresponding output tensor argument of the linalg op.
+struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    Value dimOpArg = dimOp.memrefOrTensor();
+    auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
+    if (!linalgOp)
+      return failure();
+
+    auto results = linalgOp.getOperation()->getResults();
+    int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
+    auto outputTensors = linalgOp.getOutputTensors();
+    rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
+    return success();
+  }
+};
 } // namespace
 
 #define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
   void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
                                         MLIRContext *context) {                \
-    results.insert<EraseDeadLinalgOp>();                                       \
-    results.insert<FoldTensorCastOp>();                                        \
-    results.insert<DeduplicateInputs>();                                       \
+    results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>();  \
+    results.insert<ReplaceDimOfLinalgResult>(context);                         \
   }                                                                            \
                                                                                \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
index f015d5f..faac64c 100644 (file)
@@ -389,3 +389,31 @@ func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
 //      CHECK: func @init_tensor_dynamic_dim
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
 //      CHECK:   return %[[ARG0]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+    %arg_1: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  %0, %1 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel"]
+  } ins(%arg_0 : tensor<?xf32>)
+    outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+  ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+    linalg.yield %in, %in : f32, f32
+  } -> tensor<?xf32>, tensor<?xf32>
+
+  %c0 = constant 0 : index
+  %num_elem_0 = dim %0, %c0 : tensor<?xf32>
+  %result_0 = linalg.init_tensor [%num_elem_0] : tensor<?xf32>
+
+  %num_elem_1 = dim %1, %c0 : tensor<?xf32>
+  %result_1 = linalg.init_tensor [%num_elem_1] : tensor<?xf32>
+  return %result_0, %result_1 : tensor<?xf32>, tensor<?xf32>
+}
+// CHECK-LABEL: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
+// CHECK: dim [[ARG_0]]
+// CHECK: dim [[ARG_1]]