MLIRContext *context = &getContext();
if (getTargetEnvFn) {
- // This pass is actually only needed for targeting Apple GPUs via MoltenVK,
- // where we need to translate SPIR-V into MSL. The translation has
- // limitations.
- if (getTargetEnvFn(moduleOp).getVendorID() != spirv::Vendor::Apple)
+ // This pass is only needed for targeting WebGPU, Metal, or layering Vulkan
+ // on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or
+ // MSL. The translation has limitations.
+ spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
+ spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
+ bool isVulkanOnAppleDevices =
+ clientAPI == spirv::ClientAPI::Vulkan &&
+ targetEnv.getVendorID() == spirv::Vendor::Apple;
+ if (clientAPI != spirv::ClientAPI::WebGPU &&
+ clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
return;
}