Expose a minimal type parser to dialects.
authorNicolas Vasilache <ntv@google.com>
Mon, 10 Jun 2019 20:12:32 +0000 (13:12 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 11 Jun 2019 17:13:02 +0000 (10:13 -0700)
This CL exposes a parseType method which allows standalone reuse of the MLIR type parsing mechanism. This is a free function for now because the underlying MLIR parser is not guaranteed to receive a StringRef which lives in the proper MemBuffer. This requires building a new MemBuffer/SourceMgr and modifying the Parser constructor to not require an mlir::Module.

The error diagnostic emitted by parseType has context limited to the local string.
For now the dialect has the additional option to emit its own extra error that has the FileLineColLoc context.

In the future, both error messages should be combined into a single error.

PiperOrigin-RevId: 252468911

mlir/include/mlir/Parser.h
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/IR/LinalgTypes.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/Linalg/roundtrip.mlir

index 61d8fc0..2da728b 100644 (file)
@@ -29,8 +29,10 @@ class StringRef;
 } // end namespace llvm
 
 namespace mlir {
+class Location;
 class Module;
 class MLIRContext;
+class Type;
 
 /// This parses the file specified by the indicated SourceMgr and returns an
 /// MLIR module if it was valid.  If not, the error message is emitted through
@@ -43,10 +45,17 @@ Module *parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
 Module *parseSourceFile(llvm::StringRef filename, MLIRContext *context);
 
 /// This parses the module string to a MLIR module if it was valid.  If not, the
-/// error message is emitted through the error handler registered in the
+/// error message is emitted through the error handler registered in the
 /// context, and a null pointer is returned.
 Module *parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
 
+/// This parses a single MLIR type to an MLIR context if it was valid.  If not,
+/// an error message is emitted through a new SourceMgrDiagnosticHandler
+/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
+/// `typeStr`.
+// TODO(ntv) Improve diagnostic reporting.
+Type parseType(llvm::StringRef typeStr, MLIRContext *context);
+
 } // end namespace mlir
 
 #endif // MLIR_PARSER_H
index 7ccb0ff..7d41c86 100644 (file)
@@ -69,10 +69,6 @@ ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser,
   auto indexTy = parser->getBuilder().getIndexType();
   if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
     return failure();
-  if (bufferType.getElementType() != parser->getBuilder().getF32Type())
-    return parser->emitError(parser->getNameLoc(),
-                             "Only buffer<f32> supported until "
-                             "mlir::linalg::Parser pieces are exposed");
   return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
                  parser->addTypeToList(bufferType, result->types));
 }
index 4c8d87b..8496cfd 100644 (file)
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Parser.h"
 #include "mlir/Support/LLVM.h"
 
+#include "llvm/Support/raw_ostream.h"
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -88,23 +91,35 @@ Type mlir::linalg::BufferType::getElementType() {
 
 Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
                                             Location loc) const {
+  StringRef origSpec = spec;
   MLIRContext *context = getContext();
   if (spec == "range")
     return RangeType::get(getContext());
-  // TODO(ntv): reuse mlir Parser once exposed.
-  if (spec == "buffer<f32>")
-    return BufferType::get(getContext(), FloatType::getF32(getContext()));
-  // TODO(ntv): reuse mlir Parser once exposed.
-  if (spec.startswith("view")) {
-    spec.consume_front("view");
-    // Just count the number of ? to get the rank, the type must be f32 for now.
-    unsigned rank = 0;
-    for (unsigned i = 0, e = spec.size(); i < e; ++i)
-      if (spec[i] == '?')
-        ++rank;
-    return ViewType::get(context, FloatType::getF32(context), rank);
+  else if (spec.consume_front("buffer")) {
+    if (spec.consume_front("<") && spec.consume_back(">")) {
+      if (auto t = mlir::parseType(spec, context))
+        return BufferType::get(getContext(), t);
+    }
+  } else if (spec.consume_front("view")) {
+    if (spec.consume_front("<") && spec.consume_back(">")) {
+      // Just count the number of ? to get the rank.
+      unsigned rank = 0;
+      for (unsigned i = 0, e = spec.size(); i < e; ++i) {
+        if (spec.consume_front("?")) {
+          ++rank;
+          if (!spec.consume_front("x")) {
+            context->emitError(loc,
+                               "expected a list of '?x' dimension specifiers: ")
+                << spec;
+            return Type();
+          }
+        }
+      }
+      if (auto t = mlir::parseType(spec, context))
+        return ViewType::get(context, t, rank);
+    }
   }
-  return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
+  return (context->emitError(loc, "unknown Linalg type: " + origSpec), Type());
 }
 
 struct mlir::linalg::ViewTypeStorage : public TypeStorage {
index 67432d6..515e53e 100644 (file)
@@ -58,9 +58,8 @@ class Parser;
 /// methods to access this.
 class ParserState {
 public:
-  ParserState(const llvm::SourceMgr &sourceMgr, Module *module)
-      : context(module->getContext()), module(module), lex(sourceMgr, context),
-        curToken(lex.lexToken()) {}
+  ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+      : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {}
 
   // A map from attribute alias identifier to Attribute.
   llvm::StringMap<Attribute> attributeAliasDefinitions;
@@ -77,9 +76,6 @@ private:
   // The context we're parsing into.
   MLIRContext *const context;
 
-  // This is the module we are parsing into.
-  Module *const module;
-
   // The lexer for the source file we're parsing.
   Lexer lex;
 
@@ -103,7 +99,6 @@ public:
   // Helper methods to get stuff from the parser-global state.
   ParserState &getState() const { return state; }
   MLIRContext *getContext() const { return state.context; }
-  Module *getModule() { return state.module; }
   const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
 
   /// Parse a comma-separated list of elements up until the specified end token.
@@ -3620,7 +3615,7 @@ class ModuleParser : public Parser {
 public:
   explicit ModuleParser(ParserState &state) : Parser(state) {}
 
-  ParseResult parseModule();
+  ParseResult parseModule(Module *module);
 
 private:
   /// Parse an attribute alias declaration.
@@ -3638,7 +3633,7 @@ private:
       StringRef &name, FunctionType &type,
       SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
       SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
-  ParseResult parseFunc();
+  ParseResult parseFunc(Module *module);
 };
 } // end anonymous namespace
 
@@ -3802,7 +3797,7 @@ ParseResult ModuleParser::parseFunctionSignature(
 ///   function-body ::= `{` block+ `}`
 ///   function-attributes ::= `attributes` attribute-dict
 ///
-ParseResult ModuleParser::parseFunc() {
+ParseResult ModuleParser::parseFunc(Module *module) {
   consumeToken();
 
   StringRef name;
@@ -3824,7 +3819,7 @@ ParseResult ModuleParser::parseFunc() {
   // Okay, the function signature was parsed correctly, create the function now.
   auto *function =
       new Function(getEncodedSourceLocation(loc), name, type, attrs);
-  getModule()->getFunctions().push_back(function);
+  module->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
   if (function->getName() != name)
@@ -3864,7 +3859,7 @@ ParseResult ModuleParser::parseFunc() {
 }
 
 /// This is the top-level module parser.
-ParseResult ModuleParser::parseModule() {
+ParseResult ModuleParser::parseModule(Module *module) {
   while (1) {
     switch (getToken().getKind()) {
     default:
@@ -3894,7 +3889,7 @@ ParseResult ModuleParser::parseModule() {
       break;
 
     case Token::kw_func:
-      if (parseFunc())
+      if (parseFunc(module))
         return failure();
       break;
     }
@@ -3912,8 +3907,8 @@ Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
   // This is the result module we are parsing into.
   std::unique_ptr<Module> module(new Module(context));
 
-  ParserState state(sourceMgr, module.get());
-  if (ModuleParser(state).parseModule()) {
+  ParserState state(sourceMgr, context);
+  if (ModuleParser(state).parseModule(module.get())) {
     return nullptr;
   }
 
@@ -3953,3 +3948,13 @@ Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   return parseSourceFile(sourceMgr, context);
 }
+
+Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) {
+  SourceMgr sourceMgr;
+  auto memBuffer = MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"",
+                                              /*RequiresNullTerminator=*/false);
+  sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
+  ParserState state(sourceMgr, context);
+  return Parser(state).parseType();
+}
index d730a20..9b389d2 100644 (file)
@@ -9,14 +9,19 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
 
 func @buffer(%arg0: index, %arg1: index) {
   %0 = muli %arg0, %arg0 : index
-  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+  %1 = linalg.buffer_alloc %0 : !linalg.buffer<vector<4xi8>>
+  linalg.buffer_dealloc %1 : !linalg.buffer<vector<4xi8>>
   return
 }
 // CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
 //  CHECK-NEXT:  %0 = muli %arg0, %arg0 : index
-//  CHECK-NEXT:  %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-//  CHECK-NEXT:  linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+//  CHECK-NEXT:  %1 = linalg.buffer_alloc %0 : !linalg.buffer<vector<4xi8>>
+//  CHECK-NEXT:  linalg.buffer_dealloc %1 : !linalg.buffer<vector<4xi8>>
+
+func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
+  return
+}
+// CHECK-LABEL: func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
 
 func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
   %0 = muli %arg0, %arg0 : index