[mlir][linalg] Fix bug in InferStaticShapeOfOperands pattern
authorVladislav Vinogradov <v.vinogradov@yadro.com>
Thu, 10 Nov 2022 10:23:44 +0000 (13:23 +0300)
committerVladislav Vinogradov <v.vinogradov@yadro.com>
Wed, 16 Nov 2022 09:19:16 +0000 (12:19 +0300)
The pattern tries to deduce static shape from `tensor.cast` producer of linalg operation operands.
The original code unconditionally casts type of the `tensor.cast` source to `RankedTensorType`.
But the `tensor.cast` can also operate on `UnrankedTensorType`, so this cast either fail on assertion
in debug build or introduce UB in release build.

The patch replaces unconditional cast with `dyn_cast` and check for the cast result.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D137775

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index 32c8dd6..63e2ef3 100644 (file)
@@ -2236,8 +2236,8 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
     if (parentOp) {
       if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
         Value castSource = castOp.getSource();
-        auto castSourceType = castSource.getType().cast<RankedTensorType>();
-        if (castSourceType.hasStaticShape())
+        auto castSourceType = castSource.getType().dyn_cast<RankedTensorType>();
+        if (castSourceType && castSourceType.hasStaticShape())
           sourceShape = castSourceType.getShape();
       }
     }
index 3f11183..1fe5fe5 100644 (file)
@@ -47,7 +47,6 @@ func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tenso
 
 // -----
 
-
 // CHECK-LABEL: func @tensor.cast(
 func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
   -> tensor<3x?xf32>
@@ -68,6 +67,30 @@ func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3
 
 // -----
 
+// CHECK-LABEL: func @tensor.cast.unranked(
+func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>)
+  -> tensor<*xf32>
+{
+  //      CHECK:  tensor.cast
+  //      CHECK:  tensor.cast
+  //      CHECK:  tensor.cast
+  %ta = tensor.cast %a : tensor<*xf32> to tensor<?x?xf32>
+  %tb = tensor.cast %b : tensor<*xf32> to tensor<?x?xf32>
+  %tc = tensor.cast %c : tensor<*xf32> to tensor<?x?xf32>
+
+  //      CHECK:  linalg.matmul ins({{.*}}tensor<?x?xf32>, tensor<?x?xf32>)
+  // CHECK-SAME:    outs({{.*}}tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  //      CHECK:  tensor.cast
+  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<*xf32>
+
+  return %1: tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @linalg_effects(
 //  CHECK-SAME:     %[[A:[a-z0-9]*]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[B:[a-z0-9]*]]: memref<?x?xf32>