[mlir][ODS] Verify type constraint in `TypeAttrOf`
authorMarkus Böck <markus.boeck02@gmail.com>
Fri, 7 Apr 2023 10:29:55 +0000 (12:29 +0200)
committerMarkus Böck <markus.boeck02@gmail.com>
Fri, 7 Apr 2023 10:30:15 +0000 (12:30 +0200)
The current implementation does not verify the type constraint, meaning that any type that happens to be of the same C++ type would pass the verifier.
E.g. a `TypeAttrOf<I64>` would happily accept a `i32` since both satisfy `isa<IntegerType>()`.

This patch fixes that by adding an optional type predicate parameter to `TypeAttrBase` that the type within `TypeAttr` has to satisfy. `TypeAttrOf` then simply passes the predicate of its type parameter as argument.

Differential Revision: https://reviews.llvm.org/D147778

mlir/include/mlir/IR/OpBase.td
mlir/test/IR/attribute.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-attribute.td

index 98866c8..f4aa07f 100644 (file)
@@ -1296,11 +1296,14 @@ class TypedStrAttr<Type ty>
 // Base class for attributes containing types. Example:
 //   def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
 // defines a type attribute containing an integer type.
-class TypeAttrBase<string retType, string summary> :
+class TypeAttrBase<string retType, string summary,
+                        Pred typePred = CPred<"true">> :
     Attr<And<[
       CPred<"$_self.isa<::mlir::TypeAttr>()">,
       CPred<"$_self.cast<::mlir::TypeAttr>().getValue().isa<"
-            # retType # ">()">]>,
+            # retType # ">()">,
+      SubstLeaves<"$_self",
+                    "$_self.cast<::mlir::TypeAttr>().getValue()", typePred>]>,
     summary> {
   let storageType = [{ ::mlir::TypeAttr }];
   let returnType = retType;
@@ -1313,7 +1316,8 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
 }
 
 class TypeAttrOf<Type ty>
-   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary> {
+   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+                    ty.predicate> {
   let constBuilderCall = "::mlir::TypeAttr::get($0)";
 }
 
index c296507..25d237a 100644 (file)
@@ -879,3 +879,11 @@ func.func @default_value_printing(%arg0 : i32) {
   "test.default_value_print"(%arg0) {"value_with_default" = 1 : i32} : (i32) -> ()
   return
 }
+
+// -----
+
+func.func @type_attr_of_fail() {
+    // expected-error @below {{failed to satisfy constraint: type attribute of 64-bit signless integer}}
+    test.type_attr_of i32
+    return
+}
index ff24ac9..0306f0e 100644 (file)
@@ -277,6 +277,13 @@ def TypedAttrOp : TEST_Op<"typed_attr"> {
   }];
 }
 
+def TypeAttrOfOp : TEST_Op<"type_attr_of"> {
+  let arguments = (ins TypeAttrOf<I64>:$type);
+  let assemblyFormat = [{
+    attr-dict $type
+  }];
+}
+
 def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
   let arguments = (ins
     DenseBoolArrayAttr:$i1attr,
index 3dc426c..af1f622 100644 (file)
@@ -318,10 +318,10 @@ def BOp : NS_Op<"b_op", []> {
 // DEF: if (tblgen_str_attr && !((tblgen_str_attr.isa<::mlir::StringAttr>())))
 // DEF: if (tblgen_elements_attr && !((tblgen_elements_attr.isa<::mlir::ElementsAttr>())))
 // DEF: if (tblgen_function_attr && !((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>())))
-// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>()))))
+// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>())) && ((true))))
 // DEF: if (tblgen_array_attr && !((tblgen_array_attr.isa<::mlir::ArrayAttr>())))
 // DEF: if (tblgen_some_attr_array && !(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return attr && ((some-condition)); }))))
-// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>()))))
+// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())) && ((true))))
 
 // Test common attribute kind getters' return types
 // ---