[spirv] Fix bitwidth emulation for Workgroup storage class
authorLei Zhang <antiagainst@google.com>
Wed, 5 Aug 2020 14:06:00 +0000 (10:06 -0400)
committerLei Zhang <antiagainst@google.com>
Wed, 5 Aug 2020 18:44:03 +0000 (14:44 -0400)
If Int16 is not available, 16-bit integers inside Workgroup storage
class should be emulated via 32-bit integers. This was previously
broken because the capability querying logic was incorrectly
intercepting all storage classes where it meant to only handle
interface storage classes. Adjusted where we return to fix this.

Differential Revision: https://reviews.llvm.org/D85308

mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/test/Conversion/StandardToSPIRV/alloc.mlir

index 93d0c43..583a779 100644 (file)
@@ -772,8 +772,12 @@ void ScalarType::getCapabilities(
       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
       capabilities.push_back(ref);                                             \
     }                                                                          \
-  } break
+    /* No requirements for other bitwidths */                                  \
+    return;                                                                    \
+  }
 
+  // This part only handles the cases where special bitwidths appearing in
+  // interface storage classes.
   if (storage) {
     switch (*storage) {
       STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
@@ -782,17 +786,17 @@ void ScalarType::getCapabilities(
       STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
                    StorageUniform16);
     case StorageClass::Input:
-    case StorageClass::Output:
+    case StorageClass::Output: {
       if (bitwidth == 16) {
         static const Capability caps[] = {Capability::StorageInputOutput16};
         ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
         capabilities.push_back(ref);
       }
-      break;
+      return;
+    }
     default:
       break;
     }
-    return;
   }
 #undef STORAGE_CASE
 
index fe4c9d1..14ce469 100644 (file)
@@ -32,25 +32,34 @@ module attributes {
 
 // -----
 
-// TODO: Uncomment this test when the extension handling correctly
-// converts an i16 type to i32 type and handles the load/stores
-// correctly.
-
-// module attributes {
-//   spv.target_env = #spv.target_env<
-//     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-//     {max_compute_workgroup_invocations = 128 : i32,
-//      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
-//   }
-// {
-//   func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
-//     %0 = alloc() : memref<4x5xi16, 3>
-//     %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
-//     store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
-//     dealloc %0 : memref<4x5xi16, 3>
-//     return
-//   }
-// }
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
+    %0 = alloc() : memref<4x5xi16, 3>
+    %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
+    store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
+    dealloc %0 : memref<4x5xi16, 3>
+    return
+  }
+}
+
+//       CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+//  CHECK-SAME:   !spv.ptr<!spv.struct<!spv.array<20 x i32, stride=4>>, Workgroup>
+// CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem
+//       CHECK:   %[[VAR:.+]] = spv._address_of @__workgroup_mem__0
+//       CHECK:   %[[LOC:.+]] = spv.SDiv
+//       CHECK:   %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
+//       CHECK:   %{{.+}} = spv.Load "Workgroup" %[[PTR]] : i32
+//       CHECK:   %[[LOC:.+]] = spv.SDiv
+//       CHECK:   %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
+//       CHECK:   %{{.+}} = spv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr<i32, Workgroup>
+//       CHECK:   %{{.+}} = spv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr<i32, Workgroup>
+
 
 // -----