[mlir][tensor|memref] Harden the checks on dim op
authorQuentin Colombet <quentin.colombet@gmail.com>
Tue, 24 Jan 2023 10:29:37 +0000 (11:29 +0100)
committerQuentin Colombet <quentin.colombet@gmail.com>
Thu, 2 Feb 2023 10:34:03 +0000 (11:34 +0100)
Prior to this patch it was possible to use the dim operation on a 0-D
memref/tensor.
Unless we want to change the semantic of a 0-D shape, this doesn't make
sense because, paraphrasing the dim op semantic, this is guaranteed to
produce something that is undefined. (The requested index is guaranteed
to be equal to or greater than the rank.)

Harden the type requirements for the dim op by disallowing 0-D shaped
types.

This "fixes" llvm.org/PR60195 by rejecting dim op on 0-D shapes instead of
crashing during LLVM conversion.

Differential Revision: https://reviews.llvm.org/D142445

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/invalid.mlir

index c2c0d0a..f5dab42 100644 (file)
@@ -572,7 +572,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [
     ```
   }];
 
-  let arguments = (ins AnyRankedOrUnrankedMemRef:$source,
+  let arguments = (ins AnyNon0RankedOrUnrankedMemRef:$source,
                        Index:$index);
   let results = (outs Index:$result);
 
index 9328b1f..e702189 100644 (file)
@@ -115,7 +115,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [
     ```
   }];
 
-  let arguments = (ins AnyTensor:$source,
+  let arguments = (ins AnyNon0RankedOrUnrankedTensor:$source,
                        Index:$index);
   let results = (outs Index:$result);
 
index 7fb583f..d307beb 100644 (file)
@@ -548,6 +548,12 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
                          == }]
                       # rank>)>]>;
 
+// Whether a shaped type has a rank greater than or equal of the specified rank.
+class HasRankGreaterOrEqualPred<int rank> : And<[
+    HasRankPred,
+    CPred<[{$_self.cast<::mlir::ShapedType>().getRank() >= }] # rank>
+]>;
+
 // Vector types.
 
 class VectorOf<list<Type> allowedTypes> :
@@ -748,7 +754,16 @@ class RankedTensorOf<
     string summary = "ranked tensor">
   : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;
 
+class Non0RankedTensorOf<list<Type> allowedTypes>
+  : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
+      "non-0-ranked.tensor">;
+
 def AnyRankedTensor : RankedTensorOf<[AnyType]>;
+def AnyNon0RankedTensor  : Non0RankedTensorOf<[AnyType]>;
+def AnyUnrankedTensor  : UnrankedTensorOf<[AnyType]>;
+
+def AnyNon0RankedOrUnrankedTensor:
+    AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>;
 
 // Ranked tensor type with one of the specified types and ranks.
 class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
@@ -782,13 +797,20 @@ def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
 class MemRefOf<list<Type> allowedTypes> :
     ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
                         "::mlir::MemRefType">;
+class Non0RankedMemRefOf<list<Type> allowedTypes> :
+    ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>],
+         "non-0-ranked." # MemRefOf<allowedTypes>.summary,
+         "::mlir::MemRefType">;
 
 def AnyMemRef : MemRefOf<[AnyType]>;
+def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>;
 
 class RankedOrUnrankedMemRefOf<list<Type> allowedTypes>:
     AnyTypeOf<[UnrankedMemRefOf<allowedTypes>, MemRefOf<allowedTypes>]>;
 
 def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
+def AnyNon0RankedOrUnrankedMemRef:
+    AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>;
 
 // Memref declarations handle any memref, independent of rank, size, (static or
 // dynamic), layout, or memory space.
index ccbf929..19874f0 100644 (file)
@@ -1040,3 +1040,11 @@ func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
   %0 = memref.realloc %src : memref<256xf32> to memref<?xi32>
   return %0 : memref<?xi32>
 }
+
+// -----
+
+// Asking the dimension of a 0-D shape doesn't make sense.
+func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
+  memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
+  return
+}
index cbcc1e3..fe665a3 100644 (file)
@@ -1,13 +1,13 @@
 // RUN: mlir-opt %s -tensor-bufferize -cse -split-input-file -verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL:   func @dim(
-// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME:              %[[TENSOR:.*]]: tensor<*xf32>,
 // CHECK-SAME:              %[[INDEX:.*]]: index) -> index {
-// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<f32>
-// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<f32>
+// CHECK:           %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
+// CHECK:           %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref<*xf32>
 // CHECK:           return %[[EXTENT]] : index
-func.func @dim(%arg0: tensor<f32>, %arg1: index) -> index {
-  %0 = tensor.dim %arg0, %arg1 : tensor<f32>
+func.func @dim(%arg0: tensor<*xf32>, %arg1: index) -> index {
+  %0 = tensor.dim %arg0, %arg1 : tensor<*xf32>
   return %0 : index
 }
 
index 36c4dfe..d15819f 100644 (file)
@@ -8,6 +8,14 @@ func.func @dim(%arg : tensor<1x?xf32>) {
 
 // -----
 
+// Asking the dimension of a 0-D shape doesn't make sense.
+func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
+  tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}
+  return
+}
+
+// -----
+
 func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
   // expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}}
   %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32>