[spirv] Add array length check.
authorDenis Khalikov <khalikov.denis@huawei.com>
Mon, 30 Sep 2019 23:42:46 +0000 (16:42 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 30 Sep 2019 23:43:26 +0000 (16:43 -0700)
According to the SPIR-V spec:
"Length is the number of elements in the array. It must be at least 1."

Closes tensorflow/mlir#160

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/160 from denis0x0D:sandbox/array_len 0840dc0986ad0088a3aa7d5d8d3e97d489377ed9
PiperOrigin-RevId: 272094669

mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/test/Dialect/SPIRV/types.mlir

index 251d15b..4237f00 100644 (file)
@@ -194,6 +194,13 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
     return Type();
   }
 
+  // According to the SPIR-V spec:
+  // "Length is the number of elements in the array. It must be at least 1."
+  if (!count) {
+    emitError(loc, "expected array length greater than 0");
+    return Type();
+  }
+
   if (spec.trim().empty()) {
     emitError(loc, "expected element type");
     return Type();
index 8efae87..66765be 100644 (file)
@@ -56,12 +56,14 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
 };
 
 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
+  assert(elementCount && "ArrayType needs at least one element");
   return Base::get(elementType.getContext(), TypeKind::Array, elementType,
                    elementCount, 0);
 }
 
 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
                          ArrayType::LayoutInfo layoutInfo) {
+  assert(elementCount && "ArrayType needs at least one element");
   return Base::get(elementType.getContext(), TypeKind::Array, elementType,
                    elementCount, layoutInfo);
 }
index 4cf9835..ab60ddc 100644 (file)
@@ -82,6 +82,11 @@ func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
 
 // -----
 
+// expected-error @+1 {{expected array length greater than 0}}
+func @array_type_zero_length(!spv.array<0xf32>) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // PointerType
 //===----------------------------------------------------------------------===//