namespace spirv {
/// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
-using MemorySpaceToStorageClassMap = DenseMap<unsigned, spirv::StorageClass>;
-/// Returns the default map for targeting Vulkan-flavored SPIR-V.
-MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap();
+using MemorySpaceToStorageClassMap =
+ std::function<Optional<spirv::StorageClass>(unsigned)>;
+
+/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
+/// using the default rule. Returns None if the memory space is unknown.
+Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
+/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
+/// using the default rule. Returns None if the storage class is unsupported.
+Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
/// Type converter for converting numeric MemRef memory spaces into SPIR-V
/// symbolic ones.
const MemorySpaceToStorageClassMap &memorySpaceMap);
private:
- const MemorySpaceToStorageClassMap &memorySpaceMap;
+ MemorySpaceToStorageClassMap memorySpaceMap;
};
/// Creates the target that populates legality of ops with MemRef types.
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
// Mappings
//===----------------------------------------------------------------------===//
-spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
/// Mapping between SPIR-V storage classes to memref memory spaces.
///
/// Note: memref does not have a defined semantics for each memory space; it
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
- MAP_FN(spirv::StorageClass::Output, 10) \
- MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
- MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
- MAP_FN(spirv::StorageClass::Image, 13) \
- MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
- MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
- MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
- MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
- MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
- MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
- MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
- MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
- MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
- MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
-
-#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage},
-
- return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)};
+ MAP_FN(spirv::StorageClass::Output, 10)
+
+Optional<spirv::StorageClass>
+spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
+#define STORAGE_SPACE_MAP_FN(storage, space) \
+ case space: \
+ return storage;
+
+ switch (memorySpace) {
+ STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+ default:
+ break;
+ }
+ return llvm::None;
+
+#undef STORAGE_SPACE_MAP_FN
+}
+
+Optional<unsigned>
+spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
+#define STORAGE_SPACE_MAP_FN(storage, space) \
+ case storage: \
+ return space;
+
+ switch (storageClass) {
+ STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+ default:
+ break;
+ }
+ return llvm::None;
#undef STORAGE_SPACE_MAP_FN
-#undef STORAGE_SPACE_MAP_LIST
}
+#undef STORAGE_SPACE_MAP_LIST
+
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
}
unsigned space = memRefType.getMemorySpaceAsInt();
- auto it = this->memorySpaceMap.find(space);
- if (it == this->memorySpaceMap.end()) {
+ auto storage = this->memorySpaceMap(space);
+ if (!storage) {
LLVM_DEBUG(llvm::dbgs()
<< "cannot convert " << memRefType
<< " due to being unable to find memory space in map\n");
}
auto storageAttr =
- spirv::StorageClassAttr::get(memRefType.getContext(), it->second);
+ spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
rankedType.getLayout(), storageAttr);
: public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
public:
explicit MapMemRefStorageClassPass() {
- memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
-
- LLVM_DEBUG({
- llvm::dbgs() << "memory space to storage class mapping:\n";
- if (memorySpaceMap.empty())
- llvm::dbgs() << " [empty]\n";
- for (auto kv : memorySpaceMap)
- llvm::dbgs() << " " << kv.first << " -> "
- << spirv::stringifyStorageClass(kv.second) << "\n";
- });
+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
explicit MapMemRefStorageClassPass(
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)