Add lowering for module with gpu.kernel_module attribute.
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 9 Dec 2019 17:51:25 +0000 (09:51 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Dec 2019 17:52:21 +0000 (09:52 -0800)
The existing GPU to SPIR-V lowering created a spv.module for every
function with gpu.kernel attribute. A better approach is to lower the
module that the function lives in (which has the attribute
gpu.kernel_module) to a spv.module operation. This better captures the
host-device separation modeled by GPU dialect and simplifies the
lowering as well.

PiperOrigin-RevId: 284574688

mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Conversion/GPUToSPIRV/simple.mlir

index 34b386e..d57a7f4 100644 (file)
@@ -329,13 +329,17 @@ def SPV_ModuleOp : SPV_Op<"module",
 
   let regions = (region SizedRegion<1>:$body);
 
-  let builders = [OpBuilder<"Builder *, OperationState &state">,
-                  OpBuilder<[{Builder *, OperationState &state,
-                              IntegerAttr addressing_model,
-                              IntegerAttr memory_model,
-                              /*optional*/ArrayAttr capabilities = nullptr,
-                              /*optional*/ArrayAttr extensions = nullptr,
-                              /*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
+  let builders =
+    [OpBuilder<"Builder *, OperationState &state">,
+     OpBuilder<[{Builder *, OperationState &state,
+                 IntegerAttr addressing_model,
+                 IntegerAttr memory_model}]>,
+     OpBuilder<[{Builder *, OperationState &state,
+                 spirv::AddressingModel addressing_model,
+                 spirv::MemoryModel memory_model,
+                 /*optional*/ ArrayRef<spirv::Capability> capabilities = {},
+                 /*optional*/ ArrayRef<spirv::Extension> extensions = {},
+                 /*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>];
 
   // We need to ensure the block inside the region is properly terminated;
   // the auto-generated builders do not guarantee that.
index 2c1847d..74d105e 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Module.h"
 
 using namespace mlir;
 
@@ -71,8 +72,36 @@ private:
   SmallVector<int32_t, 3> workGroupSizeAsInt32;
 };
 
+/// Pattern to convert a module with gpu.kernel_module attribute to a
+/// spv.module.
+class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> {
+public:
+  using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
+
+  PatternMatchResult
+  matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pattern to convert a module terminator op to a terminator of spv.module op.
+// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined
+// in ODS.
+class KernelModuleTerminatorConversion final
+    : public SPIRVOpLowering<ModuleTerminatorOp> {
+public:
+  using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
+
+  PatternMatchResult
+  matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// loop::ForOp.
+//===----------------------------------------------------------------------===//
+
 PatternMatchResult
 ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
                                  ConversionPatternRewriter &rewriter) const {
@@ -142,6 +171,10 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
   return matchSuccess();
 }
 
+//===----------------------------------------------------------------------===//
+// Builtins.
+//===----------------------------------------------------------------------===//
+
 template <typename SourceOp, spirv::BuiltIn builtin>
 PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
     SourceOp op, ArrayRef<Value *> operands,
@@ -170,6 +203,10 @@ PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
   return this->matchSuccess();
 }
 
+//===----------------------------------------------------------------------===//
+// FuncOp with gpu.kernel attribute.
+//===----------------------------------------------------------------------===//
+
 PatternMatchResult
 KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
                                     ConversionPatternRewriter &rewriter) const {
@@ -196,6 +233,51 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
   return matchSuccess();
 }
 
+//===----------------------------------------------------------------------===//
+// ModuleOp with gpu.kernel_module.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult KernelModuleConversion::matchAndRewrite(
+    ModuleOp moduleOp, ArrayRef<Value *> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (!moduleOp.getAttrOfType<UnitAttr>(
+          gpu::GPUDialect::getKernelModuleAttrName())) {
+    return matchFailure();
+  }
+  // TODO : Generalize this to account for different extensions,
+  // capabilities, extended_instruction_sets, other addressing models
+  // and memory models.
+  auto spvModule = rewriter.create<spirv::ModuleOp>(
+      moduleOp.getLoc(), spirv::AddressingModel::Logical,
+      spirv::MemoryModel::GLSL450, spirv::Capability::Shader,
+      spirv::Extension::SPV_KHR_storage_buffer_storage_class);
+  // Move the region from the module op into the SPIR-V module.
+  Region &spvModuleRegion = spvModule.getOperation()->getRegion(0);
+  rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
+                              spvModuleRegion.begin());
+  // The spv.module build method adds a block with a terminator. Remove that
+  // block. The terminator of the module op in the remaining block will be
+  // legalized later.
+  spvModuleRegion.back().erase();
+  rewriter.eraseOp(moduleOp);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleTerminatorOp for gpu.kernel_module.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
+    ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
+    ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU To SPIRV Patterns.
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 void populateGPUToSPIRVPatterns(MLIRContext *context,
                                 SPIRVTypeConverter &typeConverter,
@@ -203,7 +285,7 @@ void populateGPUToSPIRVPatterns(MLIRContext *context,
                                 ArrayRef<int64_t> workGroupSize) {
   patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
   patterns.insert<
-      ForOpConversion,
+      ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
index cec71ca..b8fe27e 100644 (file)
@@ -67,34 +67,19 @@ void GPUToSPIRVPass::runOnModule() {
   auto context = &getContext();
   auto module = getModule();
 
-  SmallVector<Operation *, 4> spirvModules;
-  module.walk([&module, &spirvModules](FuncOp funcOp) {
-    if (!gpu::GPUDialect::isKernel(funcOp)) {
-      return;
+  SmallVector<Operation *, 1> kernelModules;
+  OpBuilder builder(context);
+  module.walk([&builder, &kernelModules](ModuleOp moduleOp) {
+    if (moduleOp.getAttrOfType<UnitAttr>(
+            gpu::GPUDialect::getKernelModuleAttrName())) {
+      // For each kernel module (should be only 1 for now, but that is not a
+      // requirement here), clone the module for conversion because the
+      // gpu.launch function still needs the kernel module.
+      builder.setInsertionPoint(moduleOp.getOperation());
+      kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
     }
-    OpBuilder builder(funcOp.getOperation());
-    // Create a new spirv::ModuleOp for this function, and clone the
-    // function into it.
-    // TODO : Generalize this to account for different extensions,
-    // capabilities, extended_instruction_sets, other addressing models
-    // and memory models.
-    auto spvModule = builder.create<spirv::ModuleOp>(
-        funcOp.getLoc(),
-        builder.getI32IntegerAttr(
-            static_cast<int32_t>(spirv::AddressingModel::Logical)),
-        builder.getI32IntegerAttr(
-            static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
-        builder.getStrArrayAttr(
-            spirv::stringifyCapability(spirv::Capability::Shader)),
-        builder.getStrArrayAttr(spirv::stringifyExtension(
-            spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
-    // Hardwire the capability to be Shader.
-    OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
-    moduleBuilder.clone(*funcOp.getOperation());
-    spirvModules.push_back(spvModule);
   });
 
-  /// Dialect conversion to lower the functions with the spirv::ModuleOps.
   SPIRVTypeConverter typeConverter;
   OwningRewritePatternList patterns;
   populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
@@ -105,7 +90,7 @@ void GPUToSPIRVPass::runOnModule() {
   target.addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
 
-  if (failed(applyFullConversion(spirvModules, target, patterns,
+  if (failed(applyFullConversion(kernelModules, target, patterns,
                                  &typeConverter))) {
     return signalPassFailure();
   }
index 694a98f..bf17d10 100644 (file)
@@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction(
   newFuncOp.setType(rewriter.getFunctionType(
       signatureConverter.getConvertedTypes(), llvm::None));
   rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
-  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+  rewriter.eraseOp(funcOp);
 
   // Set the attributes for argument and the function.
   StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
index 66af430..7061200 100644 (file)
@@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op,
   return success();
 }
 
+template <typename Ty>
+static ArrayAttr
+getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
+                           llvm::function_ref<StringRef(Ty)> stringifyFn) {
+  if (enumValues.empty()) {
+    return nullptr;
+  }
+  SmallVector<StringRef, 1> enumValStrs;
+  enumValStrs.reserve(enumValues.size());
+  for (auto val : enumValues) {
+    enumValStrs.emplace_back(stringifyFn(val));
+  }
+  return builder.getStrArrayAttr(enumValStrs);
+}
+
 template <typename EnumClass>
 static ParseResult
 parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
@@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
   ensureTerminator(*state.addRegion(), *builder, state.location);
 }
 
+// TODO(ravishankarm): This is only here for resolving some dependency outside
+// of mlir. Remove once it is done.
 void spirv::ModuleOp::build(Builder *builder, OperationState &state,
                             IntegerAttr addressing_model,
-                            IntegerAttr memory_model, ArrayAttr capabilities,
-                            ArrayAttr extensions,
-                            ArrayAttr extended_instruction_sets) {
+                            IntegerAttr memory_model) {
   state.addAttribute("addressing_model", addressing_model);
   state.addAttribute("memory_model", memory_model);
-  if (capabilities)
-    state.addAttribute("capabilities", capabilities);
-  if (extensions)
-    state.addAttribute("extensions", extensions);
+  build(builder, state);
+}
+
+void spirv::ModuleOp::build(Builder *builder, OperationState &state,
+                            spirv::AddressingModel addressing_model,
+                            spirv::MemoryModel memory_model,
+                            ArrayRef<spirv::Capability> capabilities,
+                            ArrayRef<spirv::Extension> extensions,
+                            ArrayAttr extended_instruction_sets) {
+  state.addAttribute(
+      "addressing_model",
+      builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
+  state.addAttribute("memory_model", builder->getI32IntegerAttr(
+                                         static_cast<int32_t>(memory_model)));
+  if (!capabilities.empty())
+    state.addAttribute("capabilities",
+                       getStrArrayAttrForEnumList<spirv::Capability>(
+                           *builder, capabilities, spirv::stringifyCapability));
+  if (!extensions.empty())
+    state.addAttribute("extensions",
+                       getStrArrayAttrForEnumList<spirv::Extension>(
+                           *builder, extensions, spirv::stringifyExtension));
   if (extended_instruction_sets)
     state.addAttribute("extended_instruction_sets", extended_instruction_sets);
-  ensureTerminator(*state.addRegion(), *builder, state.location);
+  build(builder, state);
 }
 
 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
index ef136b9..5ec78ed 100644 (file)
@@ -13,6 +13,7 @@ module attributes {gpu.container_module} {
       // CHECK: spv.Return
       return
     }
+    // CHECK: attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
   }
 
   func @foo() {