/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val);
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
-/// Helper function to convert a Value into an OpFoldResult, if the Value is
-/// known to be a constant index value.
-static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
- return llvm::to_vector<4>(
- llvm::map_range(values, [](Value v) -> OpFoldResult {
- APInt intValue;
- if (v.getType().isa<IndexType>() &&
- matchPattern(v, m_ConstantInt(&intValue))) {
- return IntegerAttr::get(v.getType(), intValue.getSExtValue());
- }
- return v;
- }));
-}
-
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
/// `Value`s.
static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
return builder.create<ConstantIndexOp>(loc, *intVal);
}
-/// Given a value, try to extract a constant index-type integer as an Attribute.
-/// If this fails, return the original value.
-static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) {
- if (auto constInt = getConstantIntValue(val))
- return builder.getIndexAttr(*constInt);
- return val;
-}
-
LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
// ExtractSliceOp length will be zero in that case. (Effectively reading no
// data from the source.)
Value newOffset = min(max(sub(offset, low), zero), srcSize);
- newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
+ newOffsets.push_back(getAsOpFoldResult(newOffset));
// The original ExtractSliceOp was reading until position `offset + length`.
// Therefore, the corresponding position within the source tensor is:
// The new ExtractSliceOp length is `endLoc - newOffset`.
Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
Value newLength = sub(endLoc, newOffset);
- newLengths.push_back(asOpFoldResult(rewriter, newLength));
+ newLengths.push_back(getAsOpFoldResult(newLength));
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
}));
}
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val) {
+ Attribute attr;
+ if (matchPattern(val, m_Constant(&attr)))
+ return attr;
+ return val;
+}
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
+ return llvm::to_vector<4>(
+ llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
+}
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.