`getConstantIntValue` extracts constant values from all constant-like ops, not just `arith::ConstantIndexOp`.
Differential Revision: https://reviews.llvm.org/D154356
};
} // namespace
-// Return true if the value is obviously a constant "one".
-static bool isConstantOne(Value value) {
- if (auto def = value.getDefiningOp<arith::ConstantIndexOp>())
- return def.value() == 1;
- return false;
-}
-
// Collect ranges, bounds, steps and induction variables in preparation for
// mapping a loop nest of depth "numLoops" rooted at "forOp" to a GPU kernel.
// This may fail if the IR for computing loop bounds cannot be constructed, for
Value range = builder.create<arith::SubIOp>(currentLoop.getLoc(),
upperBound, lowerBound);
Value step = getOrCreateStep(currentLoop, builder);
- if (!isConstantOne(step))
+ if (getConstantIntValue(step) != static_cast<int64_t>(1))
range = builder.create<arith::DivSIOp>(currentLoop.getLoc(), range, step);
dims.push_back(range);
? getDim3Value(launchOp.getBlockIds(), en.index())
: getDim3Value(launchOp.getThreadIds(), en.index() - numBlockDims);
Value step = steps[en.index()];
- if (!isConstantOne(step))
+ if (getConstantIntValue(step) != static_cast<int64_t>(1))
id = builder.create<arith::MulIOp>(rootForOp.getLoc(), step, id);
Value ivReplacement =
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/LLVM.h"
// Add top level symbol.
appendSymbolVar(val);
// Check if the symbol is a constant.
- if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
+ if (std::optional<int64_t> constOp = getConstantIntValue(val))
addBound(BoundType::EQ, val, constOp.value());
}
assert(cst->containsVar(value) && "value expected to be present");
if (isValidSymbol(value)) {
// Check if the symbol is a constant.
- if (auto cOp = value.getDefiningOp<arith::ConstantIndexOp>())
+ if (std::optional<int64_t> cOp = getConstantIntValue(value))
cst->addBound(BoundType::EQ, value, cOp.value());
} else if (auto loop = getForInductionVarOwner(value)) {
if (failed(cst->addAffineForOpDomain(loop)))
buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
int64_t step,
AffineForOp::BodyBuilderFn bodyBuilderFn) {
- auto lbConst = lb.getDefiningOp<arith::ConstantIndexOp>();
- auto ubConst = ub.getDefiningOp<arith::ConstantIndexOp>();
+ std::optional<int64_t> lbConst = getConstantIntValue(lb);
+ std::optional<int64_t> ubConst = getConstantIntValue(ub);
if (lbConst && ubConst)
return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
ubConst.value(), step, bodyBuilderFn);
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
- auto index = dimOp.getIndex().getDefiningOp<arith::ConstantIndexOp>();
+ std::optional<int64_t> index = dimOp.getConstantIndex();
if (!index)
return failure();
alignmentAttr = b.getI64IntegerAttr(alignment.value());
// Static buffer.
- if (auto cst = allocSize.getDefiningOp<arith::ConstantIndexOp>()) {
+ if (std::optional<int64_t> cst = getConstantIntValue(allocSize)) {
auto staticBufferType =
MemRefType::get(width * cst.value(), b.getIntegerType(8));
if (options.useAlloca) {
#define DEBUG_TYPE "linalg-tiling"
-static bool isZero(OpFoldResult v) {
- if (!v)
- return false;
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
- IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
- return intAttr && intAttr.getValue().isZero();
- }
- if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
- return cst.value() == 0;
- return false;
-}
-
std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
ArrayRef<OpFoldResult> allShapeSizes,
// Traverse the tile sizes, which are in loop order, erase zeros everywhere.
LoopIndexToRangeIndexMap loopIndexToRangeIndex;
for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
- if (isZero(tileSizes[idx - zerosCount])) {
+ if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
+ static_cast<int64_t>(0)) {
shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
tileSizes.erase(tileSizes.begin() + idx - zerosCount);
++zerosCount;
// Initial tile sizes may be too big, only take the first nLoops.
tileSizes = tileSizes.take_front(nLoops);
- if (llvm::all_of(tileSizes, isZero)) {
+ if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
+ return getConstantIntValue(ofr) == static_cast<int64_t>(0);
+ })) {
TiledLinalgOp tiledOp;
tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
SmallVector<int64_t, 4> constantSteps;
constantSteps.reserve(steps.size());
for (Value v : steps) {
- auto op = v.getDefiningOp<arith::ConstantIndexOp>();
- assert(op && "Affine loops require constant steps");
- constantSteps.push_back(op.value());
+ auto constVal = getConstantIntValue(v);
+ assert(constVal.has_value() && "Affine loops require constant steps");
+ constantSteps.push_back(constVal.value());
}
affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
// Check whether all constant step values are positive.
for (Value stepValue : stepValues)
- if (auto cst = stepValue.getDefiningOp<arith::ConstantIndexOp>())
- if (cst.value() <= 0)
+ if (auto cst = getConstantIntValue(stepValue))
+ if (*cst <= 0)
return emitOpError("constant step operand must be positive");
// Check that the body defines the same number of block arguments as the
Value stepUnrolled;
bool generateEpilogueLoop = true;
- auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
- auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
- auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+ std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
if (lbCstOp && ubCstOp && stepCstOp) {
// Constant loop bounds computation.
int64_t lbCst = lbCstOp.value();
upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
loc, upperBoundUnrolledCst);
else
- upperBoundUnrolled = ubCstOp;
+ upperBoundUnrolled = forOp.getUpperBound();
// Create constant for 'stepUnrolled'.
stepUnrolled = stepCst == stepUnrolledCst
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
bool isZeroBased = false;
- if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>())
+ if (auto ubCst = getConstantIntValue(lowerBound))
isZeroBased = ubCst.value() == 0;
bool isStepOne = false;
- if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>())
+ if (auto stepCst = getConstantIntValue(step))
isStepOne = stepCst.value() == 1;
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialect
+ MLIRDialectUtils
MLIRIR
MLIRInferTypeOpInterface
MLIRSupport
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
LogicalResult PushBackOp::verify() {
if (Value n = getN()) {
- auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
+ std::optional<int64_t> nValue = getConstantIntValue(n);
if (nValue && nValue.value() < 1)
return emitOpError("n must be not less than 1");
}
if (getXs().empty())
return emitError("need at least one xs buffer.");
- auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
+ std::optional<int64_t> n = getConstantIntValue(getN());
Type xtp = getMemRefType(getXs().front()).getElementType();
auto checkTypes = [&](ValueRange operands,
}
LogicalResult SortCooOp::verify() {
- auto cn = getN().getDefiningOp<arith::ConstantIndexOp>();
+ std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
if (!cn)
Value source, ValueRange position) {
SmallVector<int64_t, 4> positionConstants =
llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
- return pos.getDefiningOp<arith::ConstantIndexOp>().value();
+ return getConstantIntValue(pos).value();
}));
build(builder, result, source, positionConstants);
}
Value dest, ValueRange position) {
SmallVector<int64_t, 4> positionConstants =
llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
- return pos.getDefiningOp<arith::ConstantIndexOp>().value();
+ return getConstantIntValue(pos).value();
}));
build(builder, result, source, dest, positionConstants);
}
// If all shape operands are produced by constant ops, verify that product
// of dimensions for input/output shape match.
auto isDefByConstant = [](Value operand) {
- return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
+ return getConstantIntValue(operand).has_value();
};
if (llvm::all_of(getInputShape(), isDefByConstant) &&
llvm::all_of(getOutputShape(), isDefByConstant)) {
int64_t numInputElements = 1;
for (auto operand : getInputShape())
- numInputElements *=
- cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
+ numInputElements *= getConstantIntValue(operand).value();
int64_t numOutputElements = 1;
for (auto operand : getOutputShape())
- numOutputElements *=
- cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
+ numOutputElements *= getConstantIntValue(operand).value();
if (numInputElements != numOutputElements)
return emitError("product of input and output shape sizes must match");
}
if (op.getShapedType().isDynamicDim(indicesIdx))
return false;
Value index = op.getIndices()[indicesIdx];
- auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
- if (!cstOp)
+ std::optional<int64_t> cstOp = getConstantIntValue(index);
+ if (!cstOp.has_value())
return false;
int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
return failure();
// If any index is nonzero.
auto isNotConstantZero = [](Value v) {
- auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
- return !cstOp || cstOp.value() != 0;
+ auto cstOp = getConstantIntValue(v);
+ return !cstOp.has_value() || cstOp.value() != 0;
};
if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
llvm::any_of(write.getIndices(), isNotConstantZero))
PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
auto isNotDefByConstant = [](Value operand) {
- return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
+ return !getConstantIntValue(operand).has_value();
};
if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
return failure();
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
- Operation *defOp = operand.getDefiningOp();
- int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
+ int64_t dimSize = getConstantIntValue(operand).value();
dimSize = std::min(dimSize, maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
if (dimSize <= 0) {
return reducedShape;
}
-/// Returns true if all values are `arith.constant 0 : index`
-static bool isZero(Value v) {
- auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
- return cst && cst.value() == 0;
-}
-
namespace {
/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
int vectorReducedRank = getReducedRank(vectorType.getShape());
if (reducedRank != vectorReducedRank)
return failure();
- if (llvm::any_of(transferReadOp.getIndices(),
- [](Value v) { return !isZero(v); }))
+ if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
+ return getConstantIntValue(v) != static_cast<int64_t>(0);
+ }))
return failure();
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
int vectorReducedRank = getReducedRank(vectorType.getShape());
if (reducedRank != vectorReducedRank)
return failure();
- if (llvm::any_of(transferWriteOp.getIndices(),
- [](Value v) { return !isZero(v); }))
+ if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
+ return getConstantIntValue(v) != static_cast<int64_t>(0);
+ }))
return failure();
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
if (firstDimToCollapse >= rank)
return failure();
for (int64_t i = firstDimToCollapse; i < rank; ++i) {
- arith::ConstantIndexOp cst =
- indices[i].getDefiningOp<arith::ConstantIndexOp>();
+ std::optional<int64_t> cst = getConstantIntValue(indices[i]);
if (!cst || cst.value() != 0)
return failure();
}
includes = ["include"],
deps = [
":ArithDialect",
+ ":DialectUtils",
":IR",
":InferTypeOpInterface",
":SparseTensorAttrDefsIncGen",