// TODO: Canonicalization should be implemented for shapes that can be
// determined through mixtures of the known dimensions of the inputs.
-OpFoldResult AnyOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
// Only the last operand is checked because AnyOp is commutative.
- if (operands.back())
- return operands.back();
+ if (adaptor.getInputs().back())
+ return adaptor.getInputs().back();
return nullptr;
}
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
-OpFoldResult mlir::shape::AddOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
// add(x, 0) -> x
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
+ adaptor.getOperands(),
+ [](APInt a, const APInt &b) { return std::move(a) + b; });
}
LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
-OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
// Iterate in reverse to first handle all constant operands. They are
// guaranteed to be the tail of the inputs because this is commutative.
- for (int idx = operands.size() - 1; idx >= 0; idx--) {
- Attribute a = operands[idx];
+ for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
+ Attribute a = adaptor.getInputs()[idx];
// Cannot fold if any inputs are not constant;
if (!a)
return nullptr;
// BroadcastOp
//===----------------------------------------------------------------------===//
-OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getShapes().size() == 1) {
// Otherwise, we need a cast which would be a canonicalization, not folding.
if (getShapes().front().getType() != getType())
if (getShapes().size() > 2)
return nullptr;
- if (!operands[0] || !operands[1])
+ if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
return nullptr;
auto lhsShape = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ adaptor.getShapes()[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
- operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ adaptor.getShapes()[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.
// ConcatOp
//===----------------------------------------------------------------------===//
-OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0] || !operands[1])
+OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
+ if (!adaptor.getLhs() || !adaptor.getRhs())
return nullptr;
auto lhsShape = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
- operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
resultShape.append(lhsShape.begin(), lhsShape.end());
resultShape.append(rhsShape.begin(), rhsShape.end());
return success();
}
-OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return getShapeAttr(); }
+OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
return true;
}
-OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
// No broadcasting is needed if all operands but one are scalar.
- if (hasAtMostSingleNonScalar(operands))
+ if (hasAtMostSingleNonScalar(adaptor.getShapes()))
return BoolAttr::get(getContext(), true);
if ([&] {
SmallVector<SmallVector<int64_t, 6>, 6> extents;
- for (const auto &operand : operands) {
+ for (const auto &operand : adaptor.getShapes()) {
if (!operand)
return false;
extents.push_back(llvm::to_vector<6>(
patterns.add<CstrEqEqOps>(context);
}
-OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
- if (llvm::all_of(operands,
- [&](Attribute a) { return a && a == operands[0]; }))
+OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
+ if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
+ return a && a == adaptor.getShapes().front();
+ }))
return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion
build(builder, result, builder.getIndexAttr(value));
}
-OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return getValueAttr(); }
+OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
void ConstSizeOp::getAsmResultNames(
llvm::function_ref<void(Value, StringRef)> setNameFn) {
// ConstWitnessOp
//===----------------------------------------------------------------------===//
-OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) {
- return getPassingAttr();
-}
+OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
//===----------------------------------------------------------------------===//
// CstrRequireOp
//===----------------------------------------------------------------------===//
-OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
- return operands[0];
+OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
+ return adaptor.getPred();
}
//===----------------------------------------------------------------------===//
return std::nullopt;
}
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
Type valType = getValue().getType();
auto valShapedType = valType.dyn_cast<ShapedType>();
if (!valShapedType || !valShapedType.hasRank())
// DivOp
//===----------------------------------------------------------------------===//
-OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
- auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
+ auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return nullptr;
- auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+ auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
if (!rhs)
return nullptr;
// ShapeEqOp
//===----------------------------------------------------------------------===//
-OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
bool allSame = true;
- if (!operands.empty() && !operands[0])
+ if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
return {};
- for (Attribute operand : operands.drop_front(1)) {
+ for (Attribute operand : adaptor.getShapes().drop_front()) {
if (!operand)
return {};
- allSame = allSame && operand == operands[0];
+ allSame = allSame && operand == adaptor.getShapes().front();
}
return BoolAttr::get(getContext(), allSame);
}
// IndexToSizeOp
//===----------------------------------------------------------------------===//
-OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
// Constant values of both types, `shape.size` and `index`, are represented as
// `IntegerAttr`s which makes constant folding simple.
- if (Attribute arg = operands[0])
+ if (Attribute arg = adaptor.getArg())
return arg;
return {};
}
// FromExtentsOp
//===----------------------------------------------------------------------===//
-OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
- if (llvm::any_of(operands, [](Attribute a) { return !a; }))
+OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
+ if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
return nullptr;
SmallVector<int64_t, 6> extents;
- for (auto attr : operands)
+ for (auto attr : adaptor.getExtents())
extents.push_back(attr.cast<IntegerAttr>().getInt());
Builder builder(getContext());
return builder.getIndexTensorAttr(extents);
return std::nullopt;
}
-OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
- auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
+ auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
if (!elements)
return nullptr;
std::optional<int64_t> dim = getConstantDim();
patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
}
-OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
// Can always broadcast fewer than two shapes.
- if (operands.size() < 2) {
+ if (adaptor.getShapes().size() < 2) {
return BoolAttr::get(getContext(), true);
}
// RankOp
//===----------------------------------------------------------------------===//
-OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
- auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
+ auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
if (!shape)
return {};
int64_t rank = shape.getNumElements();
// NumElementsOp
//===----------------------------------------------------------------------===//
-OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
// Fold only when argument constant.
- Attribute shape = operands[0];
+ Attribute shape = adaptor.getShape();
if (!shape)
return {};
// MaxOp
//===----------------------------------------------------------------------===//
-OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
+OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
// If operands are equal, just propagate one.
if (getLhs() == getRhs())
return getLhs();
// MinOp
//===----------------------------------------------------------------------===//
-OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
+OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
// If operands are equal, just propagate one.
if (getLhs() == getRhs())
return getLhs();
// MulOp
//===----------------------------------------------------------------------===//
-OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
- auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
+ auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return nullptr;
- auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+ auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
if (!rhs)
return nullptr;
APInt folded = lhs.getValue() * rhs.getValue();
// ShapeOfOp
//===----------------------------------------------------------------------===//
-OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
+OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
auto type = getOperand().getType().dyn_cast<ShapedType>();
if (!type || !type.hasStaticShape())
return nullptr;
// SizeToIndexOp
//===----------------------------------------------------------------------===//
-OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
// Constant values of both types, `shape.size` and `index`, are represented as
// `IntegerAttr`s which makes constant folding simple.
- if (Attribute arg = operands[0])
+ if (Attribute arg = adaptor.getArg())
return arg;
return OpFoldResult();
}
// SplitAtOp
//===----------------------------------------------------------------------===//
-LogicalResult SplitAtOp::fold(ArrayRef<Attribute> operands,
+LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- if (!operands[0] || !operands[1])
+ if (!adaptor.getOperand() || !adaptor.getIndex())
return failure();
auto shapeVec = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
auto shape = llvm::ArrayRef(shapeVec);
- auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
+ auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
// Verify that the split point is in the correct range.
// TODO: Constant fold to an "error".
int64_t rank = shape.size();
return failure();
if (splitPoint < 0)
splitPoint += shape.size();
- Builder builder(operands[0].getContext());
+ Builder builder(adaptor.getOperand().getContext());
results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
return success();
// ToExtentTensorOp
//===----------------------------------------------------------------------===//
-OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0])
+OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
+ if (!adaptor.getInput())
return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
builder.getIndexType());
return DenseIntElementsAttr::get(type, shape);