let regions = (region SizedRegion<1>:$body);
- let builders = [OpBuilder<"Builder *, OperationState &state">,
- OpBuilder<[{Builder *, OperationState &state,
- IntegerAttr addressing_model,
- IntegerAttr memory_model,
- /*optional*/ArrayAttr capabilities = nullptr,
- /*optional*/ArrayAttr extensions = nullptr,
- /*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
+ let builders =
+ [OpBuilder<"Builder *, OperationState &state">,
+ OpBuilder<[{Builder *, OperationState &state,
+ IntegerAttr addressing_model,
+ IntegerAttr memory_model}]>,
+ OpBuilder<[{Builder *, OperationState &state,
+ spirv::AddressingModel addressing_model,
+ spirv::MemoryModel memory_model,
+ /*optional*/ ArrayRef<spirv::Capability> capabilities = {},
+ /*optional*/ ArrayRef<spirv::Extension> extensions = {},
+ /*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>];
// We need to ensure the block inside the region is properly terminated;
// the auto-generated builders do not guarantee that.
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Module.h"
using namespace mlir;
SmallVector<int32_t, 3> workGroupSizeAsInt32;
};
+/// Pattern to convert a module with gpu.kernel_module attribute to a
+/// spv.module.
+class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> {
+public:
+ using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a module terminator op to a terminator of spv.module op.
+// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined
+// in ODS.
+class KernelModuleTerminatorConversion final
+ : public SPIRVOpLowering<ModuleTerminatorOp> {
+public:
+ using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
+//===----------------------------------------------------------------------===//
+// loop::ForOp.
+//===----------------------------------------------------------------------===//
+
PatternMatchResult
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
return matchSuccess();
}
+//===----------------------------------------------------------------------===//
+// Builtins.
+//===----------------------------------------------------------------------===//
+
template <typename SourceOp, spirv::BuiltIn builtin>
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value *> operands,
return this->matchSuccess();
}
+//===----------------------------------------------------------------------===//
+// FuncOp with gpu.kernel attribute.
+//===----------------------------------------------------------------------===//
+
PatternMatchResult
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
return matchSuccess();
}
+//===----------------------------------------------------------------------===//
+// ModuleOp with gpu.kernel_module.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult KernelModuleConversion::matchAndRewrite(
+ ModuleOp moduleOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!moduleOp.getAttrOfType<UnitAttr>(
+ gpu::GPUDialect::getKernelModuleAttrName())) {
+ return matchFailure();
+ }
+ // TODO : Generalize this to account for different extensions,
+ // capabilities, extended_instruction_sets, other addressing models
+ // and memory models.
+ auto spvModule = rewriter.create<spirv::ModuleOp>(
+ moduleOp.getLoc(), spirv::AddressingModel::Logical,
+ spirv::MemoryModel::GLSL450, spirv::Capability::Shader,
+ spirv::Extension::SPV_KHR_storage_buffer_storage_class);
+ // Move the region from the module op into the SPIR-V module.
+ Region &spvModuleRegion = spvModule.getOperation()->getRegion(0);
+ rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
+ spvModuleRegion.begin());
+ // The spv.module build method adds a block with a terminator. Remove that
+ // block. The terminator of the module op in the remaining block will be
+ // legalized later.
+ spvModuleRegion.back().erase();
+ rewriter.eraseOp(moduleOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleTerminatorOp for gpu.kernel_module.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
+ ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU To SPIRV Patterns.
+//===----------------------------------------------------------------------===//
+
namespace mlir {
void populateGPUToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
ArrayRef<int64_t> workGroupSize) {
patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
patterns.insert<
- ForOpConversion,
+ ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
auto context = &getContext();
auto module = getModule();
- SmallVector<Operation *, 4> spirvModules;
- module.walk([&module, &spirvModules](FuncOp funcOp) {
- if (!gpu::GPUDialect::isKernel(funcOp)) {
- return;
+ SmallVector<Operation *, 1> kernelModules;
+ OpBuilder builder(context);
+ module.walk([&builder, &kernelModules](ModuleOp moduleOp) {
+ if (moduleOp.getAttrOfType<UnitAttr>(
+ gpu::GPUDialect::getKernelModuleAttrName())) {
+ // For each kernel module (should be only 1 for now, but that is not a
+ // requirement here), clone the module for conversion because the
+ // gpu.launch function still needs the kernel module.
+ builder.setInsertionPoint(moduleOp.getOperation());
+ kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
}
- OpBuilder builder(funcOp.getOperation());
- // Create a new spirv::ModuleOp for this function, and clone the
- // function into it.
- // TODO : Generalize this to account for different extensions,
- // capabilities, extended_instruction_sets, other addressing models
- // and memory models.
- auto spvModule = builder.create<spirv::ModuleOp>(
- funcOp.getLoc(),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::AddressingModel::Logical)),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
- builder.getStrArrayAttr(
- spirv::stringifyCapability(spirv::Capability::Shader)),
- builder.getStrArrayAttr(spirv::stringifyExtension(
- spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
- // Hardwire the capability to be Shader.
- OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
- moduleBuilder.clone(*funcOp.getOperation());
- spirvModules.push_back(spvModule);
});
- /// Dialect conversion to lower the functions with the spirv::ModuleOps.
SPIRVTypeConverter typeConverter;
OwningRewritePatternList patterns;
populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
- if (failed(applyFullConversion(spirvModules, target, patterns,
+ if (failed(applyFullConversion(kernelModules, target, patterns,
&typeConverter))) {
return signalPassFailure();
}
newFuncOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+ rewriter.eraseOp(funcOp);
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
return success();
}
+template <typename Ty>
+static ArrayAttr
+getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
+ llvm::function_ref<StringRef(Ty)> stringifyFn) {
+ if (enumValues.empty()) {
+ return nullptr;
+ }
+ SmallVector<StringRef, 1> enumValStrs;
+ enumValStrs.reserve(enumValues.size());
+ for (auto val : enumValues) {
+ enumValStrs.emplace_back(stringifyFn(val));
+ }
+ return builder.getStrArrayAttr(enumValStrs);
+}
+
template <typename EnumClass>
static ParseResult
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
ensureTerminator(*state.addRegion(), *builder, state.location);
}
+// TODO(ravishankarm): This is only here for resolving some dependency outside
+// of mlir. Remove once it is done.
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
IntegerAttr addressing_model,
- IntegerAttr memory_model, ArrayAttr capabilities,
- ArrayAttr extensions,
- ArrayAttr extended_instruction_sets) {
+ IntegerAttr memory_model) {
state.addAttribute("addressing_model", addressing_model);
state.addAttribute("memory_model", memory_model);
- if (capabilities)
- state.addAttribute("capabilities", capabilities);
- if (extensions)
- state.addAttribute("extensions", extensions);
+ build(builder, state);
+}
+
+void spirv::ModuleOp::build(Builder *builder, OperationState &state,
+ spirv::AddressingModel addressing_model,
+ spirv::MemoryModel memory_model,
+ ArrayRef<spirv::Capability> capabilities,
+ ArrayRef<spirv::Extension> extensions,
+ ArrayAttr extended_instruction_sets) {
+ state.addAttribute(
+ "addressing_model",
+ builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
+ state.addAttribute("memory_model", builder->getI32IntegerAttr(
+ static_cast<int32_t>(memory_model)));
+ if (!capabilities.empty())
+ state.addAttribute("capabilities",
+ getStrArrayAttrForEnumList<spirv::Capability>(
+ *builder, capabilities, spirv::stringifyCapability));
+ if (!extensions.empty())
+ state.addAttribute("extensions",
+ getStrArrayAttrForEnumList<spirv::Extension>(
+ *builder, extensions, spirv::stringifyExtension));
if (extended_instruction_sets)
state.addAttribute("extended_instruction_sets", extended_instruction_sets);
- ensureTerminator(*state.addRegion(), *builder, state.location);
+ build(builder, state);
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
// CHECK: spv.Return
return
}
+ // CHECK: attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
}
func @foo() {