NFC: Split up Parser::parseAttribute into multiple smaller functions to improve reada...
authorRiver Riddle <riverriddle@google.com>
Mon, 3 Jun 2019 03:33:29 +0000 (20:33 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 4 Jun 2019 02:25:55 +0000 (19:25 -0700)
PiperOrigin-RevId: 251158192

mlir/lib/Parser/Parser.cpp

index 891289d..3f4bf04 100644 (file)
@@ -224,10 +224,25 @@ public:
   ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
 
   /// Parse an extended attribute.
-  Attribute parseExtendedAttribute(Type type);
+  Attribute parseExtendedAttr(Type type);
+
+  /// Parse a float attribute.
+  Attribute parseFloatAttr(Type type, bool isNegative);
+
+  /// Parse an integer attribute.
+  Attribute parseIntegerAttr(Type type, bool isSigned);
+
+  /// Parse an opaque elements attribute.
+  Attribute parseOpaqueElementsAttr();
+
+  /// Parse a sparse elements attribute.
+  Attribute parseSparseElementsAttr();
+
+  /// Parse a splat elements attribute.
+  Attribute parseSplatElementsAttr();
 
   /// Parse a dense elements attribute.
-  DenseElementsAttr parseDenseElementsAttr(ShapedType type);
+  Attribute parseDenseElementsAttr();
   DenseElementsAttr parseDenseElementsAttrAsTensor(Type eltType);
   ShapedType parseElementsLiteralType();
 
@@ -928,142 +943,7 @@ ParseResult Parser::parseXInDimensionList() {
 ///
 Attribute Parser::parseAttribute(Type type) {
   switch (getToken().getKind()) {
-  case Token::hash_identifier:
-    return parseExtendedAttribute(type);
-  case Token::kw_unit:
-    consumeToken(Token::kw_unit);
-    return builder.getUnitAttr();
-
-  case Token::kw_true:
-    consumeToken(Token::kw_true);
-    return builder.getBoolAttr(true);
-  case Token::kw_false:
-    consumeToken(Token::kw_false);
-    return builder.getBoolAttr(false);
-
-  case Token::floatliteral: {
-    auto val = getToken().getFloatingPointValue();
-    if (!val.hasValue())
-      return (emitError("floating point value too large for attribute"),
-              nullptr);
-    auto valTok = getToken().getLoc();
-    consumeToken(Token::floatliteral);
-    if (!type) {
-      if (consumeIf(Token::colon)) {
-        if (!(type = parseType()))
-          return nullptr;
-      } else {
-        // Default to F64 when no type is specified.
-        type = builder.getF64Type();
-      }
-    }
-    if (!type.isa<FloatType>())
-      return (emitError("floating point value not valid for specified type"),
-              nullptr);
-    return FloatAttr::getChecked(type, val.getValue(),
-                                 getEncodedSourceLocation(valTok));
-  }
-  case Token::integer: {
-    auto val = getToken().getUInt64IntegerValue();
-    if (!val.hasValue() || (int64_t)val.getValue() < 0)
-      return (emitError("integer constant out of range for attribute"),
-              nullptr);
-    consumeToken(Token::integer);
-    if (!type) {
-      if (consumeIf(Token::colon)) {
-        if (!(type = parseType()))
-          return nullptr;
-      } else {
-        // Default to i64 if not type is specified.
-        type = builder.getIntegerType(64);
-      }
-    }
-    if (!type.isIntOrIndex())
-      return (emitError("integer value not valid for specified type"), nullptr);
-    int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
-    APInt apInt(width, val.getValue());
-    if (apInt != *val)
-      return emitError("integer constant out of range for attribute"), nullptr;
-    return builder.getIntegerAttr(type, apInt);
-  }
-
-  case Token::minus: {
-    consumeToken(Token::minus);
-    if (getToken().is(Token::integer)) {
-      auto val = getToken().getUInt64IntegerValue();
-      if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
-        return (emitError("integer constant out of range for attribute"),
-                nullptr);
-      consumeToken(Token::integer);
-      if (!type) {
-        if (consumeIf(Token::colon)) {
-          if (!(type = parseType()))
-            return nullptr;
-        } else {
-          // Default to i64 if not type is specified.
-          type = builder.getIntegerType(64);
-        }
-      }
-      if (!type.isIntOrIndex())
-        return (emitError("integer value not valid for type"), nullptr);
-      int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
-      APInt apInt(width, *val, /*isSigned=*/true);
-      if (apInt != *val)
-        return (emitError("integer constant out of range for attribute"),
-                nullptr);
-      return builder.getIntegerAttr(type, -apInt);
-    }
-    if (getToken().is(Token::floatliteral)) {
-      auto val = getToken().getFloatingPointValue();
-      if (!val.hasValue())
-        return (emitError("floating point value too large for attribute"),
-                nullptr);
-      auto valTok = getToken().getLoc();
-      consumeToken(Token::floatliteral);
-      if (!type) {
-        if (consumeIf(Token::colon)) {
-          if (!(type = parseType()))
-            return nullptr;
-        } else {
-          // Default to F64 when no type is specified.
-          type = builder.getF64Type();
-        }
-      }
-      if (!type.isa<FloatType>())
-        return (emitError("floating point value not valid for type"), nullptr);
-      return FloatAttr::getChecked(type, -val.getValue(),
-                                   getEncodedSourceLocation(valTok));
-    }
-
-    return (emitError("expected constant integer or floating point value"),
-            nullptr);
-  }
-
-  case Token::string: {
-    auto val = getToken().getStringValue();
-    consumeToken(Token::string);
-    return builder.getStringAttr(val);
-  }
-
-  case Token::l_brace: {
-    SmallVector<NamedAttribute, 4> elements;
-    if (parseAttributeDict(elements))
-      return nullptr;
-    return builder.getDictionaryAttr(elements);
-  }
-  case Token::l_square: {
-    consumeToken(Token::l_square);
-    SmallVector<Attribute, 4> elements;
-
-    auto parseElt = [&]() -> ParseResult {
-      elements.push_back(parseAttribute());
-      return elements.back() ? success() : failure();
-    };
-
-    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
-      return nullptr;
-    return builder.getArrayAttr(elements);
-  }
+  // Parse an AffineMap or IntegerSet attribute.
   case Token::l_paren: {
     // Try to parse an affine map or an integer set reference.
     AffineMap map;
@@ -1076,170 +956,98 @@ Attribute Parser::parseAttribute(Type type) {
     return builder.getIntegerSetAttr(set);
   }
 
-  case Token::at_identifier: {
-    auto nameStr = getTokenSpelling();
-    consumeToken(Token::at_identifier);
-    return builder.getFunctionAttr(nameStr.drop_front());
-  }
-  case Token::kw_opaque: {
-    consumeToken(Token::kw_opaque);
-    if (parseToken(Token::less, "expected '<' after 'opaque'"))
-      return nullptr;
-
-    if (getToken().getKind() != Token::string)
-      return (emitError("expected dialect namespace"), nullptr);
-    auto name = getToken().getStringValue();
-    auto *dialect = builder.getContext()->getRegisteredDialect(name);
-    // TODO(shpeisman): Allow for having an unknown dialect on an opaque
-    // attribute. Otherwise, it can't be roundtripped without having the dialect
-    // registered.
-    if (!dialect)
-      return (emitError("no registered dialect with namespace '" + name + "'"),
-              nullptr);
-
-    consumeToken(Token::string);
-    if (parseToken(Token::comma, "expected ','"))
-      return nullptr;
-
-    auto type = parseElementsLiteralType();
-    if (!type)
-      return nullptr;
+  // Parse an array attribute.
+  case Token::l_square: {
+    consumeToken(Token::l_square);
 
-    if (parseToken(Token::comma, "expected ',' after elements literal type"))
-      return nullptr;
+    SmallVector<Attribute, 4> elements;
+    auto parseElt = [&]() -> ParseResult {
+      elements.push_back(parseAttribute());
+      return elements.back() ? success() : failure();
+    };
 
-    if (getToken().getKind() != Token::string)
-      return (emitError("opaque string should start with '0x'"), nullptr);
-    auto val = getToken().getStringValue();
-    if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
-      return (emitError("opaque string should start with '0x'"), nullptr);
-    val = val.substr(2);
-    if (!std::all_of(val.begin(), val.end(),
-                     [](char c) { return llvm::isHexDigit(c); })) {
-      return (emitError("opaque string only contains hex digits"), nullptr);
-    }
-    consumeToken(Token::string);
-    if (parseToken(Token::greater, "expected '>'"))
+    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
       return nullptr;
-    return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
+    return builder.getArrayAttr(elements);
   }
-  case Token::kw_splat: {
-    consumeToken(Token::kw_splat);
-    if (parseToken(Token::less, "expected '<' after 'splat'"))
-      return nullptr;
 
-    auto type = parseElementsLiteralType();
-    if (!type)
-      return nullptr;
-    if (parseToken(Token::comma, "expected ',' after elements literal type"))
-      return nullptr;
-    switch (getToken().getKind()) {
-    case Token::floatliteral:
-    case Token::integer:
-    case Token::kw_false:
-    case Token::kw_true:
-    case Token::minus: {
-      auto scalar = parseAttribute(type.getElementType());
-      if (!scalar)
-        return nullptr;
-      if (parseToken(Token::greater, "expected '>'"))
-        return nullptr;
-      return builder.getSplatElementsAttr(type, scalar);
-    }
-    default:
-      return (emitError("expected scalar constant inside tensor literal"),
-              nullptr);
-    }
-  }
-  case Token::kw_dense: {
-    consumeToken(Token::kw_dense);
-    if (parseToken(Token::less, "expected '<' after 'dense'"))
-      return nullptr;
+  // Parse a boolean attribute.
+  case Token::kw_false:
+    consumeToken(Token::kw_false);
+    return builder.getBoolAttr(false);
+  case Token::kw_true:
+    consumeToken(Token::kw_true);
+    return builder.getBoolAttr(true);
 
-    auto type = parseElementsLiteralType();
-    if (!type)
-      return nullptr;
+  // Parse a dense elements attribute.
+  case Token::kw_dense:
+    return parseDenseElementsAttr();
 
-    if (parseToken(Token::comma, "expected ',' after elements literal type"))
+  // Parse a dictionary attribute.
+  case Token::l_brace: {
+    SmallVector<NamedAttribute, 4> elements;
+    if (parseAttributeDict(elements))
       return nullptr;
+    return builder.getDictionaryAttr(elements);
+  }
 
-    auto attr = parseDenseElementsAttr(type);
-    if (!attr)
-      return nullptr;
+  // Parse an extended attribute, i.e. alias or dialect attribute.
+  case Token::hash_identifier:
+    return parseExtendedAttr(type);
 
-    if (parseToken(Token::greater, "expected '>'"))
-      return nullptr;
+  // Parse floating point and integer attributes.
+  case Token::floatliteral:
+    return parseFloatAttr(type, /*isNegative=*/false);
+  case Token::integer:
+    return parseIntegerAttr(type, /*isSigned=*/false);
+  case Token::minus: {
+    consumeToken(Token::minus);
+    if (getToken().is(Token::integer))
+      return parseIntegerAttr(type, /*isSigned=*/true);
+    if (getToken().is(Token::floatliteral))
+      return parseFloatAttr(type, /*isNegative=*/true);
 
-    return attr;
+    return (emitError("expected constant integer or floating point value"),
+            nullptr);
   }
-  case Token::kw_sparse: {
-    consumeToken(Token::kw_sparse);
-    if (parseToken(Token::less, "Expected '<' after 'sparse'"))
-      return nullptr;
 
-    auto type = parseElementsLiteralType();
-    if (!type)
-      return nullptr;
+  // Parse a function attribute.
+  case Token::at_identifier: {
+    auto nameStr = getTokenSpelling();
+    consumeToken(Token::at_identifier);
+    return builder.getFunctionAttr(nameStr.drop_front());
+  }
 
-    if (parseToken(Token::comma, "expected ',' after elements literal type"))
-      return nullptr;
+  // Parse an opaque elements attribute.
+  case Token::kw_opaque:
+    return parseOpaqueElementsAttr();
 
-    switch (getToken().getKind()) {
-    case Token::l_square: {
-      /// Parse indices
-      auto indicesEltType = builder.getIntegerType(64);
-      auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
-      if (!indices)
-        return nullptr;
-
-      if (parseToken(Token::comma, "expected ','"))
-        return nullptr;
+  // Parse a sparse elements attribute.
+  case Token::kw_sparse:
+    return parseSparseElementsAttr();
 
-      /// Parse values.
-      auto valuesEltType = type.getElementType();
-      auto values = parseDenseElementsAttrAsTensor(valuesEltType);
-      if (!values)
-        return nullptr;
+  // Parse a splat elements attribute.
+  case Token::kw_splat:
+    return parseSplatElementsAttr();
 
-      /// Sanity check.
-      auto valuesType = values.getType();
-      if (valuesType.getRank() != 1) {
-        return (emitError("expected 1-d tensor for values"), nullptr);
-      }
-      auto indicesType = indices.getType();
-      auto sameShape = (indicesType.getRank() == 1) ||
-                       (type.getRank() == indicesType.getDimSize(1));
-      auto sameElementNum =
-          indicesType.getDimSize(0) == valuesType.getDimSize(0);
-      if (!sameShape || !sameElementNum) {
-        emitError() << "expected shape ([" << type.getShape()
-                    << "]); inferred shape of indices literal (["
-                    << indicesType.getShape()
-                    << "]); inferred shape of values literal (["
-                    << valuesType.getShape() << "])";
-        return nullptr;
-      }
+  // Parse a string attribute.
+  case Token::string: {
+    auto val = getToken().getStringValue();
+    consumeToken(Token::string);
+    return builder.getStringAttr(val);
+  }
 
-      if (parseToken(Token::greater, "expected '>'"))
-        return nullptr;
+  // Parse a 'unit' attribute.
+  case Token::kw_unit:
+    consumeToken(Token::kw_unit);
+    return builder.getUnitAttr();
 
-      // Build the sparse elements attribute by the indices and values.
-      return builder.getSparseElementsAttr(
-          type, indices.cast<DenseIntElementsAttr>(), values);
-    }
-    default:
-      return (emitError("expected '[' to start sparse tensor literal"),
-              nullptr);
-    }
-    return (emitError("expected elements literal has a tensor or vector type"),
-            nullptr);
-  }
-  default: {
+  default:
+    // Parse a type attribute.
     if (Type type = parseType())
       return builder.getTypeAttr(type);
     return nullptr;
   }
-  }
 }
 
 /// Attribute dictionary.
@@ -1289,7 +1097,7 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
 ///   dialect-attribute  ::= `#` alias-name pretty-dialect-sym-body?
 ///   attribute-alias    ::= `#` alias-name
 ///
-Attribute Parser::parseExtendedAttribute(Type type) {
+Attribute Parser::parseExtendedAttr(Type type) {
   Attribute attr = parseExtendedSymbol<Attribute>(
       *this, Token::hash_identifier, state.attributeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData,
@@ -1313,6 +1121,183 @@ Attribute Parser::parseExtendedAttribute(Type type) {
   return attr;
 }
 
+/// Parse a float attribute.
+Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
+  auto val = getToken().getFloatingPointValue();
+  if (!val.hasValue())
+    return (emitError("floating point value too large for attribute"), nullptr);
+  auto valTok = getToken().getLoc();
+  consumeToken(Token::floatliteral);
+  if (!type) {
+    // Default to F64 when no type is specified.
+    if (!consumeIf(Token::colon))
+      type = builder.getF64Type();
+    else if (!(type = parseType()))
+      return nullptr;
+  }
+  if (!type.isa<FloatType>())
+    return (emitError("floating point value not valid for specified type"),
+            nullptr);
+  return FloatAttr::getChecked(type,
+                               isNegative ? -val.getValue() : val.getValue(),
+                               getEncodedSourceLocation(valTok));
+}
+
+/// Parse an integer attribute.
+Attribute Parser::parseIntegerAttr(Type type, bool isSigned) {
+  auto val = getToken().getUInt64IntegerValue();
+  if (!val.hasValue() ||
+      (isSigned ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0))
+    return (emitError("integer constant out of range for attribute"), nullptr);
+  consumeToken(Token::integer);
+  if (!type) {
+    // Default to i64 if not type is specified.
+    if (!consumeIf(Token::colon))
+      type = builder.getIntegerType(64);
+    else if (!(type = parseType()))
+      return nullptr;
+  }
+  if (!type.isIntOrIndex())
+    return (emitError("integer value not valid for specified type"), nullptr);
+
+  int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+  APInt apInt(width, *val, isSigned);
+  if (apInt != *val)
+    return (emitError("integer constant out of range for attribute"), nullptr);
+  return builder.getIntegerAttr(type, isSigned ? -apInt : apInt);
+}
+
+/// Parse an opaque elements attribute.
+Attribute Parser::parseOpaqueElementsAttr() {
+  consumeToken(Token::kw_opaque);
+  if (parseToken(Token::less, "expected '<' after 'opaque'"))
+    return nullptr;
+
+  if (getToken().isNot(Token::string))
+    return (emitError("expected dialect namespace"), nullptr);
+
+  auto name = getToken().getStringValue();
+  auto *dialect = builder.getContext()->getRegisteredDialect(name);
+  // TODO(shpeisman): Allow for having an unknown dialect on an opaque
+  // attribute. Otherwise, it can't be roundtripped without having the dialect
+  // registered.
+  if (!dialect)
+    return (emitError("no registered dialect with namespace '" + name + "'"),
+            nullptr);
+
+  consumeToken(Token::string);
+  if (parseToken(Token::comma, "expected ','"))
+    return nullptr;
+
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+
+  if (parseToken(Token::comma, "expected ',' after elements literal type"))
+    return nullptr;
+  if (getToken().getKind() != Token::string)
+    return (emitError("opaque string should start with '0x'"), nullptr);
+
+  auto val = getToken().getStringValue();
+  if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+    return (emitError("opaque string should start with '0x'"), nullptr);
+
+  val = val.substr(2);
+  if (!llvm::all_of(val, llvm::isHexDigit))
+    return (emitError("opaque string only contains hex digits"), nullptr);
+
+  consumeToken(Token::string);
+  if (parseToken(Token::greater, "expected '>'"))
+    return nullptr;
+  return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
+}
+
+/// Parse a sparse elements attribute.
+Attribute Parser::parseSparseElementsAttr() {
+  consumeToken(Token::kw_sparse);
+  if (parseToken(Token::less, "Expected '<' after 'sparse'"))
+    return nullptr;
+
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+
+  if (parseToken(Token::comma, "expected ',' after elements literal type"))
+    return nullptr;
+  if (getToken().isNot(Token::l_square))
+    return emitError("expected '[' to start sparse tensor literal"), nullptr;
+
+  /// Parse indices
+  auto indicesEltType = builder.getIntegerType(64);
+  auto indices = parseDenseElementsAttrAsTensor(indicesEltType);
+  if (!indices)
+    return nullptr;
+
+  if (parseToken(Token::comma, "expected ','"))
+    return nullptr;
+
+  /// Parse values.
+  auto valuesEltType = type.getElementType();
+  auto values = parseDenseElementsAttrAsTensor(valuesEltType);
+  if (!values)
+    return nullptr;
+
+  /// Sanity check.
+  auto valuesType = values.getType();
+  if (valuesType.getRank() != 1)
+    return (emitError("expected 1-d tensor for values"), nullptr);
+
+  auto indicesType = indices.getType();
+  auto sameShape = (indicesType.getRank() == 1) ||
+                   (type.getRank() == indicesType.getDimSize(1));
+  auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
+  if (!sameShape || !sameElementNum) {
+    emitError() << "expected shape ([" << type.getShape()
+                << "]); inferred shape of indices literal (["
+                << indicesType.getShape()
+                << "]); inferred shape of values literal (["
+                << valuesType.getShape() << "])";
+    return nullptr;
+  }
+
+  if (parseToken(Token::greater, "expected '>'"))
+    return nullptr;
+
+  // Build the sparse elements attribute by the indices and values.
+  return builder.getSparseElementsAttr(
+      type, indices.cast<DenseIntElementsAttr>(), values);
+}
+
+/// Parse a splat elements attribute.
+Attribute Parser::parseSplatElementsAttr() {
+  consumeToken(Token::kw_splat);
+  if (parseToken(Token::less, "expected '<' after 'splat'"))
+    return nullptr;
+
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+  if (parseToken(Token::comma, "expected ',' after elements literal type"))
+    return nullptr;
+
+  switch (getToken().getKind()) {
+  case Token::floatliteral:
+  case Token::integer:
+  case Token::kw_false:
+  case Token::kw_true:
+  case Token::minus: {
+    auto scalar = parseAttribute(type.getElementType());
+    if (!scalar)
+      return nullptr;
+    if (parseToken(Token::greater, "expected '>'"))
+      return nullptr;
+    return builder.getSplatElementsAttr(type, scalar);
+  }
+  default:
+    return emitError("expected scalar constant inside tensor literal"), nullptr;
+  }
+}
+
 namespace {
 class TensorLiteralParser {
 public:
@@ -1457,9 +1442,19 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
 /// This method compares the shapes from the parsing result and that from the
 /// input argument. It returns a constructed dense elements attribute if both
 /// match.
-DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
-  auto eltTy = type.getElementType();
-  TensorLiteralParser literalParser(*this, eltTy);
+Attribute Parser::parseDenseElementsAttr() {
+  consumeToken(Token::kw_dense);
+  if (parseToken(Token::less, "expected '<' after 'dense'"))
+    return nullptr;
+
+  auto type = parseElementsLiteralType();
+  if (!type)
+    return nullptr;
+
+  if (parseToken(Token::comma, "expected ',' after elements literal type"))
+    return nullptr;
+
+  TensorLiteralParser literalParser(*this, type.getElementType());
   if (literalParser.parse())
     return nullptr;
 
@@ -1470,6 +1465,9 @@ DenseElementsAttr Parser::parseDenseElementsAttr(ShapedType type) {
     return nullptr;
   }
 
+  if (parseToken(Token::greater, "expected '>'"))
+    return nullptr;
+
   return builder.getDenseElementsAttr(type, literalParser.getValues())
       .cast<DenseElementsAttr>();
 }
@@ -1504,9 +1502,8 @@ ShapedType Parser::parseElementsLiteralType() {
     return nullptr;
 
   if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
-    return (
-        emitError("elements literal must be a ranked tensor or vector type"),
-        nullptr);
+    emitError("elements literal must be a ranked tensor or vector type");
+    return nullptr;
   }
 
   auto sType = type.cast<ShapedType>();