sparse_tensor::getSparseTensorEncoding(resultType))
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
genericOp.getLoc(), resultType, returnedArg);
- else
+ else {
+ if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
+ resultType))
+ return failure();
returnedArg = rewriter.create<tensor::CastOp>(
genericOp.getLoc(), resultType, returnedArg);
+ }
}
returnedArgs.push_back(returnedArg);
}
// -----
+#map = affine_map<() -> ()>
+func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
+ %out = linalg.init_tensor [] : tensor<f32>
+ %g = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = []
+ } ins(%arg0 : f32)
+ outs(%out : tensor<f32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> (tensor<f32>)
+ return %g : tensor<f32>
+}
+// CHECK-LABEL: func @cant_fold_to_tensor_cast
+// CHECK: linalg.generic
+
+// -----
+
#map = affine_map<(d0, d1) -> (d0, d1)>
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index