printer << " : " << constOp.getType();
}
-LogicalResult spirv::ConstantOp::verify() {
- auto opType = getType();
- auto value = valueAttr();
+static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
+ Type opType) {
auto valueType = value.getType();
- // ODS already generates checks to make sure the result type is valid. We just
- // need to additionally check that the value's attribute type is consistent
- // with the result type.
if (value.isa<IntegerAttr, FloatAttr>()) {
if (valueType != opType)
- return emitOpError("result type (")
+ return op.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
}
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
auto shapedType = valueType.dyn_cast<ShapedType>();
if (!arrayType)
- return emitOpError("must have spv.array result type for array value");
+ return op.emitOpError("result or element type (")
+ << opType << ") does not match value type (" << valueType
+ << "), must be the same or spv.array";
int numElements = arrayType.getNumElements();
auto opElemType = arrayType.getElementType();
opElemType = t.getElementType();
}
if (!opElemType.isIntOrFloat())
- return emitOpError("only support nested array result type");
+ return op.emitOpError("only support nested array result type");
auto valueElemType = shapedType.getElementType();
if (valueElemType != opElemType) {
- return emitOpError("result element type (")
+ return op.emitOpError("result element type (")
<< opElemType << ") does not match value element type ("
<< valueElemType << ")";
}
if (numElements != shapedType.getNumElements()) {
- return emitOpError("result number of elements (")
+ return op.emitOpError("result number of elements (")
<< numElements << ") does not match value number of elements ("
<< shapedType.getNumElements() << ")";
}
return success();
}
- if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
+ if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
if (!arrayType)
- return emitOpError("must have spv.array result type for array value");
+ return op.emitOpError("must have spv.array result type for array value");
Type elemType = arrayType.getElementType();
- for (Attribute element : attayAttr.getValue()) {
- if (element.getType() != elemType)
- return emitOpError("has array element whose type (")
- << element.getType()
- << ") does not match the result element type (" << elemType
- << ')';
+ for (Attribute element : arrayAttr.getValue()) {
+ // Verify array elements recursively.
+ if (failed(verifyConstantType(op, element, elemType)))
+ return failure();
}
return success();
}
- return emitOpError("cannot have value of type ") << valueType;
+ return op.emitOpError("cannot have value of type ") << valueType;
+}
+
+LogicalResult spirv::ConstantOp::verify() {
+ // ODS already generates checks to make sure the result type is valid. We just
+ // need to additionally check that the value's attribute type is consistent
+ // with the result type.
+ return verifyConstantType(*this, valueAttr(), getType());
}
bool spirv::ConstantOp::isBuildableWith(Type type) {
%6 = spv.Constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
%7 = spv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
%8 = spv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
+ %9 = spv.Constant [[dense<3.0> : vector<2xf32>]] : !spv.array<1 x !spv.array<1xvector<2xf32>>>
return
}
// -----
func @array_constant() -> () {
- // expected-error @+1 {{has array element whose type ('vector<2xi32>') does not match the result element type ('vector<2xf32>')}}
+ // expected-error @+1 {{result or element type ('vector<2xf32>') does not match value type ('vector<2xi32>')}}
%0 = spv.Constant [dense<3.0> : vector<2xf32>, dense<4> : vector<2xi32>] : !spv.array<2xvector<2xf32>>
return
}
// -----
func @value_result_type_mismatch() -> () {
- // expected-error @+1 {{must have spv.array result type for array value}}
+ // expected-error @+1 {{result or element type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}}
%0 = "spv.Constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
}