[mlir][MemRef|Tensor] Fix the handling of DimOp
authorQuentin Colombet <quentin.colombet@gmail.com>
Fri, 27 Jan 2023 15:10:54 +0000 (16:10 +0100)
committerQuentin Colombet <quentin.colombet@gmail.com>
Thu, 16 Feb 2023 10:38:19 +0000 (11:38 +0100)
Although specifying an index that is out of bounds for both `memref.dim`
and `tensor.dim` produces an undefined behavior, this is still valid IR.
In particular, we could expose an out of bound index because of some
optimizations, for instance as demonstrated with
https://github.com/llvm/llvm-project/issues/60295, and this shouldn't
cause the compiler to abort.

This patch removes the overzealous verifier checks and properly handles
out of bound indices (as in it doesn't crash the compiler, but still
produces UB).

This fixes https://github.com/llvm/llvm-project/issues/60295.

Note: That `shape.dim` has a similar problem but we're not supposed to
produce UB in this case. Instead we're supposed to propagate an error in
the resulting value and I don't know how to do that at the moment. Hence I
left this part out of the patch.

Differential Revision: https://reviews.llvm.org/D143999

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
mlir/test/Dialect/Tensor/invalid.mlir

index fdc070f..45b5c9f 100644 (file)
@@ -604,7 +604,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 9e1c8bc..77d183e 100644 (file)
@@ -143,7 +143,6 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 700304d..35091a3 100644 (file)
@@ -495,14 +495,16 @@ private:
     MemRefType memRefType = operandType.cast<MemRefType>();
     if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
       int64_t i = *index;
-      if (memRefType.isDynamicDim(i)) {
-        // extract dynamic size from the memref descriptor.
-        MemRefDescriptor descriptor(adaptor.getSource());
-        return descriptor.size(rewriter, loc, i);
+      if (i >= 0 && i < memRefType.getRank()) {
+        if (memRefType.isDynamicDim(i)) {
+          // extract dynamic size from the memref descriptor.
+          MemRefDescriptor descriptor(adaptor.getSource());
+          return descriptor.size(rewriter, loc, i);
+        }
+        // Use constant for static size.
+        int64_t dimSize = memRefType.getDimSize(i);
+        return createIndexConstant(rewriter, loc, dimSize);
       }
-      // Use constant for static size.
-      int64_t dimSize = memRefType.getDimSize(i);
-      return createIndexConstant(rewriter, loc, dimSize);
     }
     Value index = adaptor.getIndex();
     int64_t rank = memRefType.getRank();
index 02f8019..6814aa5 100644 (file)
@@ -941,25 +941,6 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
-LogicalResult DimOp::verify() {
-  // Assume unknown index to be in range.
-  std::optional<int64_t> index = getConstantIndex();
-  if (!index)
-    return success();
-
-  // Check that constant index is not knowingly out of range.
-  auto type = getSource().getType();
-  if (auto memrefType = type.dyn_cast<MemRefType>()) {
-    if (*index >= memrefType.getRank())
-      return emitOpError("index is out of range");
-  } else if (type.isa<UnrankedMemRefType>()) {
-    // Assume index to be in range.
-  } else {
-    llvm_unreachable("expected operand with memref type");
-  }
-  return success();
-}
-
 /// Return a map with key being elements in `vals` and data being number of
 /// occurences of it. Use std::map, since the `vals` here are strides and the
 /// dynamic stride value is the same as the tombstone value for
@@ -1067,6 +1048,12 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   if (!memrefType)
     return {};
 
+  // Out of bound indices produce undefined behavior but are still valid IR.
+  // Don't choke on them.
+  int64_t indexVal = index.getInt();
+  if (indexVal < 0 || indexVal >= memrefType.getRank())
+    return {};
+
   // Fold if the shape extent along the given index is known.
   if (!memrefType.isDynamicDim(index.getInt())) {
     Builder builder(getContext());
index 74b3f93..d567949 100644 (file)
@@ -399,25 +399,6 @@ Speculation::Speculatability DimOp::getSpeculatability() {
   return Speculation::Speculatable;
 }
 
-LogicalResult DimOp::verify() {
-  // Assume unknown index to be in range.
-  std::optional<int64_t> index = getConstantIndex();
-  if (!index)
-    return success();
-
-  // Check that constant index is not knowingly out of range.
-  auto type = getSource().getType();
-  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
-    if (*index >= tensorType.getRank())
-      return emitOpError("index is out of range");
-  } else if (type.isa<UnrankedTensorType>()) {
-    // Assume index to be in range.
-  } else {
-    llvm_unreachable("expected operand with tensor type");
-  }
-  return success();
-}
-
 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
   auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
@@ -429,6 +410,12 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   if (!tensorType)
     return {};
 
+  // Out of bound indices produce undefined behavior but are still valid IR.
+  // Don't choke on them.
+  int64_t indexVal = index.getInt();
+  if (indexVal < 0 || indexVal >= tensorType.getRank())
+    return {};
+
   // Fold if the shape extent along the given index is known.
   if (!tensorType.isDynamicDim(index.getInt())) {
     Builder builder(getContext());
index 24877cc..0f3f3e4 100644 (file)
@@ -188,6 +188,20 @@ func.func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @static_out_of_bound_memref_dim
+func.func @static_out_of_bound_memref_dim(%static : memref<42x32x15x13x27xf32>) -> index {
+// CHECK: %[[C_MINUS_7:.*]] = arith.constant -7 : index
+// CHECK: %[[C_MINUS_7_I64:.*]] = builtin.unrealized_conversion_cast %[[C_MINUS_7]] : index to i64
+// CHECK: %[[UB_IDX:.*]] = llvm.getelementptr %{{.*}}[0, %[[C_MINUS_7_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr
+// CHECK: %[[UB_DIM_I64:.*]] = llvm.load %[[UB_IDX]] : !llvm.ptr
+// CHECK: %[[UB_DIM:.*]] = builtin.unrealized_conversion_cast %[[UB_DIM_I64]] : i64 to index
+// CHECK: return %[[UB_DIM]] : index
+  %c-7 = arith.constant -7 : index
+  %1 = memref.dim %static, %c-7 : memref<42x32x15x13x27xf32>
+  return %1 : index
+}
+// -----
+
 // Check that consistent types are emitted in address arithemic in presence of
 // a data layout specification.
 module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
index d15819f..f74bd94 100644 (file)
@@ -1,13 +1,5 @@
 // RUN: mlir-opt <%s -split-input-file -verify-diagnostics
 
-func.func @dim(%arg : tensor<1x?xf32>) {
-  %c2 = arith.constant 2 : index
-  tensor.dim %arg, %c2 : tensor<1x?xf32> // expected-error {{'tensor.dim' op index is out of range}}
-  return
-}
-
-// -----
-
 // Asking the dimension of a 0-D shape doesn't make sense.
 func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
   tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}