#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include <functional>
+#include <optional>
#define DEBUG_TYPE "mlir-spirv-conversion"
return bitWidth / 8;
}
+ if (auto complexType = type.dyn_cast<ComplexType>()) {
+ auto elementSize = getTypeNumBytes(options, complexType.getElementType());
+ if (!elementSize)
+ return std::nullopt;
+ return 2 * *elementSize;
+ }
+
if (auto vecType = type.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
if (!elementSize)
return nullptr;
}
+static Type
+convertComplexType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options, ComplexType type,
+ std::optional<spirv::StorageClass> storageClass = {}) {
+ auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
+ if (!scalarType) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot convert non-scalar element type\n");
+ return nullptr;
+ }
+
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
+ if (!elementType)
+ return nullptr;
+ if (elementType != type.getElementType()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: complex type emulation unsupported\n");
+ return nullptr;
+ }
+
+ return VectorType::get(2, elementType);
+}
+
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
///
/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
return nullptr;
}
-
if (!type.hasStaticShape()) {
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
// to the element.
if (auto vecType = elementType.dyn_cast<VectorType>()) {
arrayElemType =
convertVectorType(targetEnv, options, vecType, storageClass);
+ } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+ arrayElemType =
+ convertComplexType(targetEnv, options, complexType, storageClass);
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
return nullptr;
}
-
if (!type.hasStaticShape()) {
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
// to the element.
return Type();
});
+ addConversion([this](ComplexType complexType) {
+ return convertComplexType(this->targetEnv, this->options, complexType);
+ });
+
addConversion([this](VectorType vectorType) {
return convertVectorType(this->targetEnv, this->options, vectorType);
});
// -----
//===----------------------------------------------------------------------===//
+// Complex types
+//===----------------------------------------------------------------------===//
+
+// Check that capabilities for scalar types affects complex types too: having
+// special capabilities means keep vector types untouched.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [Float64, StorageUniform16, StorageBuffer16BitAccess],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @complex_types
+// CHECK-SAME: vector<2xf32>
+// CHECK-SAME: vector<2xf64>
+func.func @complex_types(
+ %arg0: complex<f32>,
+ %arg2: complex<f64>
+) { return }
+
+// CHECK-LABEL: func @memref_complex_types_with_cap
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<2xf16>, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<16 x vector<2xf16>, stride=4> [0])>, Uniform>
+func.func @memref_complex_types_with_cap(
+ %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+ %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that capabilities for scalar types affects complex types too: no special
+// capabilities available means widening element types to 32-bit.
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Emulation is unimplemented right now.
+// CHECK-LABEL: func @memref_complex_types_no_cap
+// CHECK-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+// NOEMU-LABEL: func @memref_complex_types_no_cap
+// NOEMU-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// NOEMU-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+func.func @memref_complex_types_no_cap(
+ %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+ %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
// Vector types
//===----------------------------------------------------------------------===//