operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
return OpFoldResult();
}
+namespace {
+/// Fold a load on a tensor_to_memref operation into an extract_element on the
+/// corresponding tensor.
+struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LoadOp load,
+ PatternRewriter &rewriter) const override {
+ auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
+ if (!tensorToMemref)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
+ load.indices());
+ return success();
+ }
+};
+} // end anonymous namespace.
+
+void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<LoadOfTensorToMemref>(context);
+}
+
//===----------------------------------------------------------------------===//
// MemRefCastOp
//===----------------------------------------------------------------------===//
return %1 : index
}
+// Test case: Folding of load(tensor_to_memref(%v, %idxs))
+// -> extract_element(%v, %idx)
+// CHECK-LABEL: func @load_from_tensor_to_memref(
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
+// CHECK-NOT: load
+// CHECK: return %[[RES]] : f32
+func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
+ %0 = tensor_to_memref %arg2 : memref<?x?xf32>
+ %1 = load %0[%arg0, %arg1] : memref<?x?xf32>
+ return %1 : f32
+}
+
// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index