[mlir] Optimize the parsing of ElementsAttr hex strings
authorRiver Riddle <riddleriver@gmail.com>
Wed, 28 Oct 2020 23:46:38 +0000 (16:46 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 28 Oct 2020 23:58:06 +0000 (16:58 -0700)
This revision optimizes the parsing of hex strings by using the checked variant of llvm::fromHex, and adding a specialized method to Token for extracting hex strings. This leads a large decrease in compile time when parsing large hex constants (one example: 2.6 seconds -> 370 miliseconds)

Differential Revision: https://reviews.llvm.org/D90266

mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Token.cpp
mlir/lib/Parser/Token.h
mlir/test/IR/dense-elements-hex.mlir
mlir/test/IR/invalid.mlir

index 4e17ccd..606b46b 100644 (file)
@@ -410,22 +410,16 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
 // TensorLiteralParser
 //===----------------------------------------------------------------------===//
 
-/// Parse elements values stored within a hex etring. On success, the values are
+/// Parse elements values stored within a hex string. On success, the values are
 /// stored into 'result'.
 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
                                              std::string &result) {
-  std::string val = tok.getStringValue();
-  if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
-    return parser.emitError(tok.getLoc(),
-                            "elements hex string should start with '0x'");
-
-  StringRef hexValues = StringRef(val).drop_front(2);
-  if (!llvm::all_of(hexValues, llvm::isHexDigit))
-    return parser.emitError(tok.getLoc(),
-                            "elements hex string only contains hex digits");
-
-  result = llvm::fromHex(hexValues);
-  return success();
+  if (Optional<std::string> value = tok.getHexStringValue()) {
+    result = std::move(*value);
+    return success();
+  }
+  return parser.emitError(
+      tok.getLoc(), "expected string containing hex digits starting with `0x`");
 }
 
 namespace {
index db6f716..b4ac30f 100644 (file)
@@ -124,6 +124,21 @@ std::string Token::getStringValue() const {
   return result;
 }
 
+/// Given a token containing a hex string literal, return its value or None if
+/// the token does not contain a valid hex string.
+Optional<std::string> Token::getHexStringValue() const {
+  assert(getKind() == string);
+
+  // Get the internal string data, without the quotes.
+  StringRef bytes = getSpelling().drop_front().drop_back();
+
+  // Try to extract the binary data from the hex string.
+  std::string hex;
+  if (!bytes.consume_front("0x") || !llvm::tryGetFromHex(bytes, hex))
+    return llvm::None;
+  return hex;
+}
+
 /// Given a token containing a symbol reference, return the unescaped string
 /// value.
 std::string Token::getSymbolReference() const {
index 8f37b2b..d4d09f9 100644 (file)
@@ -91,6 +91,11 @@ public:
   /// removing the quote characters and unescaping the contents of the string.
   std::string getStringValue() const;
 
+  /// Given a token containing a hex string literal, return its value or None if
+  /// the token does not contain a valid hex string. A hex string literal is a
+  /// string starting with `0x` and only containing hex digits.
+  Optional<std::string> getHexStringValue() const;
+
   /// Given a token containing a symbol reference, return the unescaped string
   /// value.
   std::string getSymbolReference() const;
index 4b53467..ff6d39d 100644 (file)
 
 // -----
 
-// expected-error@+1 {{elements hex string should start with '0x'}}
+// expected-error@+1 {{expected string containing hex digits starting with `0x`}}
 "foo.op"() {dense.attr = dense<"00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
 
 // -----
 
-// expected-error@+1 {{elements hex string only contains hex digits}}
+// expected-error@+1 {{expected string containing hex digits starting with `0x`}}
 "foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144X"> : tensor<2xf64>} : () -> ()
 
 // -----
index 8c09819..ae68636 100644 (file)
@@ -718,14 +718,14 @@ func @elementsattr_malformed_opaque() -> () {
 
 func @elementsattr_malformed_opaque1() -> () {
 ^bb0:
-  "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string only contains hex digits}}
+  "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}}
 }
 
 // -----
 
 func @elementsattr_malformed_opaque2() -> () {
 ^bb0:
-  "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string should start with '0x'}}
+  "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}}
 }
 
 // -----