Fix crash in Spirv -lower-host-to-llvm pass
authorMehdi Amini <joker.eph@gmail.com>
Mon, 16 Jan 2023 20:59:09 +0000 (20:59 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 16 Jan 2023 21:06:49 +0000 (21:06 +0000)
When providing with a spirv module as input where no conversion happens
the code didn't defend against broken invariant.

We'll fail the pass here, but it's not clear if it is the right thing
or if the module should just be ignored.

Fixes #59971

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D141856

mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls_fail.mlir [new file with mode: 0644]

index a58072c..2017d55 100644 (file)
@@ -131,7 +131,12 @@ static LogicalResult encodeKernelName(spirv::ModuleOp module) {
   // based on `getKernelGlobalVariables()` call. Update this function's name
   // to:
   //   {spv_module_name}_{function_name}
-  auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
+  auto entryPoints = module.getOps<spirv::EntryPointOp>();
+  if (!llvm::hasSingleElement(entryPoints)) {
+    return module.emitError(
+        "The module must contain exactly one entry point function");
+  }
+  spirv::EntryPointOp entryPoint = *entryPoints.begin();
   StringRef funcName = entryPoint.getFn();
   auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
   StringAttr newFuncName =
@@ -313,8 +318,12 @@ public:
 
     // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
     // conflicts.
-    for (auto spvModule : module.getOps<spirv::ModuleOp>())
-      (void)encodeKernelName(spvModule);
+    for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
+      if (failed(encodeKernelName(spvModule))) {
+        signalPassFailure();
+        return;
+      }
+    }
   }
 };
 } // namespace
diff --git a/mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls_fail.mlir b/mlir/test/Conversion/SPIRVToLLVM/lower-host-to-llvm-calls_fail.mlir
new file mode 100644 (file)
index 0000000..e36d30b
--- /dev/null
@@ -0,0 +1,7 @@
+// RUN: mlir-opt --lower-host-to-llvm %s -verify-diagnostics
+
+module {
+// expected-error @+1 {{The module must contain exactly one entry point function}}
+  spirv.module Logical GLSL450 {
+  }
+}