From fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 19 Oct 2022 05:49:08 +0000 Subject: [PATCH] [mlir][spirv] Consider target when converting one-element vector Vectors with just one element will be converted into scalars. However, we cannot just return the element types and assume it is supported in the target environment; we need to conver the element type again factoring in those considerations. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D136226 --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 9 +++++---- mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir | 4 ++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 083b997..2514cfe 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" @@ -239,8 +240,9 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, Optional storageClass = {}) { + auto scalarType = type.getElementType().cast(); if (type.getRank() <= 1 && type.getNumElements() == 1) - return type.getElementType(); + return convertScalarType(targetEnv, options, scalarType, storageClass); if (!spirv::CompositeType::isValid(type)) { // TODO: Vector types with more than four elements can be translated into @@ -260,9 +262,8 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv, succeeded(checkExtensionRequirements(type, targetEnv, extensions))) return type; - auto elementType = convertScalarType( - targetEnv, options, type.getElementType().cast(), - storageClass); + auto elementType = + convertScalarType(targetEnv, options, scalarType, storageClass); if (elementType) return VectorType::get(type.getShape(), elementType); return nullptr; diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 799d8c3..4f1cd09 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -207,6 +207,10 @@ func.func @float_vector( %arg1: vector<3xf64> ) { return } +// CHECK-LABEL: spirv.func @one_element_vector +// CHECK-SAME: %{{.+}}: i32 +func.func @one_element_vector(%arg0: vector<1xi8>) { return } + } // end module // ----- -- 2.7.4