One-shot-bufferize: allow non-tensor arguments in scg.while/for.
authorJohannes Reifferscheid <jreiffers@google.com>
Wed, 7 Sep 2022 12:35:50 +0000 (14:35 +0200)
committerJohannes Reifferscheid <jreiffers@google.com>
Wed, 7 Sep 2022 13:54:25 +0000 (15:54 +0200)
Currently, one-shot-bufferize crashes as soon as there's
a mixture of tensor and non-tensor arguments. This seems
to happen for no good reason.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D133419

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir

index fd0ff88..fb8c8dd 100644 (file)
@@ -601,8 +601,13 @@ struct ForOpInterface
     SmallVector<Value> castedInitArgs;
     for (const auto &it : llvm::enumerate(initArgs)) {
       Value initArg = it.value();
-      auto targetType =
-          bufferization::getBufferType(forOp->getResult(it.index()), options);
+      Value result = forOp->getResult(it.index());
+      // If the type is not a tensor, bufferization doesn't need to touch it.
+      if (!result.getType().isa<TensorType>()) {
+        castedInitArgs.push_back(initArg);
+        continue;
+      }
+      auto targetType = bufferization::getBufferType(result, options);
       if (failed(targetType))
         return failure();
       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -846,8 +851,13 @@ struct WhileOpInterface
     SmallVector<Value> castedInitArgs;
     for (const auto &it : llvm::enumerate(initArgs)) {
       Value initArg = it.value();
-      auto targetType = bufferization::getBufferType(
-          whileOp.getBeforeArguments()[it.index()], options);
+      Value beforeArg = whileOp.getBeforeArguments()[it.index()];
+      // If the type is not a tensor, bufferization doesn't need to touch it.
+      if (!beforeArg.getType().isa<TensorType>()) {
+        castedInitArgs.push_back(initArg);
+        continue;
+      }
+      auto targetType = bufferization::getBufferType(beforeArg, options);
       if (failed(targetType))
         return failure();
       castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -856,6 +866,8 @@ struct WhileOpInterface
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
+          if (!bbArg.getType().isa<TensorType>())
+            return bbArg.getType();
           // TODO: error handling
           return bufferization::getBufferType(bbArg, options)->cast<Type>();
         }));
index a72c6d3..dab4331 100644 (file)
@@ -344,13 +344,14 @@ func.func @scf_for_swapping_yields(
 //  CHECK-SAME:     %[[arg0:.*]]: memref<?xi1, #{{.*}}>
 func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
   // CHECK: scf.while : () -> () {
-  %res = scf.while (%arg1 = %arg0) : (tensor<?xi1>) -> tensor<?xi1> {
+  %res:2 = scf.while (%arg1 = %arg0, %i = %idx) :
+      (tensor<?xi1>, index) -> (tensor<?xi1>, index) {
     // CHECK: %[[condition:.*]] = memref.load %[[arg0]]
     // CHECK: scf.condition(%[[condition]])
     %condition = tensor.extract %arg1[%idx] : tensor<?xi1>
-    scf.condition(%condition) %arg1 : tensor<?xi1>
+    scf.condition(%condition) %arg1, %idx : tensor<?xi1>, index
   } do {
-  ^bb0(%arg2: tensor<?xi1>):
+  ^bb0(%arg2: tensor<?xi1>, %i: index):
     // CHECK: } do {
     // CHECK: memref.store %{{.*}}, %[[arg0]]
     // CHECK: scf.yield
@@ -358,11 +359,11 @@ func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
     %pos = "dummy.some_op"() : () -> (index)
     %val = "dummy.another_op"() : () -> (i1)
     %1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
-    scf.yield %1 : tensor<?xi1>
+    scf.yield %1, %i : tensor<?xi1>, index
   }
 
   // CHECK: return
-  return %res : tensor<?xi1>
+  return %res#0 : tensor<?xi1>
 }
 
 // -----
@@ -853,3 +854,19 @@ func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
   %x = tensor.extract %r[%c1] : tensor<?xf32>
   return %x : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @non_tensor_for_arg
+func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true}) 
+    -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2.0 : f32
+  %c10 = arith.constant 10 : index
+  %r1:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%idx = %c1, %t = %A) -> (index, tensor<?xf32>) {
+    %t2 = tensor.insert %c2 into %t[%idx] : tensor<?xf32>
+    scf.yield %idx, %t2 : index, tensor<?xf32>
+  }
+  return %r1#1 : tensor<?xf32>
+}
\ No newline at end of file