From ae4cea38f18e32d4a106871d751af380032e16fe Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Tue, 13 Jul 2021 09:34:48 -0700 Subject: [PATCH] [mlir] Add support for tensor.extract to comprehensive bufferization Differential Revision: https://reviews.llvm.org/D105870 --- .../Linalg/Transforms/ComprehensiveBufferize.cpp | 17 +++++++++++++++++ .../Dialect/Linalg/comprehensive-module-bufferize.mlir | 13 +++++++++++++ 2 files changed, 30 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index e49b020..1ff39fc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2115,6 +2115,22 @@ static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp, return success(); llvm_unreachable("unexpected yieldOp"); } + +/// Bufferization for tensor::ExtractOp just translate to memref.load, it only +/// reads the tensor. +static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(extractOp); + + Location loc = extractOp.getLoc(); + Value srcMemref = lookup(bvm, extractOp.tensor()); + Value l = b.create(loc, srcMemref, extractOp.indices()); + extractOp.replaceAllUsesWith(l); + return success(); +} //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -2310,6 +2326,7 @@ static LogicalResult bufferizeFuncOpInternals( scf::ForOp, InitTensorOp, InsertSliceOp, + tensor::ExtractOp, LinalgOp, ReturnOp, TiledLoopOp, diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir index 96fb8b4..bab2711 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -34,6 +34,19 @@ func @fill_inplace(%A : tensor {linalg.inplaceable = true}) -> tensor) -> f32 { +func @tensor_extract(%A : tensor) -> (f32) { + %c0 = constant 0 : index + +// CHECK: %[[RES:.*]] = memref.load {{.*}} : memref + %0 = tensor.extract %A[%c0] : tensor + +// CHECK: return %[[RES]] : f32 + return %0 : f32 +} + +// ----- + // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> /// No linalg.inplaceable flag, must allocate. -- 2.7.4