From 6af81ea1d6d36c7151a61f65e21b5c4ad9cf859d Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Fri, 20 Nov 2020 11:32:42 +0100 Subject: [PATCH] [mlir][std] Fold load(tensor_to_memref) into extract_element This canonicalization is useful to resolve loads into scalar values when doing partial bufferization. Differential Revision: https://reviews.llvm.org/D91855 --- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 1 + mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 24 ++++++++++++++++++++++++ mlir/test/Dialect/Standard/canonicalize.mlir | 14 ++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 8512c93..1ad3df6 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load", operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; + let hasCanonicalizer = 1; let hasFolder = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 6e755daa..04efc25 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2293,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef cstOperands) { return OpFoldResult(); } +namespace { +/// Fold a load on a tensor_to_memref operation into an extract_element on the +/// corresponding tensor. +struct LoadOfTensorToMemref : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadOp load, + PatternRewriter &rewriter) const override { + auto tensorToMemref = load.memref().getDefiningOp(); + if (!tensorToMemref) + return failure(); + + rewriter.replaceOpWithNewOp(load, tensorToMemref.tensor(), + load.indices()); + return success(); + } +}; +} // end anonymous namespace. + +void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index 5147537..ebc59c8 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref) -> index { return %1 : index } +// Test case: Folding of load(tensor_to_memref(%v, %idxs)) +// -> extract_element(%v, %idx) +// CHECK-LABEL: func @load_from_tensor_to_memref( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor +// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]] +// CHECK-NOT: load +// CHECK: return %[[RES]] : f32 +func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor) -> f32 { + %0 = tensor_to_memref %arg2 : memref + %1 = load %0[%arg0, %arg1] : memref + return %1 : f32 +} + // Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx // CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index -- 2.7.4