[linalg] When removing noop linalg.generics, check that inserting a cast is valid
authorBenjamin Kramer <benny.kra@googlemail.com>
Mon, 28 Mar 2022 12:10:26 +0000 (14:10 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Tue, 29 Mar 2022 21:05:54 +0000 (23:05 +0200)
linalg.generic can also take scalars instead of tensors, which
tensor.cast doesn't support. We don't have an easy way to cast between
scalars and tensors so just keep the linalg.generic in those cases.

Differential Revision: https://reviews.llvm.org/D122575

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index d62192d..c72b52b 100644 (file)
@@ -836,9 +836,13 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
             sparse_tensor::getSparseTensorEncoding(resultType))
           returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
               genericOp.getLoc(), resultType, returnedArg);
-        else
+        else {
+          if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
+                                                 resultType))
+            return failure();
           returnedArg = rewriter.create<tensor::CastOp>(
               genericOp.getLoc(), resultType, returnedArg);
+        }
       }
       returnedArgs.push_back(returnedArg);
     }
index eee6ebc..56ce267 100644 (file)
@@ -175,6 +175,24 @@ func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
 
 // -----
 
+#map = affine_map<() -> ()>
+func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
+  %out = linalg.init_tensor [] : tensor<f32>
+  %g = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = []
+  } ins(%arg0 : f32)
+    outs(%out : tensor<f32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32):
+    linalg.yield %arg2 : f32
+  } -> (tensor<f32>)
+  return %g : tensor<f32>
+}
+// CHECK-LABEL: func @cant_fold_to_tensor_cast
+//       CHECK:     linalg.generic
+
+// -----
+
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index