[mlir][spirv] Fix verification of nested array constants
authorSergei Grechanik <sergei.grechanik@intel.com>
Mon, 7 Feb 2022 20:42:23 +0000 (12:42 -0800)
committerSergei Grechanik <sergei.grechanik@intel.com>
Mon, 7 Feb 2022 21:48:53 +0000 (13:48 -0800)
Fix the verification function of spirv::ConstantOp to allow nesting
array attributes.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D118939

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Target/SPIRV/TranslateRegistration.cpp
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir

index b45b422..cb476dc 100644 (file)
@@ -1737,17 +1737,13 @@ static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
     printer << " : " << constOp.getType();
 }
 
-LogicalResult spirv::ConstantOp::verify() {
-  auto opType = getType();
-  auto value = valueAttr();
+static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
+                                        Type opType) {
   auto valueType = value.getType();
 
-  // ODS already generates checks to make sure the result type is valid. We just
-  // need to additionally check that the value's attribute type is consistent
-  // with the result type.
   if (value.isa<IntegerAttr, FloatAttr>()) {
     if (valueType != opType)
-      return emitOpError("result type (")
+      return op.emitOpError("result type (")
              << opType << ") does not match value type (" << valueType << ")";
     return success();
   }
@@ -1757,7 +1753,9 @@ LogicalResult spirv::ConstantOp::verify() {
     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
     auto shapedType = valueType.dyn_cast<ShapedType>();
     if (!arrayType)
-      return emitOpError("must have spv.array result type for array value");
+      return op.emitOpError("result or element type (")
+             << opType << ") does not match value type (" << valueType
+             << "), must be the same or spv.array";
 
     int numElements = arrayType.getNumElements();
     auto opElemType = arrayType.getElementType();
@@ -1766,37 +1764,42 @@ LogicalResult spirv::ConstantOp::verify() {
       opElemType = t.getElementType();
     }
     if (!opElemType.isIntOrFloat())
-      return emitOpError("only support nested array result type");
+      return op.emitOpError("only support nested array result type");
 
     auto valueElemType = shapedType.getElementType();
     if (valueElemType != opElemType) {
-      return emitOpError("result element type (")
+      return op.emitOpError("result element type (")
              << opElemType << ") does not match value element type ("
              << valueElemType << ")";
     }
 
     if (numElements != shapedType.getNumElements()) {
-      return emitOpError("result number of elements (")
+      return op.emitOpError("result number of elements (")
              << numElements << ") does not match value number of elements ("
              << shapedType.getNumElements() << ")";
     }
     return success();
   }
-  if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
+  if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
     if (!arrayType)
-      return emitOpError("must have spv.array result type for array value");
+      return op.emitOpError("must have spv.array result type for array value");
     Type elemType = arrayType.getElementType();
-    for (Attribute element : attayAttr.getValue()) {
-      if (element.getType() != elemType)
-        return emitOpError("has array element whose type (")
-               << element.getType()
-               << ") does not match the result element type (" << elemType
-               << ')';
+    for (Attribute element : arrayAttr.getValue()) {
+      // Verify array elements recursively.
+      if (failed(verifyConstantType(op, element, elemType)))
+        return failure();
     }
     return success();
   }
-  return emitOpError("cannot have value of type ") << valueType;
+  return op.emitOpError("cannot have value of type ") << valueType;
+}
+
+LogicalResult spirv::ConstantOp::verify() {
+  // ODS already generates checks to make sure the result type is valid. We just
+  // need to additionally check that the value's attribute type is consistent
+  // with the result type.
+  return verifyConstantType(*this, valueAttr(), getType());
 }
 
 bool spirv::ConstantOp::isBuildableWith(Type type) {
index aee4ad3..e9d4f34 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/Verifier.h"
 #include "mlir/Parser.h"
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Target/SPIRV/Deserialization.h"
@@ -151,6 +152,8 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
       FileLineColLoc::get(&deserializationContext,
                           /*filename=*/"", /*line=*/0, /*column=*/0)));
   dstModule->getBody()->push_front(spirvModule.release());
+  if (failed(verify(*dstModule)))
+    return failure();
   dstModule->print(output);
 
   return mlir::success();
index 798b843..ee09bee 100644 (file)
@@ -72,6 +72,7 @@ func @const() -> () {
   %6 = spv.Constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
   %7 = spv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
   %8 = spv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
+  %9 = spv.Constant [[dense<3.0> : vector<2xf32>]] : !spv.array<1 x !spv.array<1xvector<2xf32>>>
   return
 }
 
@@ -86,7 +87,7 @@ func @unaccepted_std_attr() -> () {
 // -----
 
 func @array_constant() -> () {
-  // expected-error @+1 {{has array element whose type ('vector<2xi32>') does not match the result element type ('vector<2xf32>')}}
+  // expected-error @+1 {{result or element type ('vector<2xf32>') does not match value type ('vector<2xi32>')}}
   %0 = spv.Constant [dense<3.0> : vector<2xf32>, dense<4> : vector<2xi32>] : !spv.array<2xvector<2xf32>>
   return
 }
@@ -110,7 +111,7 @@ func @non_nested_array_constant() -> () {
 // -----
 
 func @value_result_type_mismatch() -> () {
-  // expected-error @+1 {{must have spv.array result type for array value}}
+  // expected-error @+1 {{result or element type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}}
   %0 = "spv.Constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
 }