[mlir] Enable types to us custom assembly formats involving optional attributes.
authorNick Kreeger <nick.kreeger@gmail.com>
Fri, 23 Dec 2022 15:55:15 +0000 (09:55 -0600)
committerNick Kreeger <nick.kreeger@gmail.com>
Fri, 23 Dec 2022 15:55:15 +0000 (09:55 -0600)
Author: Laszlo Kindrat <laszlokindrat@gmail.com>
Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D140322

mlir/include/mlir/IR/DialectImplementation.h
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir

index e0c7c06..4603187 100644 (file)
@@ -105,6 +105,24 @@ struct FieldParser<std::string> {
   }
 };
 
+/// Parse an Optional attribute.
+template <typename AttributeT>
+struct FieldParser<
+    std::optional<AttributeT>,
+    std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
+                     std::optional<AttributeT>>> {
+  static FailureOr<std::optional<AttributeT>> parse(AsmParser &parser) {
+    AttributeT attr;
+    OptionalParseResult result = parser.parseOptionalAttribute(attr);
+    if (result.has_value()) {
+      if (succeeded(*result))
+        return {std::optional<AttributeT>(attr)};
+      return failure();
+    }
+    return {std::nullopt};
+  }
+};
+
 /// Parse an Optional integer.
 template <typename IntT>
 struct FieldParser<
index f061bd3..81280d8 100644 (file)
@@ -218,9 +218,14 @@ def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
 }
 
 def TestTypeOptionalParam : Test_Type<"TestTypeOptionalParam"> {
-  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a, "int":$b);
+  let parameters = (ins
+    OptionalParameter<"mlir::Optional<int>">:$a,
+    "int":$b,
+    DefaultValuedParameter<"std::optional<::mlir::Attribute>",
+                           "std::nullopt">:$c
+  );
   let mnemonic = "optional_param";
-  let assemblyFormat = "`<` $a `,` $b `>`";
+  let assemblyFormat = "`<` $a `,` $b ( `,` $c^)? `>`";
 }
 
 def TestTypeOptionalParams : Test_Type<"TestTypeOptionalParams"> {
index 80627e7..e63b5de 100644 (file)
@@ -34,6 +34,8 @@ attributes {
 // CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
 // CHECK: !test.optional_param<, 6>
 // CHECK: !test.optional_param<5, 6>
+// CHECK: !test.optional_param<5, 6, "foo">
+// CHECK: !test.optional_param<5, 6, {foo = "bar"}>
 // CHECK: !test.optional_params<"a">
 // CHECK: !test.optional_params<5, "a">
 // CHECK: !test.optional_struct<b = "a">
@@ -72,6 +74,8 @@ func.func private @test_roundtrip_default_parsers_struct(
   !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>,
   !test.optional_param<, 6>,
   !test.optional_param<5, 6>,
+  !test.optional_param<5, 6, "foo">,
+  !test.optional_param<5, 6, {foo = "bar"}>,
   !test.optional_params<"a">,
   !test.optional_params<5, "a">,
   !test.optional_struct<b = "a">,