}
/// Type conversion from Standard Types to SPIR-V Types.
-class SPIRVTypeConverter : public TypeConverter {
+class SPIRVBasicTypeConverter : public TypeConverter {
public:
- explicit SPIRVTypeConverter(MLIRContext *context);
+ explicit SPIRVBasicTypeConverter(MLIRContext *context);
/// Converts types to SPIR-V supported types.
- Type convertType(Type t) override;
+ virtual Type convertType(Type t);
protected:
spirv::SPIRVDialect *spirvDialect;
/// Converts a function type according to the requirements of a SPIR-V entry
/// function. The arguments need to be converted to spv.Variables of spv.ptr
/// types so that they could be bound by the runtime.
-class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
+class SPIRVTypeConverter final : public TypeConverter {
public:
- using SPIRVTypeConverter::SPIRVTypeConverter;
+ explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
+ : basicTypeConverter(basicTypeConverter) {}
+
+ /// Convert types to SPIR-V types using the basic type converter.
+ Type convertType(Type t) override {
+ return basicTypeConverter->convertType(t);
+ }
/// Method to convert argument of a function. The `type` is converted to
/// spv.ptr<type, Uniform>.
// TODO(ravishankarm) : Support other storage classes.
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) override;
+
+ /// Get the basic type converter.
+ SPIRVBasicTypeConverter *getBasicTypeConverter() const {
+ return basicTypeConverter;
+ }
+
+private:
+ SPIRVBasicTypeConverter *basicTypeConverter;
};
/// Base class to define a conversion pattern to translate Ops into SPIR-V.
template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
public:
- SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
- SPIRVEntryFnTypeConverter &entryFnConverter)
+ SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
: ConversionPattern(OpTy::getOperationName(), 1, context),
- typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
+ typeConverter(typeConverter) {}
protected:
// Type lowering class.
SPIRVTypeConverter &typeConverter;
-
- // Entry function signature converter.
- SPIRVEntryFnTypeConverter &entryFnConverter;
};
-/// Base Class for legalize a FuncOp within a spv.module. This class can be
-/// extended to implement a ConversionPattern to lower a FuncOp. It provides
-/// hooks to legalize a FuncOp as a simple function, or as an entry function.
-class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
-public:
- using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
-
-protected:
- /// Method to legalize the function as a non-entry function.
- LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp) const;
-
- /// Method to legalize the function as an entry function.
- LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp) const;
-};
+/// Method to legalize a function as a non-entry function.
+LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ SPIRVTypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp);
+
+/// Method to legalize a function as an entry function.
+LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ SPIRVTypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp);
/// Appends to a pattern list additional patterns for translating StandardOps to
/// SPIR-V ops.
/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
/// attribute gpu.kernel) within a spv.module.
-class KernelFnConversion final : public SPIRVFnLowering {
+class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
public:
- using SPIRVFnLowering::SPIRVFnLowering;
+ using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
auto funcOp = cast<FuncOp>(op);
FuncOp newFuncOp;
if (!gpu::GPUDialect::isKernel(funcOp)) {
- return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
+ return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
+ newFuncOp))
? matchSuccess()
: matchFailure();
}
- if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
+ if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
+ newFuncOp))) {
return matchFailure();
}
newFuncOp.getOperation()->removeAttr(Identifier::get(
}
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
- SPIRVTypeConverter typeConverter(context);
- SPIRVEntryFnTypeConverter entryFnConverter(context);
+ SPIRVBasicTypeConverter basicTypeConverter(context);
+ SPIRVTypeConverter typeConverter(&basicTypeConverter);
OwningRewritePatternList patterns;
- patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
+ patterns.insert<KernelFnConversion>(context, typeConverter);
populateStandardToSPIRVPatterns(context, patterns);
ConversionTarget target(*context);
target.addLegalDialect<spirv::SPIRVDialect>();
- target.addDynamicallyLegalOp<FuncOp>(
- [&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
+ return basicTypeConverter.isSignatureLegal(Op.getType());
+ });
if (failed(applyFullConversion(spirvModules, target, patterns,
&typeConverter))) {
// Type Conversion
//===----------------------------------------------------------------------===//
-SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
+SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
: spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
-Type SPIRVTypeConverter::convertType(Type t) {
+Type SPIRVBasicTypeConverter::convertType(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirvDialect->isValidSPIRVType(t)) {
return t;
//===----------------------------------------------------------------------===//
LogicalResult
-SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
- SignatureConversion &result) {
+SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+ SignatureConversion &result) {
// Try to convert the given input type.
- auto convertedType = convertType(type);
+ auto convertedType = basicTypeConverter->convertType(type);
// TODO(ravishankarm) : Vulkan spec requires these to be a
// spirv::StructType. This is not a SPIR-V requirement, so just making this a
// pointer type for now.
return success();
}
-template <typename Converter>
-static LogicalResult
-lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter, Converter &typeConverter,
- TypeConverter::SignatureConversion &signatureConverter,
- FuncOp &newFuncOp) {
+static LogicalResult lowerFunctionImpl(
+ FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
+ TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
for (auto &argType : enumerate(fnType.getInputs())) {
// Get the type of the argument
- if (failed(typeConverter.convertSignatureArg(
+ if (failed(typeConverter->convertSignatureArg(
argType.index(), argType.value(), signatureConverter))) {
return funcOp.emitError("unable to convert argument type ")
<< argType.value() << " to SPIR-V type";
return success();
}
-LogicalResult
-SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp) const {
+namespace mlir {
+LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ SPIRVTypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
- return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
+ return lowerFunctionImpl(funcOp, operands, rewriter,
+ typeConverter->getBasicTypeConverter(),
signatureConverter, newFuncOp);
}
-LogicalResult
-SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp) const {
+LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ SPIRVTypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
- if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
+ if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
signatureConverter, newFuncOp))) {
return failure();
}
builder.getSymbolRefAttr(newFuncOp.getName()), interface);
return success();
}
+} // namespace mlir
//===----------------------------------------------------------------------===//
// Operation conversion