[mlir][OpAsmFormat] Add support for an "else" group on optional elements
authorRiver Riddle <riddleriver@gmail.com>
Tue, 23 Mar 2021 01:07:09 +0000 (18:07 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 23 Mar 2021 01:19:23 +0000 (18:19 -0700)
The "else" group of an optional element is a collection of elements that get parsed/printed when the anchor of the main element group is *not* present. This is useful when there is a special syntax when an element is not present. The new syntax for an optional element is shown below:

```
optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
```

An example of how this might be used is shown below:

```tablegen
def FooOp : ... {
  let arguments = (ins UnitAttr:$foo);

  let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?";
}
```

would be formatted as such:

```mlir
// When the `foo` attribute is present:
foo.op foo_is_present

// When the `foo` attribute is not present:
foo.op foo_is_absent
```

Differential Revision: https://reviews.llvm.org/D99129

mlir/docs/OpDefinitions.md
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp

index 63b727a..5f41358 100644 (file)
@@ -772,8 +772,13 @@ When a variable is optional, the provided value may be null.
 In certain situations operations may have "optional" information, e.g.
 attributes or an empty set of variadic operands. In these situations a section
 of the assembly format can be marked as `optional` based on the presence of this
-information. An optional group is defined by wrapping a set of elements within
-`()` followed by a `?` and has the following requirements:
+information. An optional group is defined as follows:
+
+```
+optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
+```
+
+The `elements` of an optional group have the following requirements:
 
 *   The first element of the group must either be a attribute, literal, operand,
     or region.
@@ -837,6 +842,32 @@ foo.op is_read_only
 foo.op
 ```
 
+##### Optional "else" Group
+
+Optional groups also have support for an "else" group of elements. These are
+elements that are parsed/printed if the `anchor` element of the optional group
+is *not* present. Unlike the main element group, the "else" group has no
+restriction on the first element and none of the elements may act as the
+`anchor` for the optional. An example is shown below:
+
+```tablegen
+def FooOp : ... {
+  let arguments = (ins UnitAttr:$foo);
+
+  let assemblyFormat = "attr-dict (`foo_is_present` $foo^):(`foo_is_absent`)?";
+}
+```
+
+would be formatted as such:
+
+```mlir
+// When the `foo` attribute is present:
+foo.op foo_is_present
+
+// When the `foo` attribute is not present:
+foo.op foo_is_absent
+```
+
 #### Requirements
 
 The format specification has a certain set of requirements that must be adhered
index 7d48f8d..8be84f2 100644 (file)
@@ -1651,6 +1651,11 @@ def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
   let assemblyFormat = "($attr^)? attr-dict";
 }
 
+def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
+  let arguments = (ins UnitAttr:$isFirstBranchPresent);
+  let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Custom Directives
 
index 4f5ca63..8c6bb09 100644 (file)
@@ -390,6 +390,18 @@ def OptionalInvalidL : TestFormat_Op<[{
 def OptionalInvalidM : TestFormat_Op<[{
   (` `^)?
 }]>, Arguments<(ins)>;
+// CHECK: error: expected '(' to start else branch of optional group
+def OptionalInvalidN : TestFormat_Op<[{
+  ($arg^):
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: expected directive, literal, variable, or optional group
+def OptionalInvalidO : TestFormat_Op<[{
+  ($arg^):(`test`
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: expected '?' after optional group
+def OptionalInvalidP : TestFormat_Op<[{
+  ($arg^):(`test`)
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
 
 // CHECK-NOT: error
 def OptionalValidA : TestFormat_Op<[{
index 8043786..e6f998f 100644 (file)
@@ -240,6 +240,16 @@ test.format_optional_result_b_op : i64 -> i64, i64
 test.format_optional_result_c_op : (i64) -> (i64, i64)
 
 //===----------------------------------------------------------------------===//
+// Format optional with else
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_optional_else then
+test.format_optional_else then
+
+// CHECK: test.format_optional_else else
+test.format_optional_else else
+
+//===----------------------------------------------------------------------===//
 // Format custom directives
 //===----------------------------------------------------------------------===//
 
index f474bbf..abf77a5 100644 (file)
@@ -348,29 +348,41 @@ private:
 
 namespace {
 /// This class represents a group of elements that are optionally emitted based
-/// upon an optional variable of the operation.
+/// upon an optional variable of the operation, and a group of elements that are
+/// emotted when the anchor element is not present.
 class OptionalElement : public Element {
 public:
-  OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
+  OptionalElement(std::vector<std::unique_ptr<Element>> &&thenElements,
+                  std::vector<std::unique_ptr<Element>> &&elseElements,
                   unsigned anchor, unsigned parseStart)
-      : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
+      : Element{Kind::Optional}, thenElements(std::move(thenElements)),
+        elseElements(std::move(elseElements)), anchor(anchor),
         parseStart(parseStart) {}
   static bool classof(const Element *element) {
     return element->getKind() == Kind::Optional;
   }
 
-  /// Return the nested elements of this grouping.
-  auto getElements() const { return llvm::make_pointee_range(elements); }
+  /// Return the `then` elements of this grouping.
+  auto getThenElements() const {
+    return llvm::make_pointee_range(thenElements);
+  }
+
+  /// Return the `else` elements of this grouping.
+  auto getElseElements() const {
+    return llvm::make_pointee_range(elseElements);
+  }
 
   /// Return the anchor of this optional group.
-  Element *getAnchor() const { return elements[anchor].get(); }
+  Element *getAnchor() const { return thenElements[anchor].get(); }
 
   /// Return the index of the first element that needs to be parsed.
   unsigned getParseStart() const { return parseStart; }
 
 private:
-  /// The child elements of this optional.
-  std::vector<std::unique_ptr<Element>> elements;
+  /// The child elements of `then` branch of this optional.
+  std::vector<std::unique_ptr<Element>> thenElements;
+  /// The child elements of `else` branch of this optional.
+  std::vector<std::unique_ptr<Element>> elseElements;
   /// The index of the element that acts as the anchor for the optional group.
   unsigned anchor;
   /// The index of the first element that is parsed (is not a
@@ -792,7 +804,7 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
 /// Generate the storage code required for parsing the given element.
 static void genElementParserStorage(Element *element, OpMethodBody &body) {
   if (auto *optional = dyn_cast<OptionalElement>(element)) {
-    auto elements = optional->getElements();
+    auto elements = optional->getThenElements();
 
     // If the anchor is a unit attribute, it won't be parsed directly so elide
     // it.
@@ -803,6 +815,8 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
     for (auto &childElement : elements)
       if (&childElement != elidedAnchorElement)
         genElementParserStorage(&childElement, body);
+    for (auto &childElement : optional->getElseElements())
+      genElementParserStorage(&childElement, body);
 
   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
     for (auto &paramElement : custom->getArguments())
@@ -1094,8 +1108,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
                                        FmtContext &attrTypeCtx) {
   /// Optional Group.
   if (auto *optional = dyn_cast<OptionalElement>(element)) {
-    auto elements =
-        llvm::drop_begin(optional->getElements(), optional->getParseStart());
+    auto elements = llvm::drop_begin(optional->getThenElements(),
+                                     optional->getParseStart());
 
     // Generate a special optional parser for the first element to gate the
     // parsing of the rest of the elements.
@@ -1140,7 +1154,17 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
       if (&childElement != elidedAnchorElement)
         genElementParser(&childElement, body, attrTypeCtx);
     }
-    body << "  }\n";
+    body << "  }";
+
+    // Generate the else elements.
+    auto elseElements = optional->getElseElements();
+    if (!elseElements.empty()) {
+      body << " else {\n";
+      for (Element &childElement : elseElements)
+        genElementParser(&childElement, body, attrTypeCtx);
+      body << "  }";
+    }
+    body << "\n";
 
     /// Literals.
   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
@@ -1778,7 +1802,7 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
 
     // If the anchor is a unit attribute, we don't need to print it. When
     // parsing, we will add this attribute if this group is present.
-    auto elements = optional->getElements();
+    auto elements = optional->getThenElements();
     Element *elidedAnchorElement = nullptr;
     auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
     if (anchorAttr && anchorAttr != &*elements.begin() &&
@@ -1793,7 +1817,20 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
                           lastWasPunctuation);
       }
     }
-    body << "  }\n";
+    body << "  }";
+
+    // Emit each of the else elements.
+    auto elseElements = optional->getElseElements();
+    if (!elseElements.empty()) {
+      body << " else {\n";
+      for (Element &childElement : elseElements) {
+        genElementPrinter(&childElement, body, op, shouldEmitSpace,
+                          lastWasPunctuation);
+      }
+      body << "  }";
+    }
+
+    body << "\n";
     return;
   }
 
@@ -1911,6 +1948,7 @@ public:
     l_paren,
     r_paren,
     caret,
+    colon,
     comma,
     equal,
     less,
@@ -2065,6 +2103,8 @@ Token FormatLexer::lexToken() {
   // Lex punctuation.
   case '^':
     return formToken(Token::caret, tokStart);
+  case ':':
+    return formToken(Token::colon, tokStart);
   case ',':
     return formToken(Token::comma, tokStart);
   case '=':
@@ -2393,8 +2433,11 @@ LogicalResult FormatParser::verifyAttributes(
 
     // Traverse into optional groups.
     if (auto *optional = dyn_cast<OptionalElement>(element)) {
-      auto elements = optional->getElements();
-      iteratorStack.emplace_back(elements.begin(), elements.end());
+      auto thenElements = optional->getThenElements();
+      iteratorStack.emplace_back(thenElements.begin(), thenElements.end());
+
+      auto elseElements = optional->getElseElements();
+      iteratorStack.emplace_back(elseElements.begin(), elseElements.end());
       return ::mlir::success();
     }
 
@@ -2795,13 +2838,31 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
   consumeToken();
 
   // Parse the child elements for this optional group.
-  std::vector<std::unique_ptr<Element>> elements;
+  std::vector<std::unique_ptr<Element>> thenElements, elseElements;
   Optional<unsigned> anchorIdx;
   do {
-    if (failed(parseOptionalChildElement(elements, anchorIdx)))
+    if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
       return ::mlir::failure();
   } while (curToken.getKind() != Token::r_paren);
   consumeToken();
+
+  // Parse the `else` elements of this optional group.
+  if (curToken.getKind() == Token::colon) {
+    consumeToken();
+    if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
+                                          "of optional group")))
+      return failure();
+    do {
+      llvm::SMLoc childLoc = curToken.getLoc();
+      elseElements.push_back({});
+      if (failed(parseElement(elseElements.back(), TopLevelContext)) ||
+          failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
+                                            /*isAnchor=*/false)))
+        return failure();
+    } while (curToken.getKind() != Token::r_paren);
+    consumeToken();
+  }
+
   if (failed(parseToken(Token::question, "expected '?' after optional group")))
     return ::mlir::failure();
 
@@ -2811,7 +2872,7 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
 
   // The first parsable element of the group must be able to be parsed in an
   // optional fashion.
-  auto parseBegin = llvm::find_if_not(elements, [](auto &element) {
+  auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) {
     return isa<WhitespaceElement>(element.get());
   });
   Element *firstElement = parseBegin->get();
@@ -2822,9 +2883,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
                      "first parsable element of an operand group must be "
                      "an attribute, literal, operand, or region");
 
-  auto parseStart = parseBegin - elements.begin();
-  element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
-                                              parseStart);
+  auto parseStart = parseBegin - thenElements.begin();
+  element = std::make_unique<OptionalElement>(
+      std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart);
   return ::mlir::success();
 }