[mlir][spirv] Consider target when converting one-element vector
authorLei Zhang <antiagainst@google.com>
Wed, 19 Oct 2022 05:49:08 +0000 (05:49 +0000)
committerLei Zhang <antiagainst@google.com>
Wed, 19 Oct 2022 05:49:32 +0000 (05:49 +0000)
Vectors with just one element will be converted into scalars.
However, we cannot just return the element types and assume it
is supported in the target environment; we need to conver the
element type again factoring in those considerations.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D136226

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

index 083b997..2514cfe 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
@@ -239,8 +240,9 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
                               const SPIRVConversionOptions &options,
                               VectorType type,
                               Optional<spirv::StorageClass> storageClass = {}) {
+  auto scalarType = type.getElementType().cast<spirv::ScalarType>();
   if (type.getRank() <= 1 && type.getNumElements() == 1)
-    return type.getElementType();
+    return convertScalarType(targetEnv, options, scalarType, storageClass);
 
   if (!spirv::CompositeType::isValid(type)) {
     // TODO: Vector types with more than four elements can be translated into
@@ -260,9 +262,8 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
     return type;
 
-  auto elementType = convertScalarType(
-      targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
-      storageClass);
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
   if (elementType)
     return VectorType::get(type.getShape(), elementType);
   return nullptr;
index 799d8c3..4f1cd09 100644 (file)
@@ -207,6 +207,10 @@ func.func @float_vector(
   %arg1: vector<3xf64>
 ) { return }
 
+// CHECK-LABEL: spirv.func @one_element_vector
+// CHECK-SAME: %{{.+}}: i32
+func.func @one_element_vector(%arg0: vector<1xi8>) { return }
+
 } // end module
 
 // -----