From 48378a32af54af6ae656a3db14dc7c0d975d0f48 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 5 Aug 2020 10:06:00 -0400 Subject: [PATCH] [spirv] Fix bitwidth emulation for Workgroup storage class If Int16 is not available, 16-bit integers inside Workgroup storage class should be emulated via 32-bit integers. This was previously broken because the capability querying logic was incorrectly intercepting all storage classes where it meant to only handle interface storage classes. Adjusted where we return to fix this. Differential Revision: https://reviews.llvm.org/D85308 --- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp | 12 ++++--- mlir/test/Conversion/StandardToSPIRV/alloc.mlir | 47 +++++++++++++++---------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 93d0c43..583a779 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -772,8 +772,12 @@ void ScalarType::getCapabilities( ArrayRef ref(caps, llvm::array_lengthof(caps)); \ capabilities.push_back(ref); \ } \ - } break + /* No requirements for other bitwidths */ \ + return; \ + } + // This part only handles the cases where special bitwidths appearing in + // interface storage classes. if (storage) { switch (*storage) { STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16); @@ -782,17 +786,17 @@ void ScalarType::getCapabilities( STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess, StorageUniform16); case StorageClass::Input: - case StorageClass::Output: + case StorageClass::Output: { if (bitwidth == 16) { static const Capability caps[] = {Capability::StorageInputOutput16}; ArrayRef ref(caps, llvm::array_lengthof(caps)); capabilities.push_back(ref); } - break; + return; + } default: break; } - return; } #undef STORAGE_CASE diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir index fe4c9d1..14ce469 100644 --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -32,25 +32,34 @@ module attributes { // ----- -// TODO: Uncomment this test when the extension handling correctly -// converts an i16 type to i32 type and handles the load/stores -// correctly. - -// module attributes { -// spv.target_env = #spv.target_env< -// #spv.vce, -// {max_compute_workgroup_invocations = 128 : i32, -// max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> -// } -// { -// func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { -// %0 = alloc() : memref<4x5xi16, 3> -// %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3> -// store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3> -// dealloc %0 : memref<4x5xi16, 3> -// return -// } -// } +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<4x5xi16, 3> + %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3> + store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3> + dealloc %0 : memref<4x5xi16, 3> + return + } +} + +// CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}} +// CHECK-SAME: !spv.ptr>, Workgroup> +// CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem +// CHECK: %[[VAR:.+]] = spv._address_of @__workgroup_mem__0 +// CHECK: %[[LOC:.+]] = spv.SDiv +// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]] +// CHECK: %{{.+}} = spv.Load "Workgroup" %[[PTR]] : i32 +// CHECK: %[[LOC:.+]] = spv.SDiv +// CHECK: %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]] +// CHECK: %{{.+}} = spv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr +// CHECK: %{{.+}} = spv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr + // ----- -- 2.7.4