return returnOp;
}
+/// Return true if `value` is the result of an InitTensorOp or a cast thereof.
+static bool isInitTensorOp(Value value) {
+ tensor::CastOp castOp;
+ while ((castOp = value.getDefiningOp<tensor::CastOp>()))
+ value = castOp.source();
+ return value.getDefiningOp<InitTensorOp>();
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
// unitialized and we do not need to copy.
// TODO: "matching bbArg does not bufferize to a read" is a more general
// check.
- if (!operand.getDefiningOp<linalg::InitTensorOp>())
+ if (!isInitTensorOp(operand))
b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
}
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
// unitialized and we do not need to copy.
// TODO: "matching bbArg does not bufferize to a read" is a more general
// check.
- if (!oldOutputTensor.getDefiningOp<linalg::InitTensorOp>()) {
+ if (!isInitTensorOp(oldOutputTensor)) {
b.setInsertionPointAfter(alloc.getDefiningOp());
b.create<linalg::CopyOp>(loc, outputBuffer, alloc);
}