if (parentOp) {
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
Value castSource = castOp.getSource();
- auto castSourceType = castSource.getType().cast<RankedTensorType>();
- if (castSourceType.hasStaticShape())
+ auto castSourceType = castSource.getType().dyn_cast<RankedTensorType>();
+ if (castSourceType && castSourceType.hasStaticShape())
sourceShape = castSourceType.getShape();
}
}
// -----
-
// CHECK-LABEL: func @tensor.cast(
func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
-> tensor<3x?xf32>
// -----
+// CHECK-LABEL: func @tensor.cast.unranked(
+func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>)
+ -> tensor<*xf32>
+{
+ // CHECK: tensor.cast
+ // CHECK: tensor.cast
+ // CHECK: tensor.cast
+ %ta = tensor.cast %a : tensor<*xf32> to tensor<?x?xf32>
+ %tb = tensor.cast %b : tensor<*xf32> to tensor<?x?xf32>
+ %tc = tensor.cast %c : tensor<*xf32> to tensor<?x?xf32>
+
+ // CHECK: linalg.matmul ins({{.*}}tensor<?x?xf32>, tensor<?x?xf32>)
+ // CHECK-SAME: outs({{.*}}tensor<?x?xf32>) -> tensor<?x?xf32>
+ %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // CHECK: tensor.cast
+ %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<*xf32>
+
+ return %1: tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @linalg_effects(
// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32>