/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
static bool isAllocationSupported(MemRefType t) {
// Currently only support workgroup local memory allocations with static
- // shape and int or float element type.
- return t.hasStaticShape() &&
- SPIRVTypeConverter::getMemorySpaceForStorageClass(
- spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
- t.getElementType().isIntOrFloat();
+ // shape and int or float or vector of int or float element type.
+ if (!(t.hasStaticShape() &&
+ SPIRVTypeConverter::getMemorySpaceForStorageClass(
+ spirv::StorageClass::Workgroup) == t.getMemorySpace()))
+ return false;
+ Type elementType = t.getElementType();
+ if (auto vecType = elementType.dyn_cast<VectorType>())
+ elementType = vecType.getElementType();
+ return elementType.isIntOrFloat();
}
/// Returns the scope to use for atomic operations use for emulating store
return llvm::None;
}
return bitWidth / 8;
- } else if (auto memRefType = t.dyn_cast<MemRefType>()) {
+ }
+ if (auto vecType = t.dyn_cast<VectorType>()) {
+ auto elementSize = getTypeNumBytes(vecType.getElementType());
+ if (!elementSize)
+ return llvm::None;
+ return vecType.getNumElements() * *elementSize;
+ }
+ if (auto memRefType = t.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
return llvm::None;
}
- auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
- if (!scalarType) {
- LLVM_DEBUG(llvm::dbgs()
- << type << " illegal: cannot convert non-scalar element type\n");
+ Optional<Type> arrayElemType;
+ Type elementType = type.getElementType();
+ if (auto vecType = elementType.dyn_cast<VectorType>()) {
+ arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
+ } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
+ arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
+ } else {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << type
+ << " unhandled: can only convert scalar or vector element type\n");
return llvm::None;
}
-
- auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
if (!arrayElemType)
return llvm::None;
- Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
- if (!scalarSize) {
+ Optional<int64_t> elementSize = getTypeNumBytes(elementType);
+ if (!elementSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
}
if (!type.hasStaticShape()) {
- auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
+ auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
return llvm::None;
}
- auto arrayElemCount = *memrefSize / *scalarSize;
+ auto arrayElemCount = *memrefSize / *elementSize;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
if (!arrayElemSize) {
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
}
{
+ func @two_allocs_vector() {
+ %0 = alloc() : memref<4xvector<4xf32>, 3>
+ %1 = alloc() : memref<2xvector<2xi32>, 3>
+ return
+ }
+}
+
+// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<2 x vector<2xi32>, stride=8>>, Workgroup>
+// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16>>, Workgroup>
+// CHECK: spv.func @two_allocs_vector()
+// CHECK: spv.Return
+
+
+// -----
+
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+ }
+{
func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) {
// expected-error @+2 {{unhandled allocation type}}
// expected-error @+1 {{'std.alloc' op operand #0 must be index}}
// -----
+// Vector types
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_vector
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<2xf32>, stride=8> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16> [0]>, Uniform>
+func @memref_vector(
+ %arg0: memref<4xvector<2xf32>, 0>,
+ %arg1: memref<4xvector<4xf32>, 4>)
+{ return }
+
+// CHECK-LABEL: func @dynamic_dim_memref_vector
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<4xi32>, stride=16> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<2xf32>, stride=8> [0]>, StorageBuffer>
+func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
+ %arg1: memref<?x?xvector<2xf32>>)
+{ return }
+
+} // end module
+
+// -----
+
+// Vector types, check that sizes not available in SPIR-V are not transformed.
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_vector_wrong_size
+// CHECK-SAME: memref<4xvector<5xf32>>
+func @memref_vector_wrong_size(
+ %arg0: memref<4xvector<5xf32>, 0>)
+{ return }
+
+} // end module
+
+// -----
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//