[mlir][std] Fold load(tensor_to_memref) into extract_element
authorStephan Herhut <herhut@google.com>
Fri, 20 Nov 2020 10:32:42 +0000 (11:32 +0100)
committerStephan Herhut <herhut@google.com>
Fri, 20 Nov 2020 12:42:11 +0000 (13:42 +0100)
This canonicalization is useful to resolve loads into scalar values when
doing partial bufferization.

Differential Revision: https://reviews.llvm.org/D91855

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index 8512c93..1ad3df6 100644 (file)
@@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load",
     operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 
   let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
index 6e755da..04efc25 100644 (file)
@@ -2293,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
   return OpFoldResult();
 }
 
+namespace {
+/// Fold a load on a tensor_to_memref operation into an extract_element on the
+/// corresponding tensor.
+struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
+  using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(LoadOp load,
+                                PatternRewriter &rewriter) const override {
+    auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
+    if (!tensorToMemref)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
+                                                  load.indices());
+    return success();
+  }
+};
+} // end anonymous namespace.
+
+void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                         MLIRContext *context) {
+  results.insert<LoadOfTensorToMemref>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefCastOp
 //===----------------------------------------------------------------------===//
index 5147537..ebc59c8 100644 (file)
@@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
   return %1 : index
 }
 
+// Test case: Folding of load(tensor_to_memref(%v, %idxs))
+//            -> extract_element(%v, %idx)
+// CHECK-LABEL: func @load_from_tensor_to_memref(
+//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+//  CHECK-SAME:     %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
+//   CHECK-NOT:   load
+//       CHECK:   return %[[RES]] : f32
+func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = tensor_to_memref %arg2 : memref<?x?xf32>
+  %1 = load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+
 // Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
 // CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
 //  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index