From 47e953e913ed6108cdb8badf4b0090f51b9e535f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 17 Nov 2022 23:46:18 -0500 Subject: [PATCH] [mlir][spirv] Support attribute in MapMemRefStorageClassPass MemRef memory space actually can be an attribute. Update the map function signature to accept an attribute. The default mappings can still only covers numeric ones, but this allows downstream callers to extend with custom memory spaces. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D138257 --- .../mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h | 6 ++-- .../MemRefToSPIRV/MapMemRefStorageClassPass.cpp | 40 +++++++++++++++------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h index 38c0a48..dbc6c64 100644 --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h @@ -23,18 +23,18 @@ class SPIRVTypeConverter; namespace spirv { /// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones. using MemorySpaceToStorageClassMap = - std::function(unsigned)>; + std::function(Attribute)>; /// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V /// using the default rule. Returns None if the memory space is unknown. -Optional mapMemorySpaceToVulkanStorageClass(unsigned); +Optional mapMemorySpaceToVulkanStorageClass(Attribute); /// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces /// using the default rule. Returns None if the storage class is unsupported. Optional mapVulkanStorageClassToMemorySpace(spirv::StorageClass); /// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V /// using the default rule. Returns None if the memory space is unknown. -Optional mapMemorySpaceToOpenCLStorageClass(unsigned); +Optional mapMemorySpaceToOpenCLStorageClass(Attribute); /// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces /// using the default rule. Returns None if the storage class is unsupported. Optional mapOpenCLStorageClassToMemorySpace(spirv::StorageClass); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp index d52eb4a..31276c6 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -56,7 +57,18 @@ using namespace mlir; MAP_FN(spirv::StorageClass::Output, 10) Optional -spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) { +spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) { + // Handle null memory space attribute specially. + if (!memorySpaceAttr) + return spirv::StorageClass::StorageBuffer; + + // Unknown dialect custom attributes are not supported by default. + // Downstream callers should plug in more specialized ones. + auto intAttr = memorySpaceAttr.dyn_cast(); + if (!intAttr) + return llvm::None; + unsigned memorySpace = intAttr.getInt(); + #define STORAGE_SPACE_MAP_FN(storage, space) \ case space: \ return storage; @@ -99,7 +111,18 @@ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) { MAP_FN(spirv::StorageClass::Image, 7) Optional -spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) { +spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) { + // Handle null memory space attribute specially. + if (!memorySpaceAttr) + return spirv::StorageClass::CrossWorkgroup; + + // Unknown dialect custom attributes are not supported by default. + // Downstream callers should plug in more specialized ones. + auto intAttr = memorySpaceAttr.dyn_cast(); + if (!intAttr) + return llvm::None; + unsigned memorySpace = intAttr.getInt(); + #define STORAGE_SPACE_MAP_FN(storage, space) \ case space: \ return storage; @@ -143,17 +166,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter( addConversion([](Type type) { return type; }); addConversion([this](BaseMemRefType memRefType) -> Optional { - // Expect IntegerAttr memory spaces. The attribute can be missing for the - // case of memory space == 0. - Attribute spaceAttr = memRefType.getMemorySpace(); - if (spaceAttr && !spaceAttr.isa()) { - LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType - << " due to non-IntegerAttr memory space\n"); - return llvm::None; - } - - unsigned space = memRefType.getMemorySpaceAsInt(); - auto storage = this->memorySpaceMap(space); + Optional storage = + this->memorySpaceMap(memRefType.getMemorySpace()); if (!storage) { LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType -- 2.7.4