[mlir][linalg] Fix `FoldTensorCastProducerOp` for generic with memref output
authorIvan Butygin <ivan.butygin@gmail.com>
Thu, 10 Nov 2022 19:53:36 +0000 (20:53 +0100)
committerIvan Butygin <ivan.butygin@gmail.com>
Wed, 16 Nov 2022 21:59:54 +0000 (22:59 +0100)
Type should only be added to results if it is tensor.

Differential Revision: https://reviews.llvm.org/D137801

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index ec1c603..18e399e 100644 (file)
@@ -1799,7 +1799,8 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
       auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
       bool fold = canFoldIntoConsumerOp(tensorCastOp);
       newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
-      newResultTypes.push_back(newOperands.back().getType());
+      if (!newOperands.back().getType().isa<MemRefType>())
+        newResultTypes.push_back(newOperands.back().getType());
     }
     // Clone op.
     Operation *newOp =
index 55013c4..c9f1726 100644 (file)
@@ -845,3 +845,28 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
   } -> tensor<4xf32>
   return
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
+  %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
+  linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel"]
+  } ins(%0 : tensor<?xf32>)
+    outs(%arg1 : memref<?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32):
+    linalg.yield %arg2 : f32
+  }
+  return
+}
+
+// We need a mixed linalg as a bridge between tensor and memref worlds.
+// CHECK-LABEL: func @cast_producer_mixed
+//  CHECK-SAME:     (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
+//       CHECK:     linalg.generic {
+//  CHECK-SAME:    indexing_maps = [#map, #map],
+//  CHECK-SAME:    iterator_types = ["parallel"]
+//  CHECK-SAME:  } ins(%[[ARG1]] : tensor<5xf32>)
+//  CHECK-SAME:    outs(%[[ARG2]] : memref<?xf32>) {