[mlir][spirv] GPUToSPIRVPass: support case when `TargetEnv` attribute attached to...
authorIvan Butygin <ivan.butygin@gmail.com>
Thu, 13 Oct 2022 19:16:26 +0000 (21:16 +0200)
committerIvan Butygin <ivan.butygin@gmail.com>
Fri, 14 Oct 2022 10:33:01 +0000 (12:33 +0200)
Previously, only case when `TargetEnv` was attached to the top level `ModuleOp` was supported.

Differential Revision: https://reviews.llvm.org/D135907

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir

index 1b51161..b6c64f3 100644 (file)
@@ -338,6 +338,15 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
                               spvModuleRegion.begin());
   // The spirv.module build method adds a block. Remove that.
   rewriter.eraseBlock(&spvModuleRegion.back());
+
+  // Some of the patterns call `lookupTargetEnv` during conversion and they
+  // will fail if called after GPUModuleConversion and we don't preserve
+  // `TargetEnv` attribute.
+  // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
+  if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
+          spirv::getTargetEnvAttrName()))
+    spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
+
   rewriter.eraseOp(moduleOp);
   return success();
 }
index 690c92f..c425346 100644 (file)
@@ -63,37 +63,41 @@ void GPUToSPIRVPass::runOnOperation() {
     gpuModules.push_back(builder.clone(*moduleOp.getOperation()));
   });
 
-  // Map MemRef memory space to SPIR-V storage class first if requested.
-  if (mapMemorySpace) {
+  // Run conversion for each module independently as they can have different
+  // TargetEnv attributes.
+  for (Operation *gpuModule : gpuModules) {
+    // Map MemRef memory space to SPIR-V storage class first if requested.
+    if (mapMemorySpace) {
+      std::unique_ptr<ConversionTarget> target =
+          spirv::getMemorySpaceToStorageClassTarget(*context);
+      spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+          spirv::mapMemorySpaceToVulkanStorageClass;
+      spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+
+      RewritePatternSet patterns(context);
+      spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+
+      if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
     std::unique_ptr<ConversionTarget> target =
-        spirv::getMemorySpaceToStorageClassTarget(*context);
-    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
-        spirv::mapMemorySpaceToVulkanStorageClass;
-    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+        SPIRVConversionTarget::get(targetAttr);
 
+    SPIRVTypeConverter typeConverter(targetAttr);
     RewritePatternSet patterns(context);
-    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+    populateGPUToSPIRVPatterns(typeConverter, patterns);
 
-    if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
+    // TODO: Change SPIR-V conversion to be progressive and remove the following
+    // patterns.
+    mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+    populateMemRefToSPIRVPatterns(typeConverter, patterns);
+    populateFuncToSPIRVPatterns(typeConverter, patterns);
+
+    if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
       return signalPassFailure();
   }
-
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
-  std::unique_ptr<ConversionTarget> target =
-      SPIRVConversionTarget::get(targetAttr);
-
-  SPIRVTypeConverter typeConverter(targetAttr);
-  RewritePatternSet patterns(context);
-  populateGPUToSPIRVPatterns(typeConverter, patterns);
-
-  // TODO: Change SPIR-V conversion to be progressive and remove the following
-  // patterns.
-  mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
-  populateMemRefToSPIRVPatterns(typeConverter, patterns);
-  populateFuncToSPIRVPatterns(typeConverter, patterns);
-
-  if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
-    return signalPassFailure();
 }
 
 std::unique_ptr<OperationPass<ModuleOp>>
index e442ac6..fa554f9 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics -split-input-file %s -o - | FileCheck %s
 
 module attributes {
   gpu.container_module,
@@ -28,3 +28,36 @@ module attributes {
     return
   }
 }
+
+// -----
+
+module attributes {
+  gpu.container_module
+} {
+  gpu.module @kernels attributes {
+    spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Addresses], []>, #spirv.resource_limits<>>
+  } {
+    // CHECK-LABEL: spirv.module @{{.*}} Physical64 OpenCL
+    //  CHECK-SAME: spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Addresses], []>, #spirv.resource_limits<>>
+    //       CHECK:   spirv.func
+    //  CHECK-SAME:     {{%.*}}: f32
+    //   CHECK-NOT:     spirv.interface_var_abi
+    //  CHECK-SAME:     {{%.*}}: !spirv.ptr<!spirv.array<12 x f32>, CrossWorkgroup>
+    //   CHECK-NOT:     spirv.interface_var_abi
+    //  CHECK-SAME:     spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>) kernel
+        attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+      gpu.return
+    }
+  }
+
+  func.func @main() {
+    %0 = "op"() : () -> (f32)
+    %1 = "op"() : () -> (memref<12xf32, #spirv.storage_class<CrossWorkgroup>>)
+    %cst = arith.constant 1 : index
+    gpu.launch_func @kernels::@basic_module_structure
+        blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
+        args(%0 : f32, %1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>)
+    return
+  }
+}