SmallVector<Value> castedInitArgs;
for (const auto &it : llvm::enumerate(initArgs)) {
Value initArg = it.value();
- auto targetType =
- bufferization::getBufferType(forOp->getResult(it.index()), options);
+ Value result = forOp->getResult(it.index());
+ // If the type is not a tensor, bufferization doesn't need to touch it.
+ if (!result.getType().isa<TensorType>()) {
+ castedInitArgs.push_back(initArg);
+ continue;
+ }
+ auto targetType = bufferization::getBufferType(result, options);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
SmallVector<Value> castedInitArgs;
for (const auto &it : llvm::enumerate(initArgs)) {
Value initArg = it.value();
- auto targetType = bufferization::getBufferType(
- whileOp.getBeforeArguments()[it.index()], options);
+ Value beforeArg = whileOp.getBeforeArguments()[it.index()];
+ // If the type is not a tensor, bufferization doesn't need to touch it.
+ if (!beforeArg.getType().isa<TensorType>()) {
+ castedInitArgs.push_back(initArg);
+ continue;
+ }
+ auto targetType = bufferization::getBufferType(beforeArg, options);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
+ if (!bbArg.getType().isa<TensorType>())
+ return bbArg.getType();
// TODO: error handling
return bufferization::getBufferType(bbArg, options)->cast<Type>();
}));
// CHECK-SAME: %[[arg0:.*]]: memref<?xi1, #{{.*}}>
func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
// CHECK: scf.while : () -> () {
- %res = scf.while (%arg1 = %arg0) : (tensor<?xi1>) -> tensor<?xi1> {
+ %res:2 = scf.while (%arg1 = %arg0, %i = %idx) :
+ (tensor<?xi1>, index) -> (tensor<?xi1>, index) {
// CHECK: %[[condition:.*]] = memref.load %[[arg0]]
// CHECK: scf.condition(%[[condition]])
%condition = tensor.extract %arg1[%idx] : tensor<?xi1>
- scf.condition(%condition) %arg1 : tensor<?xi1>
+ scf.condition(%condition) %arg1, %idx : tensor<?xi1>, index
} do {
- ^bb0(%arg2: tensor<?xi1>):
+ ^bb0(%arg2: tensor<?xi1>, %i: index):
// CHECK: } do {
// CHECK: memref.store %{{.*}}, %[[arg0]]
// CHECK: scf.yield
%pos = "dummy.some_op"() : () -> (index)
%val = "dummy.another_op"() : () -> (i1)
%1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
- scf.yield %1 : tensor<?xi1>
+ scf.yield %1, %i : tensor<?xi1>, index
}
// CHECK: return
- return %res : tensor<?xi1>
+ return %res#0 : tensor<?xi1>
}
// -----
%x = tensor.extract %r[%c1] : tensor<?xf32>
return %x : f32
}
+
+// -----
+
+// CHECK-LABEL: func @non_tensor_for_arg
+func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true})
+ -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2.0 : f32
+ %c10 = arith.constant 10 : index
+ %r1:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%idx = %c1, %t = %A) -> (index, tensor<?xf32>) {
+ %t2 = tensor.insert %c2 into %t[%idx] : tensor<?xf32>
+ scf.yield %idx, %t2 : index, tensor<?xf32>
+ }
+ return %r1#1 : tensor<?xf32>
+}
\ No newline at end of file