auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
- newResultTypes.push_back(newOperands.back().getType());
+ if (!newOperands.back().getType().isa<MemRefType>())
+ newResultTypes.push_back(newOperands.back().getType());
}
// Clone op.
Operation *newOp =
} -> tensor<4xf32>
return
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
+ %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
+ linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]
+ } ins(%0 : tensor<?xf32>)
+ outs(%arg1 : memref<?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ }
+ return
+}
+
+// We need a mixed linalg as a bridge between tensor and memref worlds.
+// CHECK-LABEL: func @cast_producer_mixed
+// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
+// CHECK: linalg.generic {
+// CHECK-SAME: indexing_maps = [#map, #map],
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {