[mlir][spirv] Detach memory space mapping from type conversion
authorLei Zhang <antiagainst@google.com>
Tue, 9 Aug 2022 18:25:38 +0000 (14:25 -0400)
committerLei Zhang <antiagainst@google.com>
Tue, 9 Aug 2022 18:30:43 +0000 (14:30 -0400)
This commit moves MemRef memory space to SPIR-V storage class
conversion out of the main SPIR-V type converter. Now the mapping
should happen as a prelimiary step before performing the final
conversion to SPIR-V. Flows are expect to write their own memory
space mappings like the `MapMemRefStorageClassPass` to handle
memory space mappings according to their needs.

This is needed because SPIR-V is serving multiple client APIs,
including Vulkan and OpenCL. Different client APIs might want
to use different storage classes for buffers in a particular
memory space, e.g., `StorageBuffer` for Vulkan vs. `CrossWorkgroup`
for OpenCL when converting the default 0 memory space.  Hardcoding
a specific mapping makes that hard. While it's possible to embed
selection logic further inside the main type converter, it will
make the main type converter even complicated. So it's better to
separate the concerns, as mapping the memory space is really
concretizing the meaning of those numeric memory spaces in the
particular context of SPIR-V lowering.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131410

18 files changed:
mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir [moved from mlir/test/Conversion/GPUToSPIRV/simple.mlir with 83% similarity]
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir [moved from mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir with 84% similarity]
mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
mlir/test/Conversion/SCFToSPIRV/for.mlir
mlir/test/Conversion/SCFToSPIRV/if.mlir
mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

index dba8d27..8867c96 100644 (file)
@@ -21,9 +21,12 @@ class ModuleOp;
 template <typename T>
 class OperationPass;
 
-/// Creates a pass to convert GPU Ops to SPIR-V ops. For a gpu.func to be
-/// converted, it should have a spv.entry_point_abi attribute.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertGPUToSPIRVPass();
+/// Creates a pass to convert GPU kernel ops to corresponding SPIR-V ops. For a
+/// gpu.func to be converted, it should have a spv.entry_point_abi attribute.
+/// If `mapMemorySpace` is true, performs MemRef memory space to SPIR-V mapping
+/// according to default Vulkan rules first.
+std::unique_ptr<OperationPass<ModuleOp>>
+createConvertGPUToSPIRVPass(bool mapMemorySpace = false);
 
 } // namespace mlir
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRVPASS_H
index 2eede78..d0f09ea 100644 (file)
@@ -69,15 +69,6 @@ public:
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
-  /// Returns the corresponding memory space for memref given a SPIR-V storage
-  /// class.
-  static unsigned getMemorySpaceForStorageClass(spirv::StorageClass);
-
-  /// Returns the SPIR-V storage class given a memory space for memref. Return
-  /// llvm::None if the memory space does not map to any SPIR-V storage class.
-  static Optional<spirv::StorageClass>
-  getStorageClassForMemorySpace(unsigned space);
-
   /// Returns the options controlling the SPIR-V type converter.
   const Options &getOptions() const;
 
index b9d4b3f..0e0083b 100644 (file)
@@ -204,6 +204,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
     for (const auto &argType :
          enumerate(funcOp.getFunctionType().getInputs())) {
       auto convertedType = typeConverter.convertType(argType.value());
+      if (!convertedType)
+        return nullptr;
       signatureConverter.addInputs(argType.index(), convertedType);
     }
   }
index 5eaa099..6d20e98 100644 (file)
@@ -35,8 +35,14 @@ namespace {
 /// replace it).
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
-struct GPUToSPIRVPass : public ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
+class GPUToSPIRVPass : public ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
+public:
+  explicit GPUToSPIRVPass(bool mapMemorySpace)
+      : mapMemorySpace(mapMemorySpace) {}
   void runOnOperation() override;
+
+private:
+  bool mapMemorySpace;
 };
 } // namespace
 
@@ -44,16 +50,30 @@ void GPUToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
   ModuleOp module = getOperation();
 
-  SmallVector<Operation *, 1> kernelModules;
+  SmallVector<Operation *, 1> gpuModules;
   OpBuilder builder(context);
-  module.walk([&builder, &kernelModules](gpu::GPUModuleOp moduleOp) {
-    // For each kernel module (should be only 1 for now, but that is not a
-    // requirement here), clone the module for conversion because the
-    // gpu.launch function still needs the kernel module.
+  module.walk([&](gpu::GPUModuleOp moduleOp) {
+    // Clone each GPU kernel module for conversion, given that the GPU
+    // launch op still needs the original GPU kernel module.
     builder.setInsertionPoint(moduleOp.getOperation());
-    kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
+    gpuModules.push_back(builder.clone(*moduleOp.getOperation()));
   });
 
+  // Map MemRef memory space to SPIR-V sotrage class first if requested.
+  if (mapMemorySpace) {
+    std::unique_ptr<ConversionTarget> target =
+        spirv::getMemorySpaceToStorageClassTarget(*context);
+    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+        spirv::getDefaultVulkanStorageClassMap();
+    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+
+    RewritePatternSet patterns(context);
+    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+
+    if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
+      return signalPassFailure();
+  }
+
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
@@ -68,10 +88,11 @@ void GPUToSPIRVPass::runOnOperation() {
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
   populateFuncToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
+  if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertGPUToSPIRVPass() {
-  return std::make_unique<GPUToSPIRVPass>();
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
+  return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);
 }
index 55da39c..d72802b 100644 (file)
@@ -90,12 +90,12 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
 /// can be lowered to SPIR-V.
 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
   if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
-    if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
-            spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
+    auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+    if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
       return false;
   } else if (isa<memref::AllocaOp>(allocOp)) {
-    if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
-            spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
+    auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+    if (!sc || sc.getValue() != spirv::StorageClass::Function)
       return false;
   } else {
     return false;
@@ -116,12 +116,8 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
 /// operations of unsupported integer bitwidths, based on the memref
 /// type. Returns None on failure.
 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
-  Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(
-          type.getMemorySpaceAsInt());
-  if (!storageClass)
-    return {};
-  switch (*storageClass) {
+  auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+  switch (sc.getValue()) {
   case spirv::StorageClass::StorageBuffer:
     return spirv::Scope::Device;
   case spirv::StorageClass::Workgroup:
index 9a94f19..9154c81 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Sequence.h"
@@ -117,65 +118,6 @@ Type SPIRVTypeConverter::getIndexType() const {
   return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
 }
 
-/// Mapping between SPIR-V storage classes to memref memory spaces.
-///
-/// Note: memref does not have a defined semantics for each memory space; it
-/// depends on the context where it is used. There are no particular reasons
-/// behind the number assignments; we try to follow NVVM conventions and largely
-/// give common storage classes a smaller number. The hope is use symbolic
-/// memory space representation eventually after memref supports it.
-// TODO: swap Generic and StorageBuffer assignment to be more akin
-// to NVVM.
-#define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
-  MAP_FN(spirv::StorageClass::Generic, 1)                                      \
-  MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
-  MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
-  MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
-  MAP_FN(spirv::StorageClass::Private, 5)                                      \
-  MAP_FN(spirv::StorageClass::Function, 6)                                     \
-  MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
-  MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
-  MAP_FN(spirv::StorageClass::Input, 9)                                        \
-  MAP_FN(spirv::StorageClass::Output, 10)                                      \
-  MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
-  MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
-  MAP_FN(spirv::StorageClass::Image, 13)                                       \
-  MAP_FN(spirv::StorageClass::CallableDataKHR, 14)                             \
-  MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15)                     \
-  MAP_FN(spirv::StorageClass::RayPayloadKHR, 16)                               \
-  MAP_FN(spirv::StorageClass::HitAttributeKHR, 17)                             \
-  MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18)                       \
-  MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19)                       \
-  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)                       \
-  MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21)                            \
-  MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22)                             \
-  MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
-
-unsigned
-SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
-#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
-  case storage:                                                                \
-    return space;
-
-  switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
-#undef STORAGE_SPACE_MAP_FN
-  llvm_unreachable("unhandled storage class!");
-}
-
-Optional<spirv::StorageClass>
-SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
-#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
-  case space:                                                                  \
-    return storage;
-
-  switch (space) {
-    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
-  default:
-    return llvm::None;
-  }
-#undef STORAGE_SPACE_MAP_FN
-}
-
 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
   return options;
 }
@@ -184,8 +126,6 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
   return targetEnv.getAttr().getContext();
 }
 
-#undef STORAGE_SPACE_MAP_LIST
-
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
 static Optional<int64_t>
@@ -375,16 +315,8 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
 
 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
                                   const SPIRVTypeConverter::Options &options,
-                                  MemRefType type) {
-  Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(
-          type.getMemorySpaceAsInt());
-  if (!storageClass) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert memory space\n");
-    return nullptr;
-  }
-
+                                  MemRefType type,
+                                  spirv::StorageClass storageClass) {
   unsigned numBoolBits = options.boolNumBits;
   if (numBoolBits != 8) {
     LLVM_DEBUG(llvm::dbgs()
@@ -407,34 +339,37 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   if (!type.hasStaticShape()) {
-    int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+    int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
-    return wrapInStructAndGetPointer(arrayType, *storageClass);
+    return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
   auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
-  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
-  return wrapInStructAndGetPointer(arrayType, *storageClass);
+  return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
                               const SPIRVTypeConverter::Options &options,
                               MemRefType type) {
-  if (type.getElementType().isa<IntegerType>() &&
-      type.getElementTypeBitWidth() == 1) {
-    return convertBoolMemrefType(targetEnv, options, type);
+  auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+  if (!attr) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << type
+        << " illegal: expected memory space to be a SPIR-V storage class "
+           "attribute; please use MemorySpaceToStorageClassConverter to map "
+           "numeric memory spaces beforehand\n");
+    return nullptr;
   }
+  spirv::StorageClass storageClass = attr.getValue();
 
-  Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(
-          type.getMemorySpaceAsInt());
-  if (!storageClass) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert memory space\n");
-    return nullptr;
+  if (type.getElementType().isa<IntegerType>() &&
+      type.getElementTypeBitWidth() == 1) {
+    return convertBoolMemrefType(targetEnv, options, type, storageClass);
   }
 
   Type arrayElemType;
@@ -463,9 +398,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   if (!type.hasStaticShape()) {
-    int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+    int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
-    return wrapInStructAndGetPointer(arrayType, *storageClass);
+    return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
@@ -476,10 +411,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
-  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
-  return wrapInStructAndGetPointer(arrayType, *storageClass);
+  return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
index a85642c..719af7f 100644 (file)
@@ -274,29 +274,51 @@ module attributes {
 // CHECK-SAME: Private
 // CHECK-SAME: Function
 func.func @memref_mem_space(
-    %arg0: memref<4xf32, 0>,
-    %arg1: memref<4xf32, 4>,
-    %arg2: memref<4xf32, 3>,
-    %arg3: memref<4xf32, 7>,
-    %arg4: memref<4xf32, 5>,
-    %arg5: memref<4xf32, 6>
+    %arg0: memref<4xf32, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<4xf32, #spv.storage_class<Uniform>>,
+    %arg2: memref<4xf32, #spv.storage_class<Workgroup>>,
+    %arg3: memref<4xf32, #spv.storage_class<PushConstant>>,
+    %arg4: memref<4xf32, #spv.storage_class<Private>>,
+    %arg5: memref<4xf32, #spv.storage_class<Function>>
 ) { return }
 
 // CHECK-LABEL: func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32>)>, Function>
 // NOEMU-LABEL: func @memref_1bit_type
-// NOEMU-SAME: memref<4x8xi1>
-// NOEMU-SAME: memref<4x8xi1, 6>
+// NOEMU-SAME: memref<4x8xi1, #spv.storage_class<StorageBuffer>>
+// NOEMU-SAME: memref<4x8xi1, #spv.storage_class<Function>>
 func.func @memref_1bit_type(
-    %arg0: memref<4x8xi1, 0>,
-    %arg1: memref<4x8xi1, 6>
+    %arg0: memref<4x8xi1, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<4x8xi1, #spv.storage_class<Function>>
 ) { return }
 
 } // end module
 
 // -----
 
+// Reject memory spaces.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], [SPV_KHR_storage_buffer_storage_class]>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @numeric_memref_mem_space1
+// CHECK-SAME: memref<4xf32>
+// NOEMU-LABEL: func @numeric_memref_mem_space1
+// NOEMU-SAME: memref<4xf32>
+func.func @numeric_memref_mem_space1(%arg0: memref<4xf32>) { return }
+
+// CHECK-LABEL: func @numeric_memref_mem_space2
+// CHECK-SAME: memref<4xf32, 3>
+// NOEMU-LABEL: func @numeric_memref_mem_space2
+// NOEMU-SAME: memref<4xf32, 3>
+func.func @numeric_memref_mem_space2(%arg0: memref<4xf32, 3>) { return }
+
+} // end module
+
+// -----
+
 // Check that using non-32-bit scalar types in interface storage classes
 // requires special capability and extension: convert them to 32-bit if not
 // satisfied.
@@ -308,86 +330,86 @@ module attributes {
 // CHECK-LABEL: spv.func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<2 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_1bit_type
-// NOEMU-SAME: memref<5xi1>
-func.func @memref_1bit_type(%arg0: memref<5xi1>) { return }
+// NOEMU-SAME: memref<5xi1, #spv.storage_class<StorageBuffer>>
+func.func @memref_1bit_type(%arg0: memref<5xi1, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer
-// NOEMU-SAME: memref<16xi8>
-func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
+// NOEMU-SAME: memref<16xi8, #spv.storage_class<StorageBuffer>>
+func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_8bit_Uniform
-// NOEMU-SAME: memref<16xsi8, 4>
-func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }
+// NOEMU-SAME: memref<16xsi8, #spv.storage_class<Uniform>>
+func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_8bit_PushConstant
-// NOEMU-SAME: memref<16xui8, 7>
-func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }
+// NOEMU-SAME: memref<16xui8, #spv.storage_class<PushConstant>>
+func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_16bit_StorageBuffer
-// NOEMU-SAME: memref<16xi16>
-func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }
+// NOEMU-SAME: memref<16xi16, #spv.storage_class<StorageBuffer>>
+func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_16bit_Uniform
-// NOEMU-SAME: memref<16xsi16, 4>
-func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
+// NOEMU-SAME: memref<16xsi16, #spv.storage_class<Uniform>>
+func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_16bit_PushConstant
-// NOEMU-SAME: memref<16xui16, 7>
-func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
+// NOEMU-SAME: memref<16xui16, #spv.storage_class<PushConstant>>
+func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
-// NOEMU-SAME: memref<16xf16, 9>
-func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
+// NOEMU-SAME: memref<16xf16, #spv.storage_class<Input>>
+func.func @memref_16bit_Input(%arg3: memref<16xf16, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
-// NOEMU-SAME: memref<16xf16, 10>
-func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
+// NOEMU-SAME: memref<16xf16, #spv.storage_class<Output>>
+func.func @memref_16bit_Output(%arg4: memref<16xf16, #spv.storage_class<Output>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_64bit_StorageBuffer
-// NOEMU-SAME: memref<16xi64>
-func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, 0>) { return }
+// NOEMU-SAME: memref<16xi64, #spv.storage_class<StorageBuffer>>
+func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_64bit_Uniform
-// NOEMU-SAME: memref<16xsi64, 4>
-func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, 4>) { return }
+// NOEMU-SAME: memref<16xsi64, #spv.storage_class<Uniform>>
+func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_64bit_PushConstant
-// NOEMU-SAME: memref<16xui64, 7>
-func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, 7>) { return }
+// NOEMU-SAME: memref<16xui64, #spv.storage_class<PushConstant>>
+func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Input>
 // NOEMU-LABEL: func @memref_64bit_Input
-// NOEMU-SAME: memref<16xf64, 9>
-func.func @memref_64bit_Input(%arg3: memref<16xf64, 9>) { return }
+// NOEMU-SAME: memref<16xf64, #spv.storage_class<Input>>
+func.func @memref_64bit_Input(%arg3: memref<16xf64, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Output>
 // NOEMU-LABEL: func @memref_64bit_Output
-// NOEMU-SAME: memref<16xf64, 10>
-func.func @memref_64bit_Output(%arg4: memref<16xf64, 10>) { return }
+// NOEMU-SAME: memref<16xf64, #spv.storage_class<Output>>
+func.func @memref_64bit_Output(%arg4: memref<16xf64, #spv.storage_class<Output>>) { return }
 
 } // end module
 
@@ -406,7 +428,7 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, PushConstant>
 // NOEMU-LABEL: spv.func @memref_8bit_PushConstant
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, PushConstant>
-func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }
+func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, PushConstant>
@@ -415,8 +437,8 @@ func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, PushConstant>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, PushConstant>
 func.func @memref_16bit_PushConstant(
-  %arg0: memref<16xi16, 7>,
-  %arg1: memref<16xf16, 7>
+  %arg0: memref<16xi16, #spv.storage_class<PushConstant>>,
+  %arg1: memref<16xf16, #spv.storage_class<PushConstant>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_PushConstant
@@ -426,8 +448,8 @@ func.func @memref_16bit_PushConstant(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, PushConstant>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, PushConstant>
 func.func @memref_64bit_PushConstant(
-  %arg0: memref<16xi64, 7>,
-  %arg1: memref<16xf64, 7>
+  %arg0: memref<16xi64, #spv.storage_class<PushConstant>>,
+  %arg1: memref<16xf64, #spv.storage_class<PushConstant>>
 ) { return }
 
 } // end module
@@ -447,7 +469,7 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, StorageBuffer>
 // NOEMU-LABEL: spv.func @memref_8bit_StorageBuffer
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, StorageBuffer>
-func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
+func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, StorageBuffer>
@@ -456,8 +478,8 @@ func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, StorageBuffer>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, StorageBuffer>
 func.func @memref_16bit_StorageBuffer(
-  %arg0: memref<16xi16, 0>,
-  %arg1: memref<16xf16, 0>
+  %arg0: memref<16xi16, #spv.storage_class<StorageBuffer>>,
+  %arg1: memref<16xf16, #spv.storage_class<StorageBuffer>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
@@ -467,8 +489,8 @@ func.func @memref_16bit_StorageBuffer(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, StorageBuffer>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, StorageBuffer>
 func.func @memref_64bit_StorageBuffer(
-  %arg0: memref<16xi64, 0>,
-  %arg1: memref<16xf64, 0>
+  %arg0: memref<16xi64, #spv.storage_class<StorageBuffer>>,
+  %arg1: memref<16xf64, #spv.storage_class<StorageBuffer>>
 ) { return }
 
 } // end module
@@ -488,7 +510,7 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, Uniform>
 // NOEMU-LABEL: spv.func @memref_8bit_Uniform
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, Uniform>
-func.func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }
+func.func @memref_8bit_Uniform(%arg0: memref<16xi8, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Uniform>
@@ -497,8 +519,8 @@ func.func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Uniform>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Uniform>
 func.func @memref_16bit_Uniform(
-  %arg0: memref<16xi16, 4>,
-  %arg1: memref<16xf16, 4>
+  %arg0: memref<16xi16, #spv.storage_class<Uniform>>,
+  %arg1: memref<16xf16, #spv.storage_class<Uniform>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Uniform
@@ -508,8 +530,8 @@ func.func @memref_16bit_Uniform(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, Uniform>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, Uniform>
 func.func @memref_64bit_Uniform(
-  %arg0: memref<16xi64, 4>,
-  %arg1: memref<16xf64, 4>
+  %arg0: memref<16xi64, #spv.storage_class<Uniform>>,
+  %arg1: memref<16xf64, #spv.storage_class<Uniform>>
 ) { return }
 
 } // end module
@@ -528,13 +550,13 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16>)>, Input>
 // NOEMU-LABEL: spv.func @memref_16bit_Input
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16>)>, Input>
-func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
+func.func @memref_16bit_Input(%arg3: memref<16xf16, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
 // NOEMU-LABEL: spv.func @memref_16bit_Output
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
-func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
+func.func @memref_16bit_Output(%arg4: memref<16xi16, #spv.storage_class<Output>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
@@ -543,8 +565,8 @@ func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Input>
 func.func @memref_64bit_Input(
-  %arg0: memref<16xi64, 9>,
-  %arg1: memref<16xf64, 9>
+  %arg0: memref<16xi64, #spv.storage_class<Input>>,
+  %arg1: memref<16xf64, #spv.storage_class<Input>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Output
@@ -554,8 +576,8 @@ func.func @memref_64bit_Input(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Output>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Output>
 func.func @memref_64bit_Output(
-  %arg0: memref<16xi64, 10>,
-  %arg1: memref<16xf64, 10>
+  %arg0: memref<16xi64, #spv.storage_class<Output>>,
+  %arg1: memref<16xf64, #spv.storage_class<Output>>
 ) { return }
 
 } // end module
@@ -575,22 +597,22 @@ func.func @memref_offset_strides(
 // CHECK-SAME: !spv.array<256 x f32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<88 x f32, stride=4> [0])>, StorageBuffer>
-  %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>,  // tightly packed; row major
-  %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>,  // offset 8
-  %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row
-  %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major
-  %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col
+  %arg0: memref<16x4xf32, offset: 0, strides: [4, 1], #spv.storage_class<StorageBuffer>>,  // tightly packed; row major
+  %arg1: memref<16x4xf32, offset: 8, strides: [4, 1], #spv.storage_class<StorageBuffer>>,  // offset 8
+  %arg2: memref<16x4xf32, offset: 0, strides: [16, 1], #spv.storage_class<StorageBuffer>>, // pad 12 after each row
+  %arg3: memref<16x4xf32, offset: 0, strides: [1, 16], #spv.storage_class<StorageBuffer>>, // tightly packed; col major
+  %arg4: memref<16x4xf32, offset: 0, strides: [1, 22], #spv.storage_class<StorageBuffer>>, // pad 4 after each col
 
 // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<72 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<256 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<88 x f16, stride=2> [0])>, StorageBuffer>
-  %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>,
-  %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>,
-  %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>,
-  %arg8: memref<16x4xf16, offset: 0, strides: [1, 16]>,
-  %arg9: memref<16x4xf16, offset: 0, strides: [1, 22]>
+  %arg5: memref<16x4xf16, offset: 0, strides: [4, 1], #spv.storage_class<StorageBuffer>>,
+  %arg6: memref<16x4xf16, offset: 8, strides: [4, 1], #spv.storage_class<StorageBuffer>>,
+  %arg7: memref<16x4xf16, offset: 0, strides: [16, 1], #spv.storage_class<StorageBuffer>>,
+  %arg8: memref<16x4xf16, offset: 0, strides: [1, 16], #spv.storage_class<StorageBuffer>>,
+  %arg9: memref<16x4xf16, offset: 0, strides: [1, 22], #spv.storage_class<StorageBuffer>>
 ) { return }
 
 } // end module
@@ -610,14 +632,15 @@ func.func @unranked_memref(%arg0: memref<*xi32>) { return }
 // CHECK-LABEL: func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_1bit_type
-// NOEMU-SAME: memref<?xi1>
-func.func @memref_1bit_type(%arg0: memref<?xi1>) { return }
+// NOEMU-SAME: memref<?xi1, #spv.storage_class<StorageBuffer>>
+func.func @memref_1bit_type(%arg0: memref<?xi1, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: func @dynamic_dim_memref
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
-func.func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
-                         %arg1: memref<?x?xf32>) { return }
+func.func @dynamic_dim_memref(
+    %arg0: memref<8x?xi32, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<?x?xf32, #spv.storage_class<StorageBuffer>>) { return }
 
 // Check that using non-32-bit scalar types in interface storage classes
 // requires special capability and extension: convert them to 32-bit if not
@@ -626,50 +649,50 @@ func.func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer
-// NOEMU-SAME: memref<?xi8>
-func.func @memref_8bit_StorageBuffer(%arg0: memref<?xi8, 0>) { return }
+// NOEMU-SAME: memref<?xi8, #spv.storage_class<StorageBuffer>>
+func.func @memref_8bit_StorageBuffer(%arg0: memref<?xi8, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_8bit_Uniform
-// NOEMU-SAME: memref<?xsi8, 4>
-func.func @memref_8bit_Uniform(%arg0: memref<?xsi8, 4>) { return }
+// NOEMU-SAME: memref<?xsi8, #spv.storage_class<Uniform>>
+func.func @memref_8bit_Uniform(%arg0: memref<?xsi8, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_8bit_PushConstant
-// NOEMU-SAME: memref<?xui8, 7>
-func.func @memref_8bit_PushConstant(%arg0: memref<?xui8, 7>) { return }
+// NOEMU-SAME: memref<?xui8, #spv.storage_class<PushConstant>>
+func.func @memref_8bit_PushConstant(%arg0: memref<?xui8, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_16bit_StorageBuffer
-// NOEMU-SAME: memref<?xi16>
-func.func @memref_16bit_StorageBuffer(%arg0: memref<?xi16, 0>) { return }
+// NOEMU-SAME: memref<?xi16, #spv.storage_class<StorageBuffer>>
+func.func @memref_16bit_StorageBuffer(%arg0: memref<?xi16, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_16bit_Uniform
-// NOEMU-SAME: memref<?xsi16, 4>
-func.func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
+// NOEMU-SAME: memref<?xsi16, #spv.storage_class<Uniform>>
+func.func @memref_16bit_Uniform(%arg0: memref<?xsi16, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_16bit_PushConstant
-// NOEMU-SAME: memref<?xui16, 7>
-func.func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
+// NOEMU-SAME: memref<?xui16, #spv.storage_class<PushConstant>>
+func.func @memref_16bit_PushConstant(%arg0: memref<?xui16, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
-// NOEMU-SAME: memref<?xf16, 9>
-func.func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
+// NOEMU-SAME: memref<?xf16, #spv.storage_class<Input>>
+func.func @memref_16bit_Input(%arg3: memref<?xf16, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
-// NOEMU-SAME: memref<?xf16, 10>
-func.func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }
+// NOEMU-SAME: memref<?xf16, #spv.storage_class<Output>>
+func.func @memref_16bit_Output(%arg4: memref<?xf16, #spv.storage_class<Output>>) { return }
 
 } // end module
 
@@ -684,15 +707,16 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x vector<2xf32>, stride=8> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, Uniform>
 func.func @memref_vector(
-    %arg0: memref<4xvector<2xf32>, 0>,
-    %arg1: memref<4xvector<4xf32>, 4>)
+    %arg0: memref<4xvector<2xf32>, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<4xvector<4xf32>, #spv.storage_class<Uniform>>)
 { return }
 
 // CHECK-LABEL: func @dynamic_dim_memref_vector
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xi32>, stride=16> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
-func.func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
-                         %arg1: memref<?x?xvector<2xf32>>)
+func.func @dynamic_dim_memref_vector(
+    %arg0: memref<8x?xvector<4xi32>, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<?x?xvector<2xf32>, #spv.storage_class<StorageBuffer>>)
 { return }
 
 } // end module
@@ -705,9 +729,9 @@ module attributes {
 } {
 
 // CHECK-LABEL: func @memref_vector_wrong_size
-// CHECK-SAME: memref<4xvector<5xf32>>
+// CHECK-SAME: memref<4xvector<5xf32>, #spv.storage_class<StorageBuffer>>
 func.func @memref_vector_wrong_size(
-    %arg0: memref<4xvector<5xf32>, 0>)
+    %arg0: memref<4xvector<5xf32>, #spv.storage_class<StorageBuffer>>)
 { return }
 
 } // end module
@@ -7,7 +7,7 @@ module attributes {gpu.container_module} {
     // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}
     // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4> [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
     // CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
-    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) kernel
+    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       // CHECK: spv.Return
       gpu.return
@@ -16,11 +16,11 @@ module attributes {gpu.container_module} {
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<StorageBuffer>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@basic_module_structure
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<StorageBuffer>>)
     return
   }
 }
@@ -39,7 +39,7 @@ module attributes {gpu.container_module} {
     gpu.func @basic_module_structure_preset_ABI(
       %arg0 : f32
         {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>},
-      %arg1 : memref<12xf32>
+      %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>
         {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel
       attributes
         {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
@@ -55,18 +55,18 @@ module attributes {gpu.container_module} {
   gpu.module @kernels {
     // expected-error @below {{failed to legalize operation 'gpu.func'}}
     // expected-remark @below {{match failure: missing 'spv.entry_point_abi' attribute}}
-    gpu.func @missing_entry_point_abi(%arg0 : f32, %arg1 : memref<12xf32>) kernel {
+    gpu.func @missing_entry_point_abi(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel {
       gpu.return
     }
   }
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<StorageBuffer>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@missing_entry_point_abi
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<StorageBuffer>>)
     return
   }
 }
@@ -80,7 +80,7 @@ module attributes {gpu.container_module} {
     gpu.func @missing_entry_point_abi(
       %arg0 : f32
         {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>},
-      %arg1 : memref<12xf32>) kernel
+      %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel
     attributes
       {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       gpu.return
@@ -96,7 +96,7 @@ module attributes {gpu.container_module} {
     // expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 0}}
     gpu.func @missing_entry_point_abi(
       %arg0 : f32,
-      %arg1 : memref<12xf32>
+      %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>
         {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel
     attributes
       {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
@@ -110,7 +110,7 @@ module attributes {gpu.container_module} {
 module attributes {gpu.container_module} {
   gpu.module @kernels {
     // CHECK-LABEL: spv.func @barrier
-    gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel
+    gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       // CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
       gpu.barrier
@@ -120,11 +120,11 @@ module attributes {gpu.container_module} {
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<StorageBuffer>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@barrier
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<StorageBuffer>>)
     return
   }
 }
index bbc01bf..abce5d7 100644 (file)
@@ -5,7 +5,7 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spv.resource_limits<>>
 } {
-  func.func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
+  func.func @load_store(%arg0: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg1: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg2: memref<12x4xf32, #spv.storage_class<StorageBuffer>>) {
     %c0 = arith.constant 0 : index
     %c12 = arith.constant 12 : index
     %0 = arith.subi %c12, %c0 : index
@@ -17,7 +17,7 @@ module attributes {
     %c1_2 = arith.constant 1 : index
     gpu.launch_func @kernels::@load_store_kernel
         blocks in (%0, %c1_2, %c1_2) threads in (%1, %c1_2, %c1_2)
-        args(%arg0 : memref<12x4xf32>, %arg1 : memref<12x4xf32>, %arg2 : memref<12x4xf32>,
+        args(%arg0 : memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg1 : memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg2 : memref<12x4xf32, #spv.storage_class<StorageBuffer>>,
              %c0 : index, %c0_0 : index, %c1 : index, %c1_1 : index)
     return
   }
@@ -35,7 +35,7 @@ module attributes {
     // CHECK-SAME: %[[ARG4:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}
     // CHECK-SAME: %[[ARG5:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}
     // CHECK-SAME: %[[ARG6:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}
-    gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel
+    gpu.func @load_store_kernel(%arg0: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg1: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg2: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[16, 1, 1]>: vector<3xi32>>} {
       // CHECK: %[[ADDRESSWORKGROUPID:.*]] = spv.mlir.addressof @[[$WORKGROUPIDVAR]]
       // CHECK: %[[WORKGROUPID:.*]] = spv.Load "Input" %[[ADDRESSWORKGROUPID]]
@@ -69,15 +69,15 @@ module attributes {
       // CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
       // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]]
-      %14 = memref.load %arg0[%12, %13] : memref<12x4xf32>
+      %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spv.storage_class<StorageBuffer>>
       // CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
       // CHECK-NEXT: %[[VAL2:.*]] = spv.Load "StorageBuffer" %[[PTR2]]
-      %15 = memref.load %arg1[%12, %13] : memref<12x4xf32>
+      %15 = memref.load %arg1[%12, %13] : memref<12x4xf32, #spv.storage_class<StorageBuffer>>
       // CHECK: %[[VAL3:.*]] = spv.FAdd %[[VAL1]], %[[VAL2]]
       %16 = arith.addf %14, %15 : f32
       // CHECK: %[[PTR3:.*]] = spv.AccessChain %[[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
       // CHECK-NEXT: spv.Store "StorageBuffer" %[[PTR3]], %[[VAL3]]
-      memref.store %16, %arg2[%12, %13] : memref<12x4xf32>
+      memref.store %16, %arg2[%12, %13] : memref<12x4xf32, #spv.storage_class<StorageBuffer>>
       gpu.return
     }
   }
@@ -12,7 +12,7 @@ module attributes {
     //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>
     //   CHECK-NOT:     spv.interface_var_abi
     //  CHECK-SAME:     spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
-    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel
+    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<CrossWorkgroup>>) kernel
         attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       gpu.return
     }
@@ -20,11 +20,11 @@ module attributes {
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32, 11>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<CrossWorkgroup>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@basic_module_structure
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32, 11>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<CrossWorkgroup>>)
     return
   }
 }
index 3b0af88..6aeb60c 100644 (file)
@@ -44,12 +44,12 @@ module attributes {
 // CHECK:        }
 // CHECK:        spv.Return
 
-func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
+func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spv.storage_class<StorageBuffer>>) attributes {
   spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[16, 1, 1]>: vector<3xi32>>
 } {
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16xi32>)
-     outs(%output : memref<1xi32>) {
+      ins(%input : memref<16xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<1xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
@@ -74,11 +74,11 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spv.resource_limits<>>
 } {
-func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) {
+func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spv.storage_class<StorageBuffer>>) {
   // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16xi32>)
-     outs(%output : memref<1xi32>) {
+      ins(%input : memref<16xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<1xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
@@ -103,13 +103,13 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spv.resource_limits<>>
 } {
-func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
+func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spv.storage_class<StorageBuffer>>) attributes {
   spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 1, 1]>: vector<3xi32>>
 } {
   // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16xi32>)
-     outs(%output : memref<1xi32>) {
+      ins(%input : memref<16xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<1xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
@@ -134,13 +134,13 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spv.resource_limits<>>
 } {
-func.func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes {
+func.func @single_workgroup_reduction(%input: memref<16x8xi32, #spv.storage_class<StorageBuffer>>, %output: memref<16xi32, #spv.storage_class<StorageBuffer>>) attributes {
   spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[16, 8, 1]>: vector<3xi32>>
 } {
   // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16x8xi32>)
-     outs(%output : memref<16xi32>) {
+      ins(%input : memref<16x8xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<16xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
index e07b382..2edc37e 100644 (file)
@@ -6,10 +6,10 @@ module attributes {
   }
 {
   func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
-    %0 = memref.alloc() : memref<4x5xf32, 3>
-    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 3>
-    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 3>
-    memref.dealloc %0 : memref<4x5xf32, 3>
+    %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xf32, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -31,10 +31,10 @@ module attributes {
   }
 {
   func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
-    %0 = memref.alloc() : memref<4x5xi16, 3>
-    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, 3>
-    memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
-    memref.dealloc %0 : memref<4x5xi16, 3>
+    %0 = memref.alloc() : memref<4x5xi16, #spv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, #spv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, #spv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xi16, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -60,8 +60,8 @@ module attributes {
   }
 {
   func.func @two_allocs() {
-    %0 = memref.alloc() : memref<4x5xf32, 3>
-    %1 = memref.alloc() : memref<2x3xi32, 3>
+    %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    %1 = memref.alloc() : memref<2x3xi32, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -80,8 +80,8 @@ module attributes {
   }
 {
   func.func @two_allocs_vector() {
-    %0 = memref.alloc() : memref<4xvector<4xf32>, 3>
-    %1 = memref.alloc() : memref<2xvector<2xi32>, 3>
+    %0 = memref.alloc() : memref<4xvector<4xf32>, #spv.storage_class<Workgroup>>
+    %1 = memref.alloc() : memref<2xvector<2xi32>, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -103,8 +103,8 @@ module attributes {
   // CHECK-LABEL: func @alloc_dynamic_size
   func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
     // CHECK: memref.alloc
-    %0 = memref.alloc(%arg0) : memref<4x?xf32, 3>
-    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3>
+    %0 = memref.alloc(%arg0) : memref<4x?xf32, #spv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spv.storage_class<Workgroup>>
     return %1: f32
   }
 }
@@ -119,8 +119,8 @@ module attributes {
   // CHECK-LABEL: func @alloc_unsupported_memory_space
   func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
     // CHECK: memref.alloc
-    %0 = memref.alloc() : memref<4x5xf32>
-    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
+    %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class<StorageBuffer>>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32, #spv.storage_class<StorageBuffer>>
     return %1: f32
   }
 }
@@ -134,9 +134,9 @@ module attributes {
   }
 {
   // CHECK-LABEL: func @dealloc_dynamic_size
-  func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) {
+  func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, #spv.storage_class<Workgroup>>) {
     // CHECK: memref.dealloc
-    memref.dealloc %arg0 : memref<4x?xf32, 3>
+    memref.dealloc %arg0 : memref<4x?xf32, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -149,9 +149,9 @@ module attributes {
   }
 {
   // CHECK-LABEL: func @dealloc_unsupported_memory_space
-  func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) {
+  func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32, #spv.storage_class<StorageBuffer>>) {
     // CHECK: memref.dealloc
-    memref.dealloc %arg0 : memref<4x5xf32>
+    memref.dealloc %arg0 : memref<4x5xf32, #spv.storage_class<StorageBuffer>>
     return
   }
 }
index e2cd90e..8008128 100644 (file)
@@ -2,9 +2,9 @@
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, #spv.resource_limits<>>} {
   func.func @alloc_function_variable(%arg0 : index, %arg1 : index) {
-    %0 = memref.alloca() : memref<4x5xf32, 6>
-    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6>
-    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6>
+    %0 = memref.alloca() : memref<4x5xf32, #spv.storage_class<Function>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Function>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Function>>
     return
   }
 }
@@ -21,8 +21,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, #spv.resource_limits<>>} {
   func.func @two_allocs() {
-    %0 = memref.alloca() : memref<4x5xf32, 6>
-    %1 = memref.alloca() : memref<2x3xi32, 6>
+    %0 = memref.alloca() : memref<4x5xf32, #spv.storage_class<Function>>
+    %1 = memref.alloca() : memref<2x3xi32, #spv.storage_class<Function>>
     return
   }
 }
@@ -35,8 +35,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, #spv.resource_limits<>>} {
   func.func @two_allocs_vector() {
-    %0 = memref.alloca() : memref<4xvector<4xf32>, 6>
-    %1 = memref.alloca() : memref<2xvector<2xi32>, 6>
+    %0 = memref.alloca() : memref<4xvector<4xf32>, #spv.storage_class<Function>>
+    %1 = memref.alloca() : memref<2xvector<2xi32>, #spv.storage_class<Function>>
     return
   }
 }
@@ -52,8 +52,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
   // CHECK-LABEL: func @alloc_dynamic_size
   func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
     // CHECK: memref.alloca
-    %0 = memref.alloca(%arg0) : memref<4x?xf32, 6>
-    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6>
+    %0 = memref.alloca(%arg0) : memref<4x?xf32, #spv.storage_class<Function>>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spv.storage_class<Function>>
     return %1: f32
   }
 }
index 1cd94f9..212363c 100644 (file)
@@ -15,60 +15,60 @@ module attributes {
 } {
 
 // CHECK-LABEL: @load_store_zero_rank_float
-func.func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
-  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+func.func @load_store_zero_rank_float(%arg0: memref<f32, #spv.storage_class<StorageBuffer>>, %arg1: memref<f32, #spv.storage_class<StorageBuffer>>) {
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
   //      CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG0]][
   // CHECK-SAME: [[ZERO1]], [[ZERO1]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Load "StorageBuffer" %{{.*}} : f32
-  %0 = memref.load %arg0[] : memref<f32>
+  %0 = memref.load %arg0[] : memref<f32, #spv.storage_class<StorageBuffer>>
   //      CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG1]][
   // CHECK-SAME: [[ZERO2]], [[ZERO2]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Store "StorageBuffer" %{{.*}} : f32
-  memref.store %0, %arg1[] : memref<f32>
+  memref.store %0, %arg1[] : memref<f32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_store_zero_rank_int
-func.func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
-  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
-  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+func.func @load_store_zero_rank_int(%arg0: memref<i32, #spv.storage_class<StorageBuffer>>, %arg1: memref<i32, #spv.storage_class<StorageBuffer>>) {
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
   //      CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG0]][
   // CHECK-SAME: [[ZERO1]], [[ZERO1]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
-  %0 = memref.load %arg0[] : memref<i32>
+  %0 = memref.load %arg0[] : memref<i32, #spv.storage_class<StorageBuffer>>
   //      CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG1]][
   // CHECK-SAME: [[ZERO2]], [[ZERO2]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Store "StorageBuffer" %{{.*}} : i32
-  memref.store %0, %arg1[] : memref<i32>
+  memref.store %0, %arg1[] : memref<i32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: func @load_store_unknown_dim
-func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?xi32>) {
-  // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
-  // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spv.storage_class<StorageBuffer>>, %dest: memref<?xi32, #spv.storage_class<StorageBuffer>>) {
+  // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+  // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
   // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]]
   // CHECK: spv.Load "StorageBuffer" %[[AC0]]
-  %0 = memref.load %source[%i] : memref<?xi32>
+  %0 = memref.load %source[%i] : memref<?xi32, #spv.storage_class<StorageBuffer>>
   // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]]
   // CHECK: spv.Store "StorageBuffer" %[[AC1]]
-  memref.store %0, %dest[%i]: memref<?xi32>
+  memref.store %0, %dest[%i]: memref<?xi32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: func @load_i1
-//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index)
-func.func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
-  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_i1(%src: memref<4xi1, #spv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
   // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
@@ -79,17 +79,17 @@ func.func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
   // CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[ADDR]] : i8
   // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
   // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8
-  %0 = memref.load %src[%i] : memref<4xi1>
+  %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class<StorageBuffer>>
   // CHECK: return %[[BOOL]]
   return %0: i1
 }
 
 // CHECK-LABEL: func @store_i1
-//  CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
+//  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class<StorageBuffer>>,
 //  CHECK-SAME: %[[IDX:.+]]: index
-func.func @store_i1(%dst: memref<4xi1>, %i: index) {
+func.func @store_i1(%dst: memref<4xi1, #spv.storage_class<StorageBuffer>>, %i: index) {
   %true = arith.constant true
-  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
   // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
@@ -101,7 +101,7 @@ func.func @store_i1(%dst: memref<4xi1>, %i: index) {
   // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
   // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
   // CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
-  memref.store %true, %dst[%i]: memref<4xi1>
+  memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class<StorageBuffer>>
   return
 }
 
@@ -118,7 +118,7 @@ module attributes {
 } {
 
 // CHECK-LABEL: @load_i1
-func.func @load_i1(%arg0: memref<i1>) -> i1 {
+func.func @load_i1(%arg0: memref<i1, #spv.storage_class<StorageBuffer>>) -> i1 {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
@@ -138,12 +138,12 @@ func.func @load_i1(%arg0: memref<i1>) -> i1 {
   //     CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
   //     CHECK: %[[RES:.+]]  = spv.IEqual %[[T4]], %[[ONE]] : i32
   //     CHECK: return %[[RES]]
-  %0 = memref.load %arg0[] : memref<i1>
+  %0 = memref.load %arg0[] : memref<i1, #spv.storage_class<StorageBuffer>>
   return %0 : i1
 }
 
 // CHECK-LABEL: @load_i8
-func.func @load_i8(%arg0: memref<i8>) {
+func.func @load_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>) {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
@@ -159,13 +159,13 @@ func.func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[T2:.+]] = spv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
-  %0 = memref.load %arg0[] : memref<i8>
+  %0 = memref.load %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_i16
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: index)
-func.func @load_i16(%arg0: memref<10xi16>, %index : index) {
+func.func @load_i16(%arg0: memref<10xi16, #spv.storage_class<StorageBuffer>>, %index : index) {
   //     CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32
@@ -186,31 +186,31 @@ func.func @load_i16(%arg0: memref<10xi16>, %index : index) {
   //     CHECK: %[[T2:.+]] = spv.Constant 16 : i32
   //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
-  %0 = memref.load %arg0[%index] : memref<10xi16>
+  %0 = memref.load %arg0[%index] : memref<10xi16, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_i32
-func.func @load_i32(%arg0: memref<i32>) {
+func.func @load_i32(%arg0: memref<i32, #spv.storage_class<StorageBuffer>>) {
   // CHECK-NOT: spv.SDiv
   //     CHECK: spv.Load
   // CHECK-NOT: spv.ShiftRightArithmetic
-  %0 = memref.load %arg0[] : memref<i32>
+  %0 = memref.load %arg0[] : memref<i32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_f32
-func.func @load_f32(%arg0: memref<f32>) {
+func.func @load_f32(%arg0: memref<f32, #spv.storage_class<StorageBuffer>>) {
   // CHECK-NOT: spv.SDiv
   //     CHECK: spv.Load
   // CHECK-NOT: spv.ShiftRightArithmetic
-  %0 = memref.load %arg0[] : memref<f32>
+  %0 = memref.load %arg0[] : memref<f32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i1
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1)
-func.func @store_i1(%arg0: memref<i1>, %value: i1) {
+func.func @store_i1(%arg0: memref<i1, #spv.storage_class<StorageBuffer>>, %value: i1) {
   //     CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
@@ -230,13 +230,13 @@ func.func @store_i1(%arg0: memref<i1>, %value: i1) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[] : memref<i1>
+  memref.store %value, %arg0[] : memref<i1, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i8
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
-func.func @store_i8(%arg0: memref<i8>, %value: i8) {
+func.func @store_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>, %value: i8) {
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
@@ -254,13 +254,13 @@ func.func @store_i8(%arg0: memref<i8>, %value: i8) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[] : memref<i8>
+  memref.store %value, %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i16
-//       CHECK: (%[[ARG0:.+]]: memref<10xi16>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16)
-func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+//       CHECK: (%[[ARG0:.+]]: memref<10xi16, #spv.storage_class<StorageBuffer>>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16)
+func.func @store_i16(%arg0: memref<10xi16, #spv.storage_class<StorageBuffer>>, %index: index, %value: i16) {
   //     CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
@@ -283,25 +283,25 @@ func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[%index] : memref<10xi16>
+  memref.store %value, %arg0[%index] : memref<10xi16, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i32
-func.func @store_i32(%arg0: memref<i32>, %value: i32) {
+func.func @store_i32(%arg0: memref<i32, #spv.storage_class<StorageBuffer>>, %value: i32) {
   //     CHECK: spv.Store
   // CHECK-NOT: spv.AtomicAnd
   // CHECK-NOT: spv.AtomicOr
-  memref.store %value, %arg0[] : memref<i32>
+  memref.store %value, %arg0[] : memref<i32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_f32
-func.func @store_f32(%arg0: memref<f32>, %value: f32) {
+func.func @store_f32(%arg0: memref<f32, #spv.storage_class<StorageBuffer>>, %value: f32) {
   //     CHECK: spv.Store
   // CHECK-NOT: spv.AtomicAnd
   // CHECK-NOT: spv.AtomicOr
-  memref.store %value, %arg0[] : memref<f32>
+  memref.store %value, %arg0[] : memref<f32, #spv.storage_class<StorageBuffer>>
   return
 }
 
@@ -318,7 +318,7 @@ module attributes {
 } {
 
 // CHECK-LABEL: @load_i8
-func.func @load_i8(%arg0: memref<i8>) {
+func.func @load_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>) {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
@@ -334,22 +334,22 @@ func.func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[T2:.+]] = spv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
-  %0 = memref.load %arg0[] : memref<i8>
+  %0 = memref.load %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_i16
-func.func @load_i16(%arg0: memref<i16>) {
+func.func @load_i16(%arg0: memref<i16, #spv.storage_class<StorageBuffer>>) {
   // CHECK-NOT: spv.SDiv
   //     CHECK: spv.Load
   // CHECK-NOT: spv.ShiftRightArithmetic
-  %0 = memref.load %arg0[] : memref<i16>
+  %0 = memref.load %arg0[] : memref<i16, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i8
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
-func.func @store_i8(%arg0: memref<i8>, %value: i8) {
+func.func @store_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>, %value: i8) {
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
@@ -367,16 +367,16 @@ func.func @store_i8(%arg0: memref<i8>, %value: i8) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[] : memref<i8>
+  memref.store %value, %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i16
-func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+func.func @store_i16(%arg0: memref<10xi16, #spv.storage_class<StorageBuffer>>, %index: index, %value: i16) {
   //     CHECK: spv.Store
   // CHECK-NOT: spv.AtomicAnd
   // CHECK-NOT: spv.AtomicOr
-  memref.store %value, %arg0[%index] : memref<10xi16>
+  memref.store %value, %arg0[%index] : memref<10xi16, #spv.storage_class<StorageBuffer>>
   return
 }
 
index 5171637..54a8e93 100644 (file)
@@ -5,7 +5,7 @@ module attributes {
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spv.resource_limits<>>
 } {
 
-func.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
+func.func @loop_kernel(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>) {
   // CHECK: %[[LB:.*]] = spv.Constant 4 : i32
   %lb = arith.constant 4 : index
   // CHECK: %[[UB:.*]] = spv.Constant 42 : i32
@@ -36,14 +36,14 @@ func.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
   // CHECK:        spv.mlir.merge
   // CHECK:      }
   scf.for %arg4 = %lb to %ub step %step {
-    %1 = memref.load %arg2[%arg4] : memref<10xf32>
-    memref.store %1, %arg3[%arg4] : memref<10xf32>
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   }
   return
 }
 
 // CHECK-LABEL: @loop_yield
-func.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
+func.func @loop_yield(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>) {
   // CHECK: %[[LB:.*]] = spv.Constant 4 : i32
   %lb = arith.constant 4 : index
   // CHECK: %[[UB:.*]] = spv.Constant 42 : i32
@@ -78,8 +78,8 @@ func.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
   // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
   // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
   // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
-  memref.store %result#0, %arg3[%lb] : memref<10xf32>
-  memref.store %result#1, %arg3[%ub] : memref<10xf32>
+  memref.store %result#0, %arg3[%lb] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+  memref.store %result#1, %arg3[%ub] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   return
 }
 
index f937ac6..d8463b9 100644 (file)
@@ -6,7 +6,7 @@ module attributes {
 } {
 
 // CHECK-LABEL: @kernel_simple_selection
-func.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) {
+func.func @kernel_simple_selection(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : i1) {
   %value = arith.constant 0.0 : f32
   %i = arith.constant 0 : index
 
@@ -20,13 +20,13 @@ func.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) {
   // CHECK-NEXT:  spv.Return
 
   scf.if %arg3 {
-    memref.store %value, %arg2[%i] : memref<10xf32>
+    memref.store %value, %arg2[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   }
   return
 }
 
 // CHECK-LABEL: @kernel_nested_selection
-func.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) {
+func.func @kernel_nested_selection(%arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg4 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg5 : i1, %arg6 : i1) {
   %i = arith.constant 0 : index
   %j = arith.constant 9 : index
 
@@ -61,26 +61,26 @@ func.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32
 
   scf.if %arg5 {
     scf.if %arg6 {
-      %value = memref.load %arg3[%i] : memref<10xf32>
-      memref.store %value, %arg4[%i] : memref<10xf32>
+      %value = memref.load %arg3[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg4[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     } else {
-      %value = memref.load %arg4[%i] : memref<10xf32>
-      memref.store %value, %arg3[%i] : memref<10xf32>
+      %value = memref.load %arg4[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg3[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     }
   } else {
     scf.if %arg6 {
-      %value = memref.load %arg3[%j] : memref<10xf32>
-      memref.store %value, %arg4[%j] : memref<10xf32>
+      %value = memref.load %arg3[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg4[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     } else {
-      %value = memref.load %arg4[%j] : memref<10xf32>
-      memref.store %value, %arg3[%j] : memref<10xf32>
+      %value = memref.load %arg4[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg3[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     }
   }
   return
 }
 
 // CHECK-LABEL: @simple_if_yield
-func.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) {
+func.func @simple_if_yield(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : i1) {
   // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
   // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
   // CHECK:       spv.mlir.selection {
@@ -116,15 +116,15 @@ func.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) {
   }
   %i = arith.constant 0 : index
   %j = arith.constant 1 : index
-  memref.store %0#0, %arg2[%i] : memref<10xf32>
-  memref.store %0#1, %arg2[%j] : memref<10xf32>
+  memref.store %0#0, %arg2[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+  memref.store %0#1, %arg2[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // TODO: The transformation should only be legal if VariablePointer capability
 // is supported. This test is still useful to make sure we can handle scf op
 // result with type change.
-func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) {
+func.func @simple_if_yield_type_change(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg4 : i1) {
   // CHECK-LABEL: @simple_if_yield_type_change
   // CHECK:       %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>, Function>
   // CHECK:       spv.mlir.selection {
@@ -144,12 +144,12 @@ func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10
   // CHECK:       spv.Return
   %i = arith.constant 0 : index
   %value = arith.constant 0.0 : f32
-  %0 = scf.if %arg4 -> (memref<10xf32>) {
-    scf.yield %arg2 : memref<10xf32>
+  %0 = scf.if %arg4 -> (memref<10xf32, #spv.storage_class<StorageBuffer>>) {
+    scf.yield %arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>
   } else {
-    scf.yield %arg3 : memref<10xf32>
+    scf.yield %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>
   }
-  memref.store %value, %0[%i] : memref<10xf32>
+  memref.store %value, %0[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   return
 }
 
index 9a3c1be..b553d1a 100644 (file)
@@ -75,7 +75,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   PassManager passManager(module.getContext());
   applyPassManagerCLOptions(passManager);
   passManager.addPass(createGpuKernelOutliningPass());
-  passManager.addPass(createConvertGPUToSPIRVPass());
+  passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
 
   OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
   nestedPM.addPass(spirv::createLowerABIAttributesPass());
index 7e61ea7..d942ef9 100644 (file)
@@ -47,10 +47,12 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
 
   passManager.addPass(createGpuKernelOutliningPass());
   passManager.addPass(memref::createFoldSubViewOpsPass());
-  passManager.addPass(createConvertGPUToSPIRVPass());
+
+  passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createLowerABIAttributesPass());
   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
+
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
   LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
   passManager.addPass(createMemRefToLLVMPass());
@@ -58,6 +60,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
   passManager.addPass(createReconcileUnrealizedCastsPass());
   passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+
   return passManager.run(module);
 }