[MLIR][SPIRVToLLVM] Implementation of SPIR-V module conversion pattern
authorGeorge Mitenkov <georgemitenk0v@gmail.com>
Thu, 25 Jun 2020 00:42:39 +0000 (20:42 -0400)
committerLei Zhang <antiagainst@google.com>
Thu, 25 Jun 2020 00:42:50 +0000 (20:42 -0400)
This patch introduces conversion patterns for `spv.module` and `spv._module_end`.
SPIR-V module is converted into `ModuleOp`. This will play a role of enclosing
scope to LLVM ops. At the moment, SPIR-V module attributes (such as memory model,
etc) are ignored.

Differential Revision: https://reviews.llvm.org/D82468

mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir [new file with mode: 0644]

index 178d27d..1753723 100644 (file)
@@ -43,6 +43,11 @@ void populateSPIRVToLLVMFunctionConversionPatterns(
     MLIRContext *context, LLVMTypeConverter &typeConverter,
     OwningRewritePatternList &patterns);
 
+/// Populates the given patterns for module conversion from SPIR-V to LLVM.
+void populateSPIRVToLLVMModuleConversionPatterns(
+    MLIRContext *context, LLVMTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_SPIRVTOLLVM_CONVERTSPIRVTOLLVM_H
index e32fdc5..297e73a 100644 (file)
@@ -278,6 +278,43 @@ public:
     return success();
   }
 };
+
+//===----------------------------------------------------------------------===//
+// ModuleOp conversion
+//===----------------------------------------------------------------------===//
+
+class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto newModuleOp = rewriter.create<ModuleOp>(spvModuleOp.getLoc());
+    rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
+
+    // Remove the terminator block that was automatically added by builder
+    rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
+    rewriter.eraseOp(spvModuleOp);
+    return success();
+  }
+};
+
+class ModuleEndConversionPattern
+    : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -361,3 +398,10 @@ void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
     OwningRewritePatternList &patterns) {
   patterns.insert<FuncConversionPattern>(context, typeConverter);
 }
+
+void mlir::populateSPIRVToLLVMModuleConversionPatterns(
+    MLIRContext *context, LLVMTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns) {
+  patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
+      context, typeConverter);
+}
index 81a3a71..c512878 100644 (file)
@@ -34,6 +34,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
 
   OwningRewritePatternList patterns;
+  populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
   populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
   populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
 
@@ -45,6 +46,11 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addIllegalDialect<spirv::SPIRVDialect>();
   target.addLegalDialect<LLVM::LLVMDialect>();
+
+  // set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
+  // conversion.
+  target.addLegalOp<ModuleOp>();
+  target.addLegalOp<ModuleTerminatorOp>();
   if (failed(applyPartialConversion(module, target, patterns)))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
new file mode 100644 (file)
index 0000000..b8169a1
--- /dev/null
@@ -0,0 +1,26 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.module
+//===----------------------------------------------------------------------===//
+
+// CHECK: module
+spv.module Logical GLSL450 {}
+
+// CHECK: module
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]> {}
+
+// CHECK: module
+spv.module Logical GLSL450 {
+       // CHECK: }
+  spv._module_end
+}
+
+// CHECK: module
+spv.module Logical GLSL450 {
+       // CHECK-LABEL: llvm.func @empty()
+  spv.func @empty() -> () "None" {
+               // CHECK: llvm.return
+    spv.Return
+  }
+}