From cb7bda2ace81226c5b33165411dd0316f93fa57e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 4 Jul 2023 14:03:02 +0200 Subject: [PATCH] [mlir][NFC] Use `getConstantIntValue` instead of casting to `ConstantIndexOp` `getConstantIntValue` extracts constant values from all constant-like ops, not just `arith::ConstantIndexOp`. Differential Revision: https://reviews.llvm.org/D154356 --- mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 11 ++-------- .../Dialect/Affine/Analysis/AffineStructures.cpp | 4 ++-- mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 4 ++-- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 19 +++++----------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 6 +++--- mlir/lib/Dialect/SCF/IR/SCF.cpp | 4 ++-- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 12 +++++------ mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt | 1 + .../SparseTensor/IR/SparseTensorDialect.cpp | 7 +++--- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++------------ .../Transforms/VectorTransferOpTransforms.cpp | 19 ++++++---------- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 15 files changed, 49 insertions(+), 70 deletions(-) diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index f9245ad..aff620a 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -171,13 +171,6 @@ struct AffineLoopToGpuConverter { }; } // namespace -// Return true if the value is obviously a constant "one". -static bool isConstantOne(Value value) { - if (auto def = value.getDefiningOp()) - return def.value() == 1; - return false; -} - // Collect ranges, bounds, steps and induction variables in preparation for // mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel. // This may fail if the IR for computing loop bounds cannot be constructed, for @@ -201,7 +194,7 @@ AffineLoopToGpuConverter::collectBounds(AffineForOp forOp, unsigned numLoops) { Value range = builder.create(currentLoop.getLoc(), upperBound, lowerBound); Value step = getOrCreateStep(currentLoop, builder); - if (!isConstantOne(step)) + if (getConstantIntValue(step) != static_cast(1)) range = builder.create(currentLoop.getLoc(), range, step); dims.push_back(range); @@ -269,7 +262,7 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp, ? getDim3Value(launchOp.getBlockIds(), en.index()) : getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims); Value step = steps[en.index()]; - if (!isConstantOne(step)) + if (getConstantIntValue(step) != static_cast(1)) id = builder.create(rootForOp.getLoc(), step, id); Value ivReplacement = diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp index fc567b9..5f32505 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -16,7 +16,7 @@ #include "mlir/Analysis/Presburger/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" -#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Support/LLVM.h" @@ -61,7 +61,7 @@ void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { // Add top level symbol. appendSymbolVar(val); // Check if the symbol is a constant. - if (auto constOp = val.getDefiningOp()) + if (std::optional constOp = getConstantIntValue(val)) addBound(BoundType::EQ, val, constOp.value()); } diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 3687b66..ce3ff0a 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -568,7 +568,7 @@ ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const { assert(cst->containsVar(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. - if (auto cOp = value.getDefiningOp()) + if (std::optional cOp = getConstantIntValue(value)) cst->addBound(BoundType::EQ, value, cOp.value()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index ca676d9..0e8e849 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2713,8 +2713,8 @@ static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { - auto lbConst = lb.getDefiningOp(); - auto ubConst = ub.getDefiningOp(); + std::optional lbConst = getConstantIntValue(lb); + std::optional ubConst = getConstantIntValue(ub); if (lbConst && ubConst) return buildAffineLoopFromConstants(builder, loc, lbConst.value(), ubConst.value(), step, bodyBuilderFn); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index da59b59..414b54f 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1739,7 +1739,7 @@ struct SimplifyDimOfAllocOp : public OpRewritePattern { LogicalResult matchAndRewrite(memref::DimOp dimOp, PatternRewriter &rewriter) const override { - auto index = dimOp.getIndex().getDefiningOp(); + std::optional index = dimOp.getConstantIndex(); if (!index) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 319f73b..8cf85eb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -55,7 +55,7 @@ static Value allocBuffer(ImplicitLocOpBuilder &b, alignmentAttr = b.getI64IntegerAttr(alignment.value()); // Static buffer. - if (auto cst = allocSize.getDefiningOp()) { + if (std::optional cst = getConstantIntValue(allocSize)) { auto staticBufferType = MemRefType::get(width * cst.value(), b.getIntegerType(8)); if (options.useAlloca) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index ebfdc6e..2c6afd4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -45,18 +45,6 @@ using namespace mlir::scf; #define DEBUG_TYPE "linalg-tiling" -static bool isZero(OpFoldResult v) { - if (!v) - return false; - if (auto attr = llvm::dyn_cast_if_present(v)) { - IntegerAttr intAttr = dyn_cast(attr); - return intAttr && intAttr.getValue().isZero(); - } - if (auto cst = v.get().getDefiningOp()) - return cst.value() == 0; - return false; -} - std::tuple, LoopIndexToRangeIndexMap> mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, ArrayRef allShapeSizes, @@ -70,7 +58,8 @@ mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, // Traverse the tile sizes, which are in loop order, erase zeros everywhere. LoopIndexToRangeIndexMap loopIndexToRangeIndex; for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { - if (isZero(tileSizes[idx - zerosCount])) { + if (getConstantIntValue(tileSizes[idx - zerosCount]) == + static_cast(0)) { shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); tileSizes.erase(tileSizes.begin() + idx - zerosCount); ++zerosCount; @@ -473,7 +462,9 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, // Initial tile sizes may be too big, only take the first nLoops. tileSizes = tileSizes.take_front(nLoops); - if (llvm::all_of(tileSizes, isZero)) { + if (llvm::all_of(tileSizes, [](OpFoldResult ofr) { + return getConstantIntValue(ofr) == static_cast(0); + })) { TiledLinalgOp tiledOp; tiledOp.op = cast(b.clone(*op.getOperation())); tiledOp.tensorResults.assign(tiledOp.op->result_begin(), diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index fe324ab..eb62dca 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -373,9 +373,9 @@ void GenerateLoopNest::doit( SmallVector constantSteps; constantSteps.reserve(steps.size()); for (Value v : steps) { - auto op = v.getDefiningOp(); - assert(op && "Affine loops require constant steps"); - constantSteps.push_back(op.value()); + auto constVal = getConstantIntValue(v); + assert(constVal.has_value() && "Affine loops require constant steps"); + constantSteps.push_back(constVal.value()); } affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 4f805d6..01cfa67 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2746,8 +2746,8 @@ LogicalResult ParallelOp::verify() { // Check whether all constant step values are positive. for (Value stepValue : stepValues) - if (auto cst = stepValue.getDefiningOp()) - if (cst.value() <= 0) + if (auto cst = getConstantIntValue(stepValue)) + if (*cst <= 0) return emitOpError("constant step operand must be positive"); // Check that the body defines the same number of block arguments as the diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 96e4f89..b3e8ef7 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -438,9 +438,9 @@ LogicalResult mlir::loopUnrollByFactor( Value stepUnrolled; bool generateEpilogueLoop = true; - auto lbCstOp = forOp.getLowerBound().getDefiningOp(); - auto ubCstOp = forOp.getUpperBound().getDefiningOp(); - auto stepCstOp = forOp.getStep().getDefiningOp(); + std::optional lbCstOp = getConstantIntValue(forOp.getLowerBound()); + std::optional ubCstOp = getConstantIntValue(forOp.getUpperBound()); + std::optional stepCstOp = getConstantIntValue(forOp.getStep()); if (lbCstOp && ubCstOp && stepCstOp) { // Constant loop bounds computation. int64_t lbCst = lbCstOp.value(); @@ -467,7 +467,7 @@ LogicalResult mlir::loopUnrollByFactor( upperBoundUnrolled = boundsBuilder.create( loc, upperBoundUnrolledCst); else - upperBoundUnrolled = ubCstOp; + upperBoundUnrolled = forOp.getUpperBound(); // Create constant for 'stepUnrolled'. stepUnrolled = stepCst == stepUnrolledCst @@ -550,11 +550,11 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder, // Check if the loop is already known to have a constant zero lower bound or // a constant one step. bool isZeroBased = false; - if (auto ubCst = lowerBound.getDefiningOp()) + if (auto ubCst = getConstantIntValue(lowerBound)) isZeroBased = ubCst.value() == 0; bool isStepOne = false; - if (auto stepCst = step.getDefiningOp()) + if (auto stepCst = getConstantIntValue(step)) isStepOne = stepCst.value() == 1; // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt index 474db31..22311ab 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt @@ -68,6 +68,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect LINK_LIBS PUBLIC MLIRArithDialect MLIRDialect + MLIRDialectUtils MLIRIR MLIRInferTypeOpInterface MLIRSupport diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 979a86d..3037429 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" @@ -1216,7 +1217,7 @@ void PushBackOp::build(OpBuilder &builder, OperationState &result, LogicalResult PushBackOp::verify() { if (Value n = getN()) { - auto nValue = dyn_cast_or_null(n.getDefiningOp()); + std::optional nValue = getConstantIntValue(n); if (nValue && nValue.value() < 1) return emitOpError("n must be not less than 1"); } @@ -1324,7 +1325,7 @@ LogicalResult SortOp::verify() { if (getXs().empty()) return emitError("need at least one xs buffer."); - auto n = getN().getDefiningOp(); + std::optional n = getConstantIntValue(getN()); Type xtp = getMemRefType(getXs().front()).getElementType(); auto checkTypes = [&](ValueRange operands, @@ -1349,7 +1350,7 @@ LogicalResult SortOp::verify() { } LogicalResult SortCooOp::verify() { - auto cn = getN().getDefiningOp(); + std::optional cn = getConstantIntValue(getN()); // We can't check the size of the buffers when n or buffer dimensions aren't // compile-time constants. if (!cn) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c2562af..e25e0f2 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1148,7 +1148,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, ValueRange position) { SmallVector positionConstants = llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return pos.getDefiningOp().value(); + return getConstantIntValue(pos).value(); })); build(builder, result, source, positionConstants); } @@ -2318,7 +2318,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, ValueRange position) { SmallVector positionConstants = llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { - return pos.getDefiningOp().value(); + return getConstantIntValue(pos).value(); })); build(builder, result, source, dest, positionConstants); } @@ -2908,18 +2908,16 @@ LogicalResult ReshapeOp::verify() { // If all shape operands are produced by constant ops, verify that product // of dimensions for input/output shape match. auto isDefByConstant = [](Value operand) { - return isa_and_nonnull(operand.getDefiningOp()); + return getConstantIntValue(operand).has_value(); }; if (llvm::all_of(getInputShape(), isDefByConstant) && llvm::all_of(getOutputShape(), isDefByConstant)) { int64_t numInputElements = 1; for (auto operand : getInputShape()) - numInputElements *= - cast(operand.getDefiningOp()).value(); + numInputElements *= getConstantIntValue(operand).value(); int64_t numOutputElements = 1; for (auto operand : getOutputShape()) - numOutputElements *= - cast(operand.getDefiningOp()).value(); + numOutputElements *= getConstantIntValue(operand).value(); if (numInputElements != numOutputElements) return emitError("product of input and output shape sizes must match"); } @@ -3645,8 +3643,8 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { if (op.getShapedType().isDynamicDim(indicesIdx)) return false; Value index = op.getIndices()[indicesIdx]; - auto cstOp = index.getDefiningOp(); - if (!cstOp) + std::optional cstOp = getConstantIntValue(index); + if (!cstOp.has_value()) return false; int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); @@ -4031,8 +4029,8 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, return failure(); // If any index is nonzero. auto isNotConstantZero = [](Value v) { - auto cstOp = v.getDefiningOp(); - return !cstOp || cstOp.value() != 0; + auto cstOp = getConstantIntValue(v); + return !cstOp.has_value() || cstOp.value() != 0; }; if (llvm::any_of(read.getIndices(), isNotConstantZero) || llvm::any_of(write.getIndices(), isNotConstantZero)) @@ -5269,7 +5267,7 @@ public: PatternRewriter &rewriter) const override { // Return if any of 'createMaskOp' operands are not defined by a constant. auto isNotDefByConstant = [](Value operand) { - return !isa_and_nonnull(operand.getDefiningOp()); + return !getConstantIntValue(operand).has_value(); }; if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant)) return failure(); @@ -5291,8 +5289,7 @@ public: maskDimSizes.reserve(createMaskOp->getNumOperands()); for (auto [operand, maxDimSize] : llvm::zip_equal( createMaskOp.getOperands(), createMaskOp.getType().getShape())) { - Operation *defOp = operand.getDefiningOp(); - int64_t dimSize = cast(defOp).value(); + int64_t dimSize = getConstantIntValue(operand).value(); dimSize = std::min(dimSize, maxDimSize); // If one of dim sizes is zero, set all dims to zero. if (dimSize <= 0) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index fa901d0..dd4948f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -297,12 +297,6 @@ static SmallVector getReducedShape(ArrayRef shape) { return reducedShape; } -/// Returns true if all values are `arith.constant 0 : index` -static bool isZero(Value v) { - auto cst = v.getDefiningOp(); - return cst && cst.value() == 0; -} - namespace { /// Rewrites `vector.transfer_read` ops where the source has unit dims, by @@ -338,8 +332,9 @@ class TransferReadDropUnitDimsPattern int vectorReducedRank = getReducedRank(vectorType.getShape()); if (reducedRank != vectorReducedRank) return failure(); - if (llvm::any_of(transferReadOp.getIndices(), - [](Value v) { return !isZero(v); })) + if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { + return getConstantIntValue(v) != static_cast(0); + })) return failure(); Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); @@ -392,8 +387,9 @@ class TransferWriteDropUnitDimsPattern int vectorReducedRank = getReducedRank(vectorType.getShape()); if (reducedRank != vectorReducedRank) return failure(); - if (llvm::any_of(transferWriteOp.getIndices(), - [](Value v) { return !isZero(v); })) + if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { + return getConstantIntValue(v) != static_cast(0); + })) return failure(); Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); @@ -463,8 +459,7 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, if (firstDimToCollapse >= rank) return failure(); for (int64_t i = firstDimToCollapse; i < rank; ++i) { - arith::ConstantIndexOp cst = - indices[i].getDefiningOp(); + std::optional cst = getConstantIntValue(indices[i]); if (!cst || cst.value() != 0) return failure(); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 17c1c2b..85d6bc0 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2583,6 +2583,7 @@ cc_library( includes = ["include"], deps = [ ":ArithDialect", + ":DialectUtils", ":IR", ":InferTypeOpInterface", ":SparseTensorAttrDefsIncGen", -- 2.7.4