From adabce41185910227ca276a1cfd22e76443dd238 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 11 Oct 2022 22:33:45 -0700 Subject: [PATCH] Correctly model undefined behavior in {tensor|memref}.dim These operations have undefined behavior if the index is not less than the rank of the source tensor / memref, so they cannot be freely speculated like they were before this patch. After this patch we speculate them only if we can prove that they don't have UB. Depends on D135505. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D135748 --- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 5 +- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 5 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 14 +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 14 +++ .../Transforms/loop-invariant-code-motion.mlir | 104 +++++++++++++++++++++ 5 files changed, 140 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c94a5310..54394da 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -544,7 +544,7 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> { def MemRef_DimOp : MemRef_Op<"dim", [ DeclareOpInterfaceMethods, MemRefsNormalizable, - Pure, + ConditionallySpeculatable, NoMemoryEffect, ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ @@ -593,6 +593,9 @@ def MemRef_DimOp : MemRef_Op<"dim", [ /// Interface method of ShapedDimOpInterface: Return the dimension. OpFoldResult getDimension() { return getIndex(); } + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index bdc24fa..0088756 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -87,7 +87,7 @@ def Tensor_CastOp : Tensor_Op<"cast", [ def Tensor_DimOp : Tensor_Op<"dim", [ DeclareOpInterfaceMethods, - Pure, + ConditionallySpeculatable, NoMemoryEffect, ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ @@ -135,6 +135,9 @@ def Tensor_DimOp : Tensor_Op<"dim", [ /// Interface method of ShapedDimOpInterface: Return the dimension. OpFoldResult getDimension() { return getIndex(); } + + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 292eb46..fbc1ead 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -819,6 +819,20 @@ Optional DimOp::getConstantIndex() { return {}; } +Speculation::Speculatability DimOp::getSpeculatability() { + auto constantIndex = getConstantIndex(); + if (!constantIndex) + return Speculation::NotSpeculatable; + + auto rankedSourceType = dyn_cast(getSource().getType()); + if (!rankedSourceType) + return Speculation::NotSpeculatable; + + // The verifier rejects operations that violate this assertion. + assert(constantIndex < rankedSourceType.getRank()); + return Speculation::Speculatable; +} + LogicalResult DimOp::verify() { // Assume unknown index to be in range. Optional index = getConstantIndex(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 448e97c..0ee79a6 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -328,6 +328,20 @@ Optional DimOp::getConstantIndex() { return {}; } +Speculation::Speculatability DimOp::getSpeculatability() { + auto constantIndex = getConstantIndex(); + if (!constantIndex) + return Speculation::NotSpeculatable; + + auto rankedSourceType = dyn_cast(getSource().getType()); + if (!rankedSourceType) + return Speculation::NotSpeculatable; + + // The verifier rejects operations that violate this assertion. + assert(constantIndex < rankedSourceType.getRank()); + return Speculation::Speculatable; +} + LogicalResult DimOp::verify() { // Assume unknown index to be in range. Optional index = getConstantIndex(); diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index 0b74c81..b8d3450 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -503,3 +503,107 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste return } + +// ----- + +func.func @speculate_tensor_dim_unknown_rank_unknown_dim( +// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_unknown_dim + %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: tensor.dim + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %dim_idx : tensor<*xf32> + } + + return +} + +func.func @speculate_tensor_dim_known_rank_unknown_dim( +// CHECK-LABEL: @speculate_tensor_dim_known_rank_unknown_dim + %t: tensor, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: tensor.dim + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %dim_idx : tensor + } + + return +} + +func.func @speculate_tensor_dim_unknown_rank_known_dim( +// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_known_dim + %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + // CHECK: scf.for + // CHECK-NEXT: tensor.dim + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %c0 : tensor<*xf32> + } + + return +} + +func.func @speculate_tensor_dim_known_rank_known_dim_inbounds( +// CHECK-LABEL: @speculate_tensor_dim_known_rank_known_dim_inbounds + %t: tensor, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c1 = arith.constant 1 : index + // CHECK: tensor.dim + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %val = tensor.dim %t, %c1 : tensor + } + + return +} + +// ----- + +func.func @speculate_memref_dim_unknown_rank_unknown_dim( +// CHECK-LABEL: @speculate_memref_dim_unknown_rank_unknown_dim + %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: memref.dim + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %dim_idx : memref<*xf32> + } + + return +} + +func.func @speculate_memref_dim_known_rank_unknown_dim( +// CHECK-LABEL: @speculate_memref_dim_known_rank_unknown_dim + %t: memref, %dim_idx: index, %lb: index, %ub: index, %step: index) { + // CHECK: scf.for + // CHECK-NEXT: memref.dim + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %dim_idx : memref + } + + return +} + +func.func @speculate_memref_dim_unknown_rank_known_dim( +// CHECK-LABEL: @speculate_memref_dim_unknown_rank_known_dim + %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : index + // CHECK: scf.for + // CHECK-NEXT: memref.dim + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %c0 : memref<*xf32> + } + + return +} + +func.func @speculate_memref_dim_known_rank_known_dim_inbounds( +// CHECK-LABEL: @speculate_memref_dim_known_rank_known_dim_inbounds + %t: memref, %dim_idx: index, %lb: index, %ub: index, %step: index) { + %c1 = arith.constant 1 : index + // CHECK: memref.dim + // CHECK-NEXT: scf.for + scf.for %i = %lb to %ub step %step { + %val = memref.dim %t, %c1 : memref + } + + return +} -- 2.7.4