Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
OpBuilder &builder);
-/// Legalizes a function as an entry function.
-FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter,
- ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
- spirv::EntryPointABIAttr entryPointInfo);
-
/// Attribute name for specifying argument ABI information.
StringRef getInterfaceVarABIAttrName();
EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
MLIRContext *context);
+/// Legalizes a function as an entry function.
+FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter,
+ spirv::EntryPointABIAttr entryPointInfo,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo);
+
+/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
+/// arguments
+LogicalResult setABIAttrs(FuncOp funcOp,
+ spirv::EntryPointABIAttr entryPointInfo,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo);
+
} // namespace spirv
} // namespace mlir
auto entryPointAttr =
spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context);
FuncOp newFuncOp = spirv::lowerAsEntryFunction(
- funcOp, typeConverter, rewriter, argABI, entryPointAttr);
+ funcOp, typeConverter, rewriter, entryPointAttr, argABI);
if (!newFuncOp) {
return matchFailure();
}
FuncOp mlir::spirv::lowerAsEntryFunction(
FuncOp funcOp, SPIRVTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
- ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
- spirv::EntryPointABIAttr entryPointInfo) {
+ spirv::EntryPointABIAttr entryPointInfo,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
funcOp.emitError("SPIR-V lowering only supports entry functions"
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.eraseOp(funcOp);
+ spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo);
+ return newFuncOp;
+}
+
+LogicalResult
+mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo,
+ ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
- for (auto argIndex : llvm::seq<unsigned>(0, newFuncOp.getNumArguments())) {
- newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
+ for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+ funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
}
- newFuncOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
- return newFuncOp;
+ funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
+ return success();
}