[mlir][spirv] Fix nullptr dereference in UnifyAliasedResource
authorJakub Kuderski <kubak@google.com>
Fri, 28 Apr 2023 15:39:22 +0000 (11:39 -0400)
committerJakub Kuderski <kubak@google.com>
Fri, 28 Apr 2023 15:39:23 +0000 (11:39 -0400)
Fixes: https://github.com/llvm/llvm-project/issues/62368

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149376

mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

index 1713c44..97f16d1 100644 (file)
@@ -220,6 +220,9 @@ ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
 }
 
 bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
+  if (!op)
+    return false;
+
   if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
     auto canonicalOp = getCanonicalResource(varOp);
     return canonicalOp && varOp != canonicalOp;
@@ -566,16 +569,15 @@ public:
 private:
   spirv::GetTargetEnvFn getTargetEnvFn;
 };
-} // namespace
 
 void UnifyAliasedResourcePass::runOnOperation() {
   spirv::ModuleOp moduleOp = getOperation();
   MLIRContext *context = &getContext();
 
   if (getTargetEnvFn) {
-    // This pass is only needed for targeting WebGPU, Metal, or layering Vulkan
-    // on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or
-    // MSL. The translation has limitations.
+    // This pass is only needed for targeting WebGPU, Metal, or layering
+    // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
+    // WGSL or MSL. The translation has limitations.
     spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
     spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
     bool isVulkanOnAppleDevices =
@@ -614,6 +616,7 @@ void UnifyAliasedResourcePass::runOnOperation() {
       resources.front()->removeAttr("aliased");
   }
 }
+} // namespace
 
 std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
 spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
index 8801fdb..4f3df86 100644 (file)
@@ -506,3 +506,19 @@ spirv.module Logical GLSL450 {
 // CHECK:   %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32>
 // CHECK:   spirv.ReturnValue %[[CC]]
 
+// -----
+
+// Make sure we do not crash on function arguments.
+
+spirv.module Logical GLSL450 {
+  spirv.func @main(%arg0: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>) "None" {
+    %cst0_i32 = spirv.Constant 0 : i32
+    %0 = spirv.AccessChain %arg0[%cst0_i32, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    spirv.Return
+  }
+}
+
+// CHECK-LABEL: spirv.module
+// CHECK-LABEL: spirv.func @main
+// CHECK-SAME:  (%{{.+}}: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>) "None"
+// CHECK:       spirv.Return