[mlir][NFC] Use `getConstantIntValue` instead of casting to `ConstantIndexOp`
authorMatthias Springer <me@m-sp.org>
Tue, 4 Jul 2023 12:03:02 +0000 (14:03 +0200)
committerMatthias Springer <me@m-sp.org>
Tue, 4 Jul 2023 12:08:37 +0000 (14:08 +0200)
`getConstantIntValue` extracts constant values from all constant-like ops, not just `arith::ConstantIndexOp`.

Differential Revision: https://reviews.llvm.org/D154356

15 files changed:
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
mlir/lib/Dialect/Affine/Analysis/Utils.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index f9245ad..aff620a 100644 (file)
@@ -171,13 +171,6 @@ struct AffineLoopToGpuConverter {
 };
 } // namespace
 
-// Return true if the value is obviously a constant "one".
-static bool isConstantOne(Value value) {
-  if (auto def = value.getDefiningOp<arith::ConstantIndexOp>())
-    return def.value() == 1;
-  return false;
-}
-
 // Collect ranges, bounds, steps and induction variables in preparation for
 // mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel.
 // This may fail if the IR for computing loop bounds cannot be constructed, for
@@ -201,7 +194,7 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) {
     Value range = builder.create<arith::SubIOp>(currentLoop.getLoc(),
                                                 upperBound, lowerBound);
     Value step = getOrCreateStep(currentLoop, builder);
-    if (!isConstantOne(step))
+    if (getConstantIntValue(step) != static_cast<int64_t>(1))
       range = builder.create<arith::DivSIOp>(currentLoop.getLoc(), range, step);
     dims.push_back(range);
 
@@ -269,7 +262,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
             ? getDim3Value(launchOp.getBlockIds(), en.index())
             : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
     Value step = steps[en.index()];
-    if (!isConstantOne(step))
+    if (getConstantIntValue(step) != static_cast<int64_t>(1))
       id = builder.create<arith::MulIOp>(rootForOp.getLoc(), step, id);
 
     Value ivReplacement =
index fc567b9..5f32505 100644 (file)
@@ -16,7 +16,7 @@
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/LLVM.h"
@@ -61,7 +61,7 @@ void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
   // Add top level symbol.
   appendSymbolVar(val);
   // Check if the symbol is a constant.
-  if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
+  if (std::optional<int64_t> constOp = getConstantIntValue(val))
     addBound(BoundType::EQ, val, constOp.value());
 }
 
index 3687b66..ce3ff0a 100644 (file)
@@ -568,7 +568,7 @@ ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const {
     assert(cst->containsVar(value) && "value expected to be present");
     if (isValidSymbol(value)) {
       // Check if the symbol is a constant.
-      if (auto cOp = value.getDefiningOp<arith::ConstantIndexOp>())
+      if (std::optional<int64_t> cOp = getConstantIntValue(value))
         cst->addBound(BoundType::EQ, value, cOp.value());
     } else if (auto loop = getForInductionVarOwner(value)) {
       if (failed(cst->addAffineForOpDomain(loop)))
index ca676d9..0e8e849 100644 (file)
@@ -2713,8 +2713,8 @@ static AffineForOp
 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
                           int64_t step,
                           AffineForOp::BodyBuilderFn bodyBuilderFn) {
-  auto lbConst = lb.getDefiningOp<arith::ConstantIndexOp>();
-  auto ubConst = ub.getDefiningOp<arith::ConstantIndexOp>();
+  std::optional<int64_t> lbConst = getConstantIntValue(lb);
+  std::optional<int64_t> ubConst = getConstantIntValue(ub);
   if (lbConst && ubConst)
     return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
                                         ubConst.value(), step, bodyBuilderFn);
index da59b59..414b54f 100644 (file)
@@ -1739,7 +1739,7 @@ struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
 
   LogicalResult matchAndRewrite(memref::DimOp dimOp,
                                 PatternRewriter &rewriter) const override {
-    auto index = dimOp.getIndex().getDefiningOp<arith::ConstantIndexOp>();
+    std::optional<int64_t> index = dimOp.getConstantIndex();
     if (!index)
       return failure();
 
index 319f73b..8cf85eb 100644 (file)
@@ -55,7 +55,7 @@ static Value allocBuffer(ImplicitLocOpBuilder &b,
     alignmentAttr = b.getI64IntegerAttr(alignment.value());
 
   // Static buffer.
-  if (auto cst = allocSize.getDefiningOp<arith::ConstantIndexOp>()) {
+  if (std::optional<int64_t> cst = getConstantIntValue(allocSize)) {
     auto staticBufferType =
         MemRefType::get(width * cst.value(), b.getIntegerType(8));
     if (options.useAlloca) {
index ebfdc6e..2c6afd4 100644 (file)
@@ -45,18 +45,6 @@ using namespace mlir::scf;
 
 #define DEBUG_TYPE "linalg-tiling"
 
-static bool isZero(OpFoldResult v) {
-  if (!v)
-    return false;
-  if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
-    IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
-    return intAttr && intAttr.getValue().isZero();
-  }
-  if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
-    return cst.value() == 0;
-  return false;
-}
-
 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
                                   ArrayRef<OpFoldResult> allShapeSizes,
@@ -70,7 +58,8 @@ mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
-    if (isZero(tileSizes[idx - zerosCount])) {
+    if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
+        static_cast<int64_t>(0)) {
       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
       ++zerosCount;
@@ -473,7 +462,9 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
   // Initial tile sizes may be too big, only take the first nLoops.
   tileSizes = tileSizes.take_front(nLoops);
 
-  if (llvm::all_of(tileSizes, isZero)) {
+  if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
+        return getConstantIntValue(ofr) == static_cast<int64_t>(0);
+      })) {
     TiledLinalgOp tiledOp;
     tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
     tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
index fe324ab..eb62dca 100644 (file)
@@ -373,9 +373,9 @@ void GenerateLoopNest<AffineForOp>::doit(
   SmallVector<int64_t, 4> constantSteps;
   constantSteps.reserve(steps.size());
   for (Value v : steps) {
-    auto op = v.getDefiningOp<arith::ConstantIndexOp>();
-    assert(op && "Affine loops require constant steps");
-    constantSteps.push_back(op.value());
+    auto constVal = getConstantIntValue(v);
+    assert(constVal.has_value() && "Affine loops require constant steps");
+    constantSteps.push_back(constVal.value());
   }
 
   affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
index 4f805d6..01cfa67 100644 (file)
@@ -2746,8 +2746,8 @@ LogicalResult ParallelOp::verify() {
 
   // Check whether all constant step values are positive.
   for (Value stepValue : stepValues)
-    if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
-      if (cst.value() <= 0)
+    if (auto cst = getConstantIntValue(stepValue))
+      if (*cst <= 0)
         return emitOpError("constant step operand must be positive");
 
   // Check that the body defines the same number of block arguments as the
index 96e4f89..b3e8ef7 100644 (file)
@@ -438,9 +438,9 @@ LogicalResult mlir::loopUnrollByFactor(
   Value stepUnrolled;
   bool generateEpilogueLoop = true;
 
-  auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
-  auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
-  auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
   if (lbCstOp && ubCstOp && stepCstOp) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.value();
@@ -467,7 +467,7 @@ LogicalResult mlir::loopUnrollByFactor(
       upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
           loc, upperBoundUnrolledCst);
     else
-      upperBoundUnrolled = ubCstOp;
+      upperBoundUnrolled = forOp.getUpperBound();
 
     // Create constant for 'stepUnrolled'.
     stepUnrolled = stepCst == stepUnrolledCst
@@ -550,11 +550,11 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
   bool isZeroBased = false;
-  if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>())
+  if (auto ubCst = getConstantIntValue(lowerBound))
     isZeroBased = ubCst.value() == 0;
 
   bool isStepOne = false;
-  if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>())
+  if (auto stepCst = getConstantIntValue(step))
     isStepOne = stepCst.value() == 1;
 
   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
index 474db31..22311ab 100644 (file)
@@ -68,6 +68,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRDialect
+  MLIRDialectUtils
   MLIRIR
   MLIRInferTypeOpInterface
   MLIRSupport
index 979a86d..3037429 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -1216,7 +1217,7 @@ void PushBackOp::build(OpBuilder &builder, OperationState &result,
 
 LogicalResult PushBackOp::verify() {
   if (Value n = getN()) {
-    auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+    std::optional<int64_t> nValue = getConstantIntValue(n);
     if (nValue && nValue.value() < 1)
       return emitOpError("n must be not less than 1");
   }
@@ -1324,7 +1325,7 @@ LogicalResult SortOp::verify() {
   if (getXs().empty())
     return emitError("need at least one xs buffer.");
 
-  auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
+  std::optional<int64_t> n = getConstantIntValue(getN());
 
   Type xtp = getMemRefType(getXs().front()).getElementType();
   auto checkTypes = [&](ValueRange operands,
@@ -1349,7 +1350,7 @@ LogicalResult SortOp::verify() {
 }
 
 LogicalResult SortCooOp::verify() {
-  auto cn = getN().getDefiningOp<arith::ConstantIndexOp>();
+  std::optional<int64_t> cn = getConstantIntValue(getN());
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
   if (!cn)
index c2562af..e25e0f2 100644 (file)
@@ -1148,7 +1148,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
                               Value source, ValueRange position) {
   SmallVector<int64_t, 4> positionConstants =
       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
-        return pos.getDefiningOp<arith::ConstantIndexOp>().value();
+        return getConstantIntValue(pos).value();
       }));
   build(builder, result, source, positionConstants);
 }
@@ -2318,7 +2318,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
                      Value dest, ValueRange position) {
   SmallVector<int64_t, 4> positionConstants =
       llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
-        return pos.getDefiningOp<arith::ConstantIndexOp>().value();
+        return getConstantIntValue(pos).value();
       }));
   build(builder, result, source, dest, positionConstants);
 }
@@ -2908,18 +2908,16 @@ LogicalResult ReshapeOp::verify() {
   // If all shape operands are produced by constant ops, verify that product
   // of dimensions for input/output shape match.
   auto isDefByConstant = [](Value operand) {
-    return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
+    return getConstantIntValue(operand).has_value();
   };
   if (llvm::all_of(getInputShape(), isDefByConstant) &&
       llvm::all_of(getOutputShape(), isDefByConstant)) {
     int64_t numInputElements = 1;
     for (auto operand : getInputShape())
-      numInputElements *=
-          cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
+      numInputElements *= getConstantIntValue(operand).value();
     int64_t numOutputElements = 1;
     for (auto operand : getOutputShape())
-      numOutputElements *=
-          cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
+      numOutputElements *= getConstantIntValue(operand).value();
     if (numInputElements != numOutputElements)
       return emitError("product of input and output shape sizes must match");
   }
@@ -3645,8 +3643,8 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   if (op.getShapedType().isDynamicDim(indicesIdx))
     return false;
   Value index = op.getIndices()[indicesIdx];
-  auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
-  if (!cstOp)
+  std::optional<int64_t> cstOp = getConstantIntValue(index);
+  if (!cstOp.has_value())
     return false;
 
   int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
@@ -4031,8 +4029,8 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
     return failure();
   // If any index is nonzero.
   auto isNotConstantZero = [](Value v) {
-    auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
-    return !cstOp || cstOp.value() != 0;
+    auto cstOp = getConstantIntValue(v);
+    return !cstOp.has_value() || cstOp.value() != 0;
   };
   if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
       llvm::any_of(write.getIndices(), isNotConstantZero))
@@ -5269,7 +5267,7 @@ public:
                                 PatternRewriter &rewriter) const override {
     // Return if any of 'createMaskOp' operands are not defined by a constant.
     auto isNotDefByConstant = [](Value operand) {
-      return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
+      return !getConstantIntValue(operand).has_value();
     };
     if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
       return failure();
@@ -5291,8 +5289,7 @@ public:
     maskDimSizes.reserve(createMaskOp->getNumOperands());
     for (auto [operand, maxDimSize] : llvm::zip_equal(
              createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
-      Operation *defOp = operand.getDefiningOp();
-      int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
+      int64_t dimSize = getConstantIntValue(operand).value();
       dimSize = std::min(dimSize, maxDimSize);
       // If one of dim sizes is zero, set all dims to zero.
       if (dimSize <= 0) {
index fa901d0..dd4948f 100644 (file)
@@ -297,12 +297,6 @@ static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
   return reducedShape;
 }
 
-/// Returns true if all values are `arith.constant 0 : index`
-static bool isZero(Value v) {
-  auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
-  return cst && cst.value() == 0;
-}
-
 namespace {
 
 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
@@ -338,8 +332,9 @@ class TransferReadDropUnitDimsPattern
     int vectorReducedRank = getReducedRank(vectorType.getShape());
     if (reducedRank != vectorReducedRank)
       return failure();
-    if (llvm::any_of(transferReadOp.getIndices(),
-                     [](Value v) { return !isZero(v); }))
+    if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
+          return getConstantIntValue(v) != static_cast<int64_t>(0);
+        }))
       return failure();
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
@@ -392,8 +387,9 @@ class TransferWriteDropUnitDimsPattern
     int vectorReducedRank = getReducedRank(vectorType.getShape());
     if (reducedRank != vectorReducedRank)
       return failure();
-    if (llvm::any_of(transferWriteOp.getIndices(),
-                     [](Value v) { return !isZero(v); }))
+    if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
+          return getConstantIntValue(v) != static_cast<int64_t>(0);
+        }))
       return failure();
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
@@ -463,8 +459,7 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
   if (firstDimToCollapse >= rank)
     return failure();
   for (int64_t i = firstDimToCollapse; i < rank; ++i) {
-    arith::ConstantIndexOp cst =
-        indices[i].getDefiningOp<arith::ConstantIndexOp>();
+    std::optional<int64_t> cst = getConstantIntValue(indices[i]);
     if (!cst || cst.value() != 0)
       return failure();
   }
index 17c1c2b..85d6bc0 100644 (file)
@@ -2583,6 +2583,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":ArithDialect",
+        ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",
         ":SparseTensorAttrDefsIncGen",