From 3f1e827abd7fb893f7a33f467126d5d736ffa8d4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 21 Apr 2021 15:52:50 -0700 Subject: [PATCH] [mlir] Linalg : do not forward memrefs to outputs when do bufferization Example: ``` %0 = linalg.init_tensor : tensor<...> %1 = linalg.generic ... outs(%0: tensor<...>) %2 = linalg.generic ... outs(%0: tensor<...>) ``` Memref allocated as a result of `init_tensor` bufferization can be incorrectly overwritten by the second linalg.generic operation Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D100921 --- mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp | 4 ---- mlir/test/Dialect/Linalg/bufferize.mlir | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp index 3ab86be..892942e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -59,10 +59,6 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, continue; } - if (auto alloc = resultTensor.getDefiningOp()) { - resultBuffers.push_back(resultTensor); - continue; - } // Allocate buffers for statically-shaped results. if (memrefType.hasStaticShape()) { resultBuffers.push_back(b.create(loc, memrefType)); diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index b9a4362..757b7a1 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -45,8 +45,8 @@ func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @init_tensor( // CHECK-SAME: %[[IN:.*]]: tensor, %[[SIZE:.*]]: index) -// CHECK: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref +// CHECK: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF]] : memref) // CHECK-SAME: outs(%[[OUT_BUF]] : memref) { -- 2.7.4