#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
const SPIRVConversionOptions &options,
VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
+ auto scalarType = type.getElementType().cast<spirv::ScalarType>();
if (type.getRank() <= 1 && type.getNumElements() == 1)
- return type.getElementType();
+ return convertScalarType(targetEnv, options, scalarType, storageClass);
if (!spirv::CompositeType::isValid(type)) {
// TODO: Vector types with more than four elements can be translated into
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
return type;
- auto elementType = convertScalarType(
- targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
- storageClass);
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
if (elementType)
return VectorType::get(type.getShape(), elementType);
return nullptr;