return success();
}
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
+ auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
if (!index)
return {};
return success();
}
-OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// If this is a splat elements attribute, simply return the value. All of
// the elements of a splat attribute are the same.
- if (Attribute tensor = operands.front())
+ if (Attribute tensor = adaptor.getTensor())
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
return splatTensor.getSplatValue<Attribute>();
// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
- for (Attribute indice : llvm::drop_begin(operands, 1)) {
+ for (Attribute indice : adaptor.getIndices()) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
- if (Attribute tensor = operands.front()) {
+ if (Attribute tensor = adaptor.getTensor()) {
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValues<Attribute>()[indices];
build(builder, result, resultType, elements);
}
-OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
- if (!llvm::is_contained(operands, nullptr))
- return DenseElementsAttr::get(getType(), operands);
+OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
+ if (!llvm::is_contained(adaptor.getElements(), nullptr))
+ return DenseElementsAttr::get(getType(), adaptor.getElements());
return {};
}
return success();
}
-OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
- Attribute scalar = operands[0];
- Attribute dest = operands[1];
+OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
+ Attribute scalar = adaptor.getScalar();
+ Attribute dest = adaptor.getDest();
if (scalar && dest)
if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
if (scalar == splatDest.getSplatValue<Attribute>())
setNameFn(getResult(), "rank");
}
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
auto shapedType = type.dyn_cast<ShapedType>();
context);
}
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
- return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
+ adaptor.getOperands());
}
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
- return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
+ adaptor.getOperands());
}
//===----------------------------------------------------------------------===//
return {};
}
-OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
- if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
+ if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
auto resultType = getResult().getType().cast<ShapedType>();
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
return extractOp.getSource();
}
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
+OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return padValue;
}
-OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult PadOp::fold(FoldAdaptor) {
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
!getNofold())
return getSource();
setNameFn(getResult(), "splat");
}
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
+OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
+ auto constOperand = adaptor.getInput();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
return {};