// GenericOp
//===----------------------------------------------------------------------===//
+static void buildGenericRegion(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+ SmallVector<Type, 4> blockArgTypes;
+ SmallVector<Location, 4> blockArgLocs;
+ for (ValueRange container : {inputs, outputs}) {
+ for (Value v : container) {
+ blockArgTypes.push_back(getElementTypeOrSelf(v));
+ blockArgLocs.push_back(v.getLoc());
+ }
+ }
+
+ OpBuilder::InsertionGuard guard(builder);
+ auto ®ion = *result.regions.front();
+ Block *bodyBlock =
+ builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
+ bodyBuild(builder, result.location, bodyBlock->getArguments());
+}
+
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
iteratorTypes, doc, libraryCall);
result.addAttributes(attributes);
- if (!bodyBuild)
- return;
-
- SmallVector<Type, 4> blockArgTypes;
- SmallVector<Location, 4> blockArgLocs;
- for (ValueRange container : {inputs, outputs}) {
- for (Value v : container) {
- blockArgTypes.push_back(getElementTypeOrSelf(v));
- blockArgLocs.push_back(v.getLoc());
- }
- }
-
- OpBuilder::InsertionGuard guard(builder);
- auto ®ion = *result.regions.front();
- Block *bodyBlock =
- builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
- bodyBuild(builder, result.location, bodyBlock->getArguments());
+ if (bodyBuild)
+ buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
}
void GenericOp::build(
setNameFn(getResults().front(), "mapped");
}
+void MapOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, TypeRange{}, inputs, init);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
+
+ if (bodyBuild)
+ buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
+}
+
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseDstStyleOp(parser, result))
return failure();
setNameFn(getResults().front(), "reduced");
}
+void ReduceOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange inits, ArrayRef<int64_t> dimensions,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+ ArrayRef<NamedAttribute> attributes) {
+ build(builder, result, TypeRange{}, inputs, inits, dimensions);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ for (Value init : inits) {
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
+ }
+
+ if (bodyBuild)
+ buildGenericRegion(builder, result, inputs, inits, bodyBuild);
+}
+
SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<StringRef> iteratorTypes(inputRank,
};
}
-void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
- ::mlir::OperationState &odsState) {
- Region *region = odsState.addRegion();
-
- SmallVector<Type> argTypes;
- SmallVector<Location> argLocs;
- for (auto t : odsState.operands) {
- argTypes.push_back(getElementTypeOrSelf(t));
- argLocs.push_back(opBuilder.getUnknownLoc());
- }
-
- // RAII.
- OpBuilder::InsertionGuard guard(opBuilder);
- Block *body =
- opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
-
- ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
- getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
-}
-
-void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState, Value input,
- Value init, DenseI64ArrayAttr permutation,
+void TransposeOp::build(::mlir::OpBuilder &builder,
+ ::mlir::OperationState &result, Value input, Value init,
+ DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attributes) {
- odsState.addOperands(input);
- odsState.addOperands(init);
- odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
- odsState.addAttributes(attributes);
- odsState.addTypes(init.getType());
+ result.addOperands(input);
+ result.addOperands(init);
+ result.addAttribute(getPermutationAttrName(result.name), permutation);
+ result.addAttributes(attributes);
+
+ // Add output types for `RankedTensorType` output arguments.
+ Type initType = init.getType();
+ if (initType.isa<RankedTensorType>())
+ result.addTypes(initType);
- createRegion(odsBuilder, odsState);
+ buildGenericRegion(builder, result, input, init,
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ b.create<linalg::YieldOp>(loc, args[0]);
+ });
}
-void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState, Value input,
- Value init, ArrayRef<int64_t> permutation,
+void TransposeOp::build(::mlir::OpBuilder &builder,
+ ::mlir::OperationState &result, Value input, Value init,
+ ArrayRef<int64_t> permutation,
ArrayRef<NamedAttribute> attributes) {
- build(odsBuilder, odsState, input, init,
- odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
+ build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
+ attributes);
}
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
})))
return failure();
- OpBuilder opBuilder(parser.getContext());
- createRegion(opBuilder, result);
+ (void)result.addRegion();
+ OpBuilder builder(parser.getContext());
+ buildGenericRegion(builder, result, /*inputs=*/result.operands,
+ /*outputs=*/{},
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ b.create<linalg::YieldOp>(loc, args[0]);
+ });
return success();
}