[mlir][spirv] Support attribute in MapMemRefStorageClassPass
authorLei Zhang <antiagainst@google.com>
Fri, 18 Nov 2022 04:46:18 +0000 (23:46 -0500)
committerLei Zhang <antiagainst@google.com>
Fri, 18 Nov 2022 04:55:52 +0000 (23:55 -0500)
MemRef memory space actually can be an attribute. Update the
map function signature to accept an attribute. The default
mappings can still only covers numeric ones, but this allows
downstream callers to extend with custom memory spaces.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D138257

mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

index 38c0a48..dbc6c64 100644 (file)
@@ -23,18 +23,18 @@ class SPIRVTypeConverter;
 namespace spirv {
 /// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
 using MemorySpaceToStorageClassMap =
-    std::function<Optional<spirv::StorageClass>(unsigned)>;
+    std::function<Optional<spirv::StorageClass>(Attribute)>;
 
 /// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
 /// using the default rule. Returns None if the memory space is unknown.
-Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
+Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(Attribute);
 /// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
 /// using the default rule. Returns None if the storage class is unsupported.
 Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
 
 /// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V
 /// using the default rule. Returns None if the memory space is unknown.
-Optional<spirv::StorageClass> mapMemorySpaceToOpenCLStorageClass(unsigned);
+Optional<spirv::StorageClass> mapMemorySpaceToOpenCLStorageClass(Attribute);
 /// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces
 /// using the default rule. Returns None if the storage class is unsupported.
 Optional<unsigned> mapOpenCLStorageClassToMemorySpace(spirv::StorageClass);
index d52eb4a..31276c6 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -56,7 +57,18 @@ using namespace mlir;
   MAP_FN(spirv::StorageClass::Output, 10)
 
 Optional<spirv::StorageClass>
-spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
+spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
+  // Handle null memory space attribute specially.
+  if (!memorySpaceAttr)
+    return spirv::StorageClass::StorageBuffer;
+
+  // Unknown dialect custom attributes are not supported by default.
+  // Downstream callers should plug in more specialized ones.
+  auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
+  if (!intAttr)
+    return llvm::None;
+  unsigned memorySpace = intAttr.getInt();
+
 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
   case space:                                                                  \
     return storage;
@@ -99,7 +111,18 @@ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
   MAP_FN(spirv::StorageClass::Image, 7)
 
 Optional<spirv::StorageClass>
-spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) {
+spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
+  // Handle null memory space attribute specially.
+  if (!memorySpaceAttr)
+    return spirv::StorageClass::CrossWorkgroup;
+
+  // Unknown dialect custom attributes are not supported by default.
+  // Downstream callers should plug in more specialized ones.
+  auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
+  if (!intAttr)
+    return llvm::None;
+  unsigned memorySpace = intAttr.getInt();
+
 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
   case space:                                                                  \
     return storage;
@@ -143,17 +166,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
   addConversion([](Type type) { return type; });
 
   addConversion([this](BaseMemRefType memRefType) -> Optional<Type> {
-    // Expect IntegerAttr memory spaces. The attribute can be missing for the
-    // case of memory space == 0.
-    Attribute spaceAttr = memRefType.getMemorySpace();
-    if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
-                              << " due to non-IntegerAttr memory space\n");
-      return llvm::None;
-    }
-
-    unsigned space = memRefType.getMemorySpaceAsInt();
-    auto storage = this->memorySpaceMap(space);
+    Optional<spirv::StorageClass> storage =
+        this->memorySpaceMap(memRefType.getMemorySpace());
     if (!storage) {
       LLVM_DEBUG(llvm::dbgs()
                  << "cannot convert " << memRefType