From 4a62019eb88f0f8fafe8f4f7ab1c984313b0b022 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 9 Dec 2019 09:51:25 -0800 Subject: [PATCH] Add lowering for module with gpu.kernel_module attribute. The existing GPU to SPIR-V lowering created a spv.module for every function with gpu.kernel attribute. A better approach is to lower the module that the function lives in (which has the attribute gpu.kernel_module) to a spv.module operation. This better captures the host-device separation modeled by GPU dialect and simplifies the lowering as well. PiperOrigin-RevId: 284574688 --- .../mlir/Dialect/SPIRV/SPIRVStructureOps.td | 18 +++-- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 84 +++++++++++++++++++++- .../GPUToSPIRV/ConvertGPUToSPIRVPass.cpp | 37 +++------- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 2 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 49 ++++++++++--- mlir/test/Conversion/GPUToSPIRV/simple.mlir | 1 + 6 files changed, 148 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index 34b386e..d57a7f4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -329,13 +329,17 @@ def SPV_ModuleOp : SPV_Op<"module", let regions = (region SizedRegion<1>:$body); - let builders = [OpBuilder<"Builder *, OperationState &state">, - OpBuilder<[{Builder *, OperationState &state, - IntegerAttr addressing_model, - IntegerAttr memory_model, - /*optional*/ArrayAttr capabilities = nullptr, - /*optional*/ArrayAttr extensions = nullptr, - /*optional*/ArrayAttr extended_instruction_sets = nullptr}]>]; + let builders = + [OpBuilder<"Builder *, OperationState &state">, + OpBuilder<[{Builder *, OperationState &state, + IntegerAttr addressing_model, + IntegerAttr memory_model}]>, + OpBuilder<[{Builder *, OperationState &state, + spirv::AddressingModel addressing_model, + spirv::MemoryModel memory_model, + /*optional*/ ArrayRef capabilities = {}, + /*optional*/ ArrayRef extensions = {}, + /*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>]; // We need to ensure the block inside the region is properly terminated; // the auto-generated builders do not guarantee that. diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 2c1847d..74d105e 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Module.h" using namespace mlir; @@ -71,8 +72,36 @@ private: SmallVector workGroupSizeAsInt32; }; +/// Pattern to convert a module with gpu.kernel_module attribute to a +/// spv.module. +class KernelModuleConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ModuleOp moduleOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to convert a module terminator op to a terminator of spv.module op. +// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined +// in ODS. +class KernelModuleTerminatorConversion final + : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace +//===----------------------------------------------------------------------===// +// loop::ForOp. +//===----------------------------------------------------------------------===// + PatternMatchResult ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -142,6 +171,10 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, return matchSuccess(); } +//===----------------------------------------------------------------------===// +// Builtins. +//===----------------------------------------------------------------------===// + template PatternMatchResult LaunchConfigConversion::matchAndRewrite( SourceOp op, ArrayRef operands, @@ -170,6 +203,10 @@ PatternMatchResult LaunchConfigConversion::matchAndRewrite( return this->matchSuccess(); } +//===----------------------------------------------------------------------===// +// FuncOp with gpu.kernel attribute. +//===----------------------------------------------------------------------===// + PatternMatchResult KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -196,6 +233,51 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, return matchSuccess(); } +//===----------------------------------------------------------------------===// +// ModuleOp with gpu.kernel_module. +//===----------------------------------------------------------------------===// + +PatternMatchResult KernelModuleConversion::matchAndRewrite( + ModuleOp moduleOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (!moduleOp.getAttrOfType( + gpu::GPUDialect::getKernelModuleAttrName())) { + return matchFailure(); + } + // TODO : Generalize this to account for different extensions, + // capabilities, extended_instruction_sets, other addressing models + // and memory models. + auto spvModule = rewriter.create( + moduleOp.getLoc(), spirv::AddressingModel::Logical, + spirv::MemoryModel::GLSL450, spirv::Capability::Shader, + spirv::Extension::SPV_KHR_storage_buffer_storage_class); + // Move the region from the module op into the SPIR-V module. + Region &spvModuleRegion = spvModule.getOperation()->getRegion(0); + rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion, + spvModuleRegion.begin()); + // The spv.module build method adds a block with a terminator. Remove that + // block. The terminator of the module op in the remaining block will be + // legalized later. + spvModuleRegion.back().erase(); + rewriter.eraseOp(moduleOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// ModuleTerminatorOp for gpu.kernel_module. +//===----------------------------------------------------------------------===// + +PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite( + ModuleTerminatorOp terminatorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(terminatorOp); + return matchSuccess(); +} + +//===----------------------------------------------------------------------===// +// GPU To SPIRV Patterns. +//===----------------------------------------------------------------------===// + namespace mlir { void populateGPUToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, @@ -203,7 +285,7 @@ void populateGPUToSPIRVPatterns(MLIRContext *context, ArrayRef workGroupSize) { patterns.insert(context, typeConverter, workGroupSize); patterns.insert< - ForOpConversion, + ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp index cec71ca..b8fe27e 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -67,34 +67,19 @@ void GPUToSPIRVPass::runOnModule() { auto context = &getContext(); auto module = getModule(); - SmallVector spirvModules; - module.walk([&module, &spirvModules](FuncOp funcOp) { - if (!gpu::GPUDialect::isKernel(funcOp)) { - return; + SmallVector kernelModules; + OpBuilder builder(context); + module.walk([&builder, &kernelModules](ModuleOp moduleOp) { + if (moduleOp.getAttrOfType( + gpu::GPUDialect::getKernelModuleAttrName())) { + // For each kernel module (should be only 1 for now, but that is not a + // requirement here), clone the module for conversion because the + // gpu.launch function still needs the kernel module. + builder.setInsertionPoint(moduleOp.getOperation()); + kernelModules.push_back(builder.clone(*moduleOp.getOperation())); } - OpBuilder builder(funcOp.getOperation()); - // Create a new spirv::ModuleOp for this function, and clone the - // function into it. - // TODO : Generalize this to account for different extensions, - // capabilities, extended_instruction_sets, other addressing models - // and memory models. - auto spvModule = builder.create( - funcOp.getLoc(), - builder.getI32IntegerAttr( - static_cast(spirv::AddressingModel::Logical)), - builder.getI32IntegerAttr( - static_cast(spirv::MemoryModel::GLSL450)), - builder.getStrArrayAttr( - spirv::stringifyCapability(spirv::Capability::Shader)), - builder.getStrArrayAttr(spirv::stringifyExtension( - spirv::Extension::SPV_KHR_storage_buffer_storage_class))); - // Hardwire the capability to be Shader. - OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0)); - moduleBuilder.clone(*funcOp.getOperation()); - spirvModules.push_back(spvModule); }); - /// Dialect conversion to lower the functions with the spirv::ModuleOps. SPIRVTypeConverter typeConverter; OwningRewritePatternList patterns; populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize); @@ -105,7 +90,7 @@ void GPUToSPIRVPass::runOnModule() { target.addDynamicallyLegalOp( [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); - if (failed(applyFullConversion(spirvModules, target, patterns, + if (failed(applyFullConversion(kernelModules, target, patterns, &typeConverter))) { return signalPassFailure(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 694a98f..bf17d10 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction( newFuncOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), llvm::None)); rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); - rewriter.replaceOp(funcOp.getOperation(), llvm::None); + rewriter.eraseOp(funcOp); // Set the attributes for argument and the function. StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 66af430..7061200 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op, return success(); } +template +static ArrayAttr +getStrArrayAttrForEnumList(Builder &builder, ArrayRef enumValues, + llvm::function_ref stringifyFn) { + if (enumValues.empty()) { + return nullptr; + } + SmallVector enumValStrs; + enumValStrs.reserve(enumValues.size()); + for (auto val : enumValues) { + enumValStrs.emplace_back(stringifyFn(val)); + } + return builder.getStrArrayAttr(enumValStrs); +} + template static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser &parser, @@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) { ensureTerminator(*state.addRegion(), *builder, state.location); } +// TODO(ravishankarm): This is only here for resolving some dependency outside +// of mlir. Remove once it is done. void spirv::ModuleOp::build(Builder *builder, OperationState &state, IntegerAttr addressing_model, - IntegerAttr memory_model, ArrayAttr capabilities, - ArrayAttr extensions, - ArrayAttr extended_instruction_sets) { + IntegerAttr memory_model) { state.addAttribute("addressing_model", addressing_model); state.addAttribute("memory_model", memory_model); - if (capabilities) - state.addAttribute("capabilities", capabilities); - if (extensions) - state.addAttribute("extensions", extensions); + build(builder, state); +} + +void spirv::ModuleOp::build(Builder *builder, OperationState &state, + spirv::AddressingModel addressing_model, + spirv::MemoryModel memory_model, + ArrayRef capabilities, + ArrayRef extensions, + ArrayAttr extended_instruction_sets) { + state.addAttribute( + "addressing_model", + builder->getI32IntegerAttr(static_cast(addressing_model))); + state.addAttribute("memory_model", builder->getI32IntegerAttr( + static_cast(memory_model))); + if (!capabilities.empty()) + state.addAttribute("capabilities", + getStrArrayAttrForEnumList( + *builder, capabilities, spirv::stringifyCapability)); + if (!extensions.empty()) + state.addAttribute("extensions", + getStrArrayAttrForEnumList( + *builder, extensions, spirv::stringifyExtension)); if (extended_instruction_sets) state.addAttribute("extended_instruction_sets", extended_instruction_sets); - ensureTerminator(*state.addRegion(), *builder, state.location); + build(builder, state); } static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir index ef136b9..5ec78ed 100644 --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -13,6 +13,7 @@ module attributes {gpu.container_module} { // CHECK: spv.Return return } + // CHECK: attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]} } func @foo() { -- 2.7.4