[mlir][linalg] ComprehensiveBufferize: Do not copy InitTensorOps
authorMatthias Springer <springerm@google.com>
Mon, 13 Sep 2021 13:26:40 +0000 (22:26 +0900)
committerMatthias Springer <springerm@google.com>
Mon, 13 Sep 2021 13:31:54 +0000 (22:31 +0900)
Do not copy InitTensorOps or casts thereof.

Differential Revision: https://reviews.llvm.org/D109656

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

index cb7f0c3..3d810bf 100644 (file)
@@ -172,6 +172,14 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
   return returnOp;
 }
 
+/// Return true if `value` is the result of an InitTensorOp or a cast thereof.
+static bool isInitTensorOp(Value value) {
+  tensor::CastOp castOp;
+  while ((castOp = value.getDefiningOp<tensor::CastOp>()))
+    value = castOp.source();
+  return value.getDefiningOp<InitTensorOp>();
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//
@@ -1781,7 +1789,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
       // unitialized and we do not need to copy.
       // TODO: "matching bbArg does not bufferize to a read" is a more general
       // check.
-      if (!operand.getDefiningOp<linalg::InitTensorOp>())
+      if (!isInitTensorOp(operand))
         b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
     }
     BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
@@ -1908,7 +1916,7 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
       // unitialized and we do not need to copy.
       // TODO: "matching bbArg does not bufferize to a read" is a more general
       // check.
-      if (!oldOutputTensor.getDefiningOp<linalg::InitTensorOp>()) {
+      if (!isInitTensorOp(oldOutputTensor)) {
         b.setInsertionPointAfter(alloc.getDefiningOp());
         b.create<linalg::CopyOp>(loc, outputBuffer, alloc);
       }