From aa9ae76cac0443b7d70b27ae2c0bf9cf92f344d3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Markus=20B=C3=B6ck?= Date: Fri, 7 Apr 2023 12:29:55 +0200 Subject: [PATCH] [mlir][ODS] Verify type constraint in `TypeAttrOf` The current implementation does not verify the type constraint, meaning that any type that happens to be of the same C++ type would pass the verifier. E.g. a `TypeAttrOf` would happily accept a `i32` since both satisfy `isa()`. This patch fixes that by adding an optional type predicate parameter to `TypeAttrBase` that the type within `TypeAttr` has to satisfy. `TypeAttrOf` then simply passes the predicate of its type parameter as argument. Differential Revision: https://reviews.llvm.org/D147778 --- mlir/include/mlir/IR/OpBase.td | 10 +++++++--- mlir/test/IR/attribute.mlir | 8 ++++++++ mlir/test/lib/Dialect/Test/TestOps.td | 7 +++++++ mlir/test/mlir-tblgen/op-attribute.td | 4 ++-- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 98866c8..f4aa07f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1296,11 +1296,14 @@ class TypedStrAttr // Base class for attributes containing types. Example: // def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> // defines a type attribute containing an integer type. -class TypeAttrBase : +class TypeAttrBase> : Attr()">, CPred<"$_self.cast<::mlir::TypeAttr>().getValue().isa<" - # retType # ">()">]>, + # retType # ">()">, + SubstLeaves<"$_self", + "$_self.cast<::mlir::TypeAttr>().getValue()", typePred>]>, summary> { let storageType = [{ ::mlir::TypeAttr }]; let returnType = retType; @@ -1313,7 +1316,8 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> { } class TypeAttrOf - : TypeAttrBase { + : TypeAttrBase { let constBuilderCall = "::mlir::TypeAttr::get($0)"; } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index c2965078..25d237a 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -879,3 +879,11 @@ func.func @default_value_printing(%arg0 : i32) { "test.default_value_print"(%arg0) {"value_with_default" = 1 : i32} : (i32) -> () return } + +// ----- + +func.func @type_attr_of_fail() { + // expected-error @below {{failed to satisfy constraint: type attribute of 64-bit signless integer}} + test.type_attr_of i32 + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ff24ac9..0306f0e 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -277,6 +277,13 @@ def TypedAttrOp : TEST_Op<"typed_attr"> { }]; } +def TypeAttrOfOp : TEST_Op<"type_attr_of"> { + let arguments = (ins TypeAttrOf:$type); + let assemblyFormat = [{ + attr-dict $type + }]; +} + def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { let arguments = (ins DenseBoolArrayAttr:$i1attr, diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 3dc426c..af1f622 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -318,10 +318,10 @@ def BOp : NS_Op<"b_op", []> { // DEF: if (tblgen_str_attr && !((tblgen_str_attr.isa<::mlir::StringAttr>()))) // DEF: if (tblgen_elements_attr && !((tblgen_elements_attr.isa<::mlir::ElementsAttr>()))) // DEF: if (tblgen_function_attr && !((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>()))) -// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) +// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())) && ((true)))) // DEF: if (tblgen_array_attr && !((tblgen_array_attr.isa<::mlir::ArrayAttr>()))) // DEF: if (tblgen_some_attr_array && !(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return attr && ((some-condition)); })))) -// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())))) +// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())) && ((true)))) // Test common attribute kind getters' return types // --- -- 2.7.4