NFC: Add support for parsing attributes programmatically via mlir::parseAttribute.
authorRiver Riddle <riverriddle@google.com>
Tue, 22 Oct 2019 04:34:21 +0000 (21:34 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 Oct 2019 04:34:51 +0000 (21:34 -0700)
This matches the behavior of the public mlir::parseType, and even uses the internal implementation.

PiperOrigin-RevId: 275989777

mlir/include/mlir/Parser.h
mlir/lib/Parser/Parser.cpp

index 2a98977..3a818ff 100644 (file)
@@ -31,6 +31,7 @@ class StringRef;
 } // end namespace llvm
 
 namespace mlir {
+class Attribute;
 class Location;
 class MLIRContext;
 class OwningModuleRef;
@@ -61,6 +62,24 @@ OwningModuleRef parseSourceFile(llvm::StringRef filename,
 OwningModuleRef parseSourceString(llvm::StringRef moduleStr,
                                   MLIRContext *context);
 
+/// This parses a single MLIR attribute to an MLIR context if it was valid.  If
+/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
+/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
+/// `attrStr`. If the passed `attrStr` has additional tokens that were not part
+/// of the type, an error is emitted.
+// TODO(ntv) Improve diagnostic reporting.
+Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context);
+Attribute parseAttribute(llvm::StringRef attrStr, Type type);
+
+/// This parses a single MLIR attribute to an MLIR context if it was valid.  If
+/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
+/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
+/// `attrStr`. The number of characters of `attrStr` parsed in the process is
+/// returned in `numRead`.
+Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
+                         size_t &numRead);
+Attribute parseAttribute(llvm::StringRef attrStr, Type type, size_t &numRead);
+
 /// This parses a single MLIR type to an MLIR context if it was valid.  If not,
 /// an error message is emitted through a new SourceMgrDiagnosticHandler
 /// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
index 60a0a90..8813cdc 100644 (file)
@@ -4333,28 +4333,60 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
   return parseSourceFile(sourceMgr, context);
 }
 
-Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context,
-                     size_t &numRead) {
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
+                     size_t &numRead, ParserFn &&parserFn) {
   SourceMgr sourceMgr;
-  auto memBuffer =
-      MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>",
-                                 /*RequiresNullTerminator=*/false);
+  auto memBuffer = MemoryBuffer::getMemBuffer(
+      inputStr, /*BufferName=*/"<mlir_parser_buffer>",
+      /*RequiresNullTerminator=*/false);
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
   ParserState state(sourceMgr, context);
   Parser parser(state);
 
   auto start = parser.getToken().getLoc();
-  auto ty = parser.parseType();
-  if (!ty)
-    return Type();
+  T symbol = parserFn(parser);
+  if (!symbol)
+    return T();
 
   auto end = parser.getToken().getLoc();
   numRead = static_cast<size_t>(end.getPointer() - start.getPointer());
-  return ty;
+  return symbol;
+}
+
+Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context) {
+  size_t numRead = 0;
+  return parseAttribute(attrStr, context, numRead);
+}
+Attribute mlir::parseAttribute(llvm::StringRef attrStr, Type type) {
+  size_t numRead = 0;
+  return parseAttribute(attrStr, type, numRead);
+}
+
+Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
+                               size_t &numRead) {
+  return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
+    return parser.parseAttribute();
+  });
+}
+Attribute mlir::parseAttribute(llvm::StringRef attrStr, Type type,
+                               size_t &numRead) {
+  return parseSymbol<Attribute>(
+      attrStr, type.getContext(), numRead,
+      [type](Parser &parser) { return parser.parseAttribute(type); });
 }
 
 Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) {
   size_t numRead = 0;
   return parseType(typeStr, context, numRead);
 }
+
+Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context,
+                     size_t &numRead) {
+  return parseSymbol<Type>(typeStr, context, numRead,
+                           [](Parser &parser) { return parser.parseType(); });
+}