[mlir] Add support for tensor.extract to comprehensive bufferization
authorthomasraoux <thomasraoux@google.com>
Tue, 13 Jul 2021 16:34:48 +0000 (09:34 -0700)
committerthomasraoux <thomasraoux@google.com>
Tue, 13 Jul 2021 16:54:46 +0000 (09:54 -0700)
Differential Revision: https://reviews.llvm.org/D105870

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

index e49b020..1ff39fc 100644 (file)
@@ -2115,6 +2115,22 @@ static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp,
     return success();
   llvm_unreachable("unexpected yieldOp");
 }
+
+/// Bufferization for tensor::ExtractOp just translate to memref.load, it only
+/// reads the tensor.
+static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp,
+                               BlockAndValueMapping &bvm,
+                               BufferizationAliasInfo &aliasInfo) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(extractOp);
+
+  Location loc = extractOp.getLoc();
+  Value srcMemref = lookup(bvm, extractOp.tensor());
+  Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
+  extractOp.replaceAllUsesWith(l);
+  return success();
+}
 //===----------------------------------------------------------------------===//
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
@@ -2310,6 +2326,7 @@ static LogicalResult bufferizeFuncOpInternals(
             scf::ForOp,
             InitTensorOp,
             InsertSliceOp,
+            tensor::ExtractOp,
             LinalgOp,
             ReturnOp,
             TiledLoopOp,
index 96fb8b4..bab2711 100644 (file)
@@ -34,6 +34,19 @@ func @fill_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf
 
 // -----
 
+// CHECK-LABEL: func @tensor_extract(%{{.*}}: memref<?xf32, #{{.*}}>) -> f32 {
+func @tensor_extract(%A : tensor<?xf32>) -> (f32) {
+  %c0 = constant 0 : index
+
+//       CHECK: %[[RES:.*]] = memref.load {{.*}} : memref<?xf32, #{{.*}}>
+  %0 = tensor.extract %A[%c0] : tensor<?xf32>
+
+//       CHECK: return %[[RES]] : f32
+  return %0 : f32
+}
+
+// -----
+
 // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
 /// No linalg.inplaceable flag, must allocate.