Correctly model undefined behavior in {tensor|memref}.dim
authorSanjoy Das <sanjoy.das@getcruise.com>
Wed, 12 Oct 2022 05:33:45 +0000 (22:33 -0700)
committerSanjoy Das <sanjoy.das@getcruise.com>
Thu, 13 Oct 2022 00:30:13 +0000 (17:30 -0700)
These operations have undefined behavior if the index is not less than the rank of the source tensor / memref, so they cannot be freely speculated like they were before this patch.  After this patch we speculate them only if we can prove that they don't have UB.

Depends on D135505.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D135748

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Transforms/loop-invariant-code-motion.mlir

index c94a531..54394da 100644 (file)
@@ -544,7 +544,7 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
 def MemRef_DimOp : MemRef_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     MemRefsNormalizable,
-    Pure,
+    ConditionallySpeculatable, NoMemoryEffect,
     ShapedDimOpInterface]> {
   let summary = "dimension index operation";
   let description = [{
@@ -593,6 +593,9 @@ def MemRef_DimOp : MemRef_Op<"dim", [
 
     /// Interface method of ShapedDimOpInterface: Return the dimension.
     OpFoldResult getDimension() { return getIndex(); }
+
+    /// Interface method for ConditionallySpeculatable.
+    Speculation::Speculatability getSpeculatability();
   }];
 
   let hasCanonicalizer = 1;
index bdc24fa..0088756 100644 (file)
@@ -87,7 +87,7 @@ def Tensor_CastOp : Tensor_Op<"cast", [
 
 def Tensor_DimOp : Tensor_Op<"dim", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    Pure,
+    ConditionallySpeculatable, NoMemoryEffect,
     ShapedDimOpInterface]> {
   let summary = "dimension index operation";
   let description = [{
@@ -135,6 +135,9 @@ def Tensor_DimOp : Tensor_Op<"dim", [
 
     /// Interface method of ShapedDimOpInterface: Return the dimension.
     OpFoldResult getDimension() { return getIndex(); }
+
+    /// Interface method for ConditionallySpeculatable.
+    Speculation::Speculatability getSpeculatability();
   }];
 
   let hasCanonicalizer = 1;
index 292eb46..fbc1ead 100644 (file)
@@ -819,6 +819,20 @@ Optional<int64_t> DimOp::getConstantIndex() {
   return {};
 }
 
+Speculation::Speculatability DimOp::getSpeculatability() {
+  auto constantIndex = getConstantIndex();
+  if (!constantIndex)
+    return Speculation::NotSpeculatable;
+
+  auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
+  if (!rankedSourceType)
+    return Speculation::NotSpeculatable;
+
+  // The verifier rejects operations that violate this assertion.
+  assert(constantIndex < rankedSourceType.getRank());
+  return Speculation::Speculatable;
+}
+
 LogicalResult DimOp::verify() {
   // Assume unknown index to be in range.
   Optional<int64_t> index = getConstantIndex();
index 448e97c..0ee79a6 100644 (file)
@@ -328,6 +328,20 @@ Optional<int64_t> DimOp::getConstantIndex() {
   return {};
 }
 
+Speculation::Speculatability DimOp::getSpeculatability() {
+  auto constantIndex = getConstantIndex();
+  if (!constantIndex)
+    return Speculation::NotSpeculatable;
+
+  auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
+  if (!rankedSourceType)
+    return Speculation::NotSpeculatable;
+
+  // The verifier rejects operations that violate this assertion.
+  assert(constantIndex < rankedSourceType.getRank());
+  return Speculation::Speculatable;
+}
+
 LogicalResult DimOp::verify() {
   // Assume unknown index to be in range.
   Optional<int64_t> index = getConstantIndex();
index 0b74c81..b8d3450 100644 (file)
@@ -503,3 +503,107 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
 
   return
 }
+
+// -----
+
+func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
+// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_unknown_dim
+    %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  // CHECK: scf.for
+  // CHECK-NEXT: tensor.dim
+  scf.for %i = %lb to %ub step %step {
+    %val = tensor.dim %t, %dim_idx : tensor<*xf32>
+  }
+
+  return
+}
+
+func.func @speculate_tensor_dim_known_rank_unknown_dim(
+// CHECK-LABEL: @speculate_tensor_dim_known_rank_unknown_dim
+    %t: tensor<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  // CHECK: scf.for
+  // CHECK-NEXT: tensor.dim
+  scf.for %i = %lb to %ub step %step {
+    %val = tensor.dim %t, %dim_idx : tensor<?x?x?x?xf32>
+  }
+
+  return
+}
+
+func.func @speculate_tensor_dim_unknown_rank_known_dim(
+// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_known_dim
+    %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  %c0 = arith.constant 0 : index
+  // CHECK: scf.for
+  // CHECK-NEXT: tensor.dim
+  scf.for %i = %lb to %ub step %step {
+    %val = tensor.dim %t, %c0 : tensor<*xf32>
+  }
+
+  return
+}
+
+func.func @speculate_tensor_dim_known_rank_known_dim_inbounds(
+// CHECK-LABEL: @speculate_tensor_dim_known_rank_known_dim_inbounds
+    %t: tensor<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  %c1 = arith.constant 1 : index
+  // CHECK: tensor.dim
+  // CHECK-NEXT: scf.for
+  scf.for %i = %lb to %ub step %step {
+    %val = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+  }
+
+  return
+}
+
+// -----
+
+func.func @speculate_memref_dim_unknown_rank_unknown_dim(
+// CHECK-LABEL: @speculate_memref_dim_unknown_rank_unknown_dim
+    %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  // CHECK: scf.for
+  // CHECK-NEXT: memref.dim
+  scf.for %i = %lb to %ub step %step {
+    %val = memref.dim %t, %dim_idx : memref<*xf32>
+  }
+
+  return
+}
+
+func.func @speculate_memref_dim_known_rank_unknown_dim(
+// CHECK-LABEL: @speculate_memref_dim_known_rank_unknown_dim
+    %t: memref<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  // CHECK: scf.for
+  // CHECK-NEXT: memref.dim
+  scf.for %i = %lb to %ub step %step {
+    %val = memref.dim %t, %dim_idx : memref<?x?x?x?xf32>
+  }
+
+  return
+}
+
+func.func @speculate_memref_dim_unknown_rank_known_dim(
+// CHECK-LABEL: @speculate_memref_dim_unknown_rank_known_dim
+    %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  %c0 = arith.constant 0 : index
+  // CHECK: scf.for
+  // CHECK-NEXT: memref.dim
+  scf.for %i = %lb to %ub step %step {
+    %val = memref.dim %t, %c0 : memref<*xf32>
+  }
+
+  return
+}
+
+func.func @speculate_memref_dim_known_rank_known_dim_inbounds(
+// CHECK-LABEL: @speculate_memref_dim_known_rank_known_dim_inbounds
+    %t: memref<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+  %c1 = arith.constant 1 : index
+  // CHECK: memref.dim
+  // CHECK-NEXT: scf.for
+  scf.for %i = %lb to %ub step %step {
+    %val = memref.dim %t, %c1 : memref<?x?x?x?xf32>
+  }
+
+  return
+}