[mlir][spirv] Validate float type bitwidths
authorJakub Kuderski <kubak@google.com>
Tue, 14 Feb 2023 22:48:43 +0000 (17:48 -0500)
committerJakub Kuderski <kubak@google.com>
Tue, 14 Feb 2023 22:50:38 +0000 (17:50 -0500)
Not all float types are supported in SPIR-V.

Fixes: https://github.com/llvm/llvm-project/issues/60199

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D144043

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

index 7dc3fde..d2f2cb8 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/STLExtras.h"
@@ -574,19 +575,12 @@ bool ScalarType::classof(Type type) {
   return false;
 }
 
-bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
+bool ScalarType::isValid(FloatType type) {
+  return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
+}
 
 bool ScalarType::isValid(IntegerType type) {
-  switch (type.getWidth()) {
-  case 1:
-  case 8:
-  case 16:
-  case 32:
-  case 64:
-    return true;
-  default:
-    return false;
-  }
+  return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
 }
 
 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
index d207ecd..37ea7bf 100644 (file)
@@ -37,12 +37,19 @@ func.func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }
 // CHECK-SAME: i64
 // CHECK-SAME: si64
 // CHECK-SAME: ui64
-// NOEMU-LABEL: func @integer64
+// NOEMU-LABEL: func.func @integer64
 // NOEMU-SAME: i64
 // NOEMU-SAME: si64
 // NOEMU-SAME: ui64
 func.func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }
 
+// i128 is not supported by SPIR-V.
+// CHECK-LABEL: func.func @integer128
+// CHECK-SAME: i128
+// NOEMU-LABEL: func.func @integer128
+// NOEMU-SAME: i128
+func.func @integer128(%arg0: i128) { return }
+
 } // end module
 
 // -----
@@ -143,6 +150,13 @@ func.func @float16(%arg0: f16) { return }
 // NOEMU-SAME: f64
 func.func @float64(%arg0: f64) { return }
 
+// f80 is not supported by SPIR-V.
+// CHECK-LABEL: func.func @float80
+// CHECK-SAME: f80
+// NOEMU-LABEL: func.func @float80
+// NOEMU-SAME: f80
+func.func @float80(%arg0: f80) { return }
+
 } // end module
 
 // -----