Initial version for chapter 1 of the Toy tutorial
authorMehdi Amini <aminim@google.com>
Tue, 2 Apr 2019 17:02:07 +0000 (10:02 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 2 Apr 2019 20:40:06 +0000 (13:40 -0700)
--

PiperOrigin-RevId: 241549247

15 files changed:
mlir/CMakeLists.txt
mlir/examples/CMakeLists.txt [new file with mode: 0644]
mlir/examples/toy/CMakeLists.txt [new file with mode: 0644]
mlir/examples/toy/Ch1/CMakeLists.txt [new file with mode: 0644]
mlir/examples/toy/Ch1/include/toy/AST.h [new file with mode: 0644]
mlir/examples/toy/Ch1/include/toy/Lexer.h [new file with mode: 0644]
mlir/examples/toy/Ch1/include/toy/Parser.h [new file with mode: 0644]
mlir/examples/toy/Ch1/parser/AST.cpp [new file with mode: 0644]
mlir/examples/toy/Ch1/toyc.cpp [new file with mode: 0644]
mlir/g3doc/Tutorials/Toy/Ch-1.md [new file with mode: 0644]
mlir/test/CMakeLists.txt
mlir/test/Examples/Toy/Ch1/ast.toy [new file with mode: 0644]
mlir/test/Examples/lit.local.cfg [new file with mode: 0644]
mlir/test/lit.cfg.py
mlir/test/lit.site.cfg.py.in

index 882c249..5dd7d45 100644 (file)
@@ -43,3 +43,7 @@ add_subdirectory(lib)
 add_subdirectory(tools)
 add_subdirectory(unittests)
 add_subdirectory(test)
+
+if( LLVM_INCLUDE_EXAMPLES )
+  add_subdirectory(examples)
+endif()
diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt
new file mode 100644 (file)
index 0000000..9dc01fe
--- /dev/null
@@ -0,0 +1,2 @@
+add_subdirectory(toy)
+
diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt
new file mode 100644 (file)
index 0000000..19f5293
--- /dev/null
@@ -0,0 +1,9 @@
+add_custom_target(Toy)
+set_target_properties(Toy PROPERTIES FOLDER Examples)
+
+macro(add_toy_chapter name)
+  add_dependencies(Toy ${name})
+  add_llvm_example(${name} ${ARGN})
+endmacro(add_toy_chapter name)
+
+add_subdirectory(Ch1)
diff --git a/mlir/examples/toy/Ch1/CMakeLists.txt b/mlir/examples/toy/Ch1/CMakeLists.txt
new file mode 100644 (file)
index 0000000..dd26cf1
--- /dev/null
@@ -0,0 +1,9 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+add_toy_chapter(toyc-ch1
+  toyc.cpp
+  parser/AST.cpp
+  )
+include_directories(include/)
\ No newline at end of file
diff --git a/mlir/examples/toy/Ch1/include/toy/AST.h b/mlir/examples/toy/Ch1/include/toy/AST.h
new file mode 100644 (file)
index 0000000..456a323
--- /dev/null
@@ -0,0 +1,256 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_AST_H_
+#define MLIR_TUTORIAL_TOY_AST_H_
+
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <vector>
+
+namespace toy {
+
+/// A variable
+struct VarType {
+  enum { TY_FLOAT, TY_INT } elt_ty;
+  std::vector<int> shape;
+};
+
+/// Base class for all expression nodes.
+class ExprAST {
+public:
+  enum ExprASTKind {
+    Expr_VarDecl,
+    Expr_Return,
+    Expr_Num,
+    Expr_Literal,
+    Expr_Var,
+    Expr_BinOp,
+    Expr_Call,
+    Expr_Print, // builtin
+    Expr_If,
+    Expr_For,
+  };
+
+  ExprAST(ExprASTKind kind, Location location)
+      : kind(kind), location(location) {}
+
+  virtual ~ExprAST() = default;
+
+  ExprASTKind getKind() const { return kind; }
+
+  const Location &loc() { return location; }
+
+private:
+  const ExprASTKind kind;
+  Location location;
+};
+
+/// A block-list of expressions.
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+/// Expression class for numeric literals like "1.0".
+class NumberExprAST : public ExprAST {
+  double Val;
+
+public:
+  NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+
+  double getValue() { return Val; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+};
+
+///
+class LiteralExprAST : public ExprAST {
+  std::vector<std::unique_ptr<ExprAST>> values;
+  std::vector<int64_t> dims;
+
+public:
+  LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
+                 std::vector<int64_t> dims)
+      : ExprAST(Expr_Literal, loc), values(std::move(values)),
+        dims(std::move(dims)) {}
+
+  std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
+  std::vector<int64_t> &getDims() { return dims; }
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+};
+
+/// Expression class for referencing a variable, like "a".
+class VariableExprAST : public ExprAST {
+  std::string name;
+
+public:
+  VariableExprAST(Location loc, const std::string &name)
+      : ExprAST(Expr_Var, loc), name(name) {}
+
+  llvm::StringRef getName() { return name; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+};
+
+///
+class VarDeclExprAST : public ExprAST {
+  std::string name;
+  VarType type;
+  std::unique_ptr<ExprAST> initVal;
+
+public:
+  VarDeclExprAST(Location loc, const std::string &name, VarType type,
+                 std::unique_ptr<ExprAST> initVal)
+      : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
+        initVal(std::move(initVal)) {}
+
+  llvm::StringRef getName() { return name; }
+  ExprAST *getInitVal() { return initVal.get(); }
+  VarType &getType() { return type; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+};
+
+///
+class ReturnExprAST : public ExprAST {
+  llvm::Optional<std::unique_ptr<ExprAST>> expr;
+
+public:
+  ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
+      : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+  llvm::Optional<ExprAST *> getExpr() {
+    if (expr.hasValue())
+      return expr->get();
+    return llvm::NoneType();
+  }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+};
+
+/// Expression class for a binary operator.
+class BinaryExprAST : public ExprAST {
+  char Op;
+  std::unique_ptr<ExprAST> LHS, RHS;
+
+public:
+  char getOp() { return Op; }
+  ExprAST *getLHS() { return LHS.get(); }
+  ExprAST *getRHS() { return RHS.get(); }
+
+  BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
+                std::unique_ptr<ExprAST> RHS)
+      : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
+        RHS(std::move(RHS)) {}
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+};
+
+/// Expression class for function calls.
+class CallExprAST : public ExprAST {
+  std::string Callee;
+  std::vector<std::unique_ptr<ExprAST>> Args;
+
+public:
+  CallExprAST(Location loc, const std::string &Callee,
+              std::vector<std::unique_ptr<ExprAST>> Args)
+      : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+
+  llvm::StringRef getCallee() { return Callee; }
+  llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+  std::unique_ptr<ExprAST> Arg;
+
+public:
+  PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
+      : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+
+  ExprAST *getArg() { return Arg.get(); }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+};
+
+/// This class represents the "prototype" for a function, which captures its
+/// name, and its argument names (thus implicitly the number of arguments the
+/// function takes).
+class PrototypeAST {
+  Location location;
+  std::string name;
+  std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+  PrototypeAST(Location location, const std::string &name,
+               std::vector<std::unique_ptr<VariableExprAST>> args)
+      : location(location), name(name), args(std::move(args)) {}
+
+  const Location &loc() { return location; }
+  const std::string &getName() const { return name; }
+  const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
+    return args;
+  }
+};
+
+/// This class represents a function definition itself.
+class FunctionAST {
+  std::unique_ptr<PrototypeAST> Proto;
+  std::unique_ptr<ExprASTList> Body;
+
+public:
+  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
+              std::unique_ptr<ExprASTList> Body)
+      : Proto(std::move(Proto)), Body(std::move(Body)) {}
+  PrototypeAST *getProto() { return Proto.get(); }
+  ExprASTList *getBody() { return Body.get(); }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+  std::vector<FunctionAST> functions;
+
+public:
+  ModuleAST(std::vector<FunctionAST> functions)
+      : functions(std::move(functions)) {}
+
+  auto begin() -> decltype(functions.begin()) { return functions.begin(); }
+  auto end() -> decltype(functions.end()) { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_AST_H_
diff --git a/mlir/examples/toy/Ch1/include/toy/Lexer.h b/mlir/examples/toy/Ch1/include/toy/Lexer.h
new file mode 100644 (file)
index 0000000..d73adb9
--- /dev/null
@@ -0,0 +1,239 @@
+//===- Lexer.h - Lexer for the Toy language -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple Lexer for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
+#define MLIR_TUTORIAL_TOY_LEXER_H_
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace toy {
+
+/// Structure definition a location in a file.
+struct Location {
+  std::shared_ptr<std::string> file; ///< filename
+  int line;                          ///< line number.
+  int col;                           ///< column number.
+};
+
+// List of Token returned by the lexer.
+enum Token : int {
+  tok_semicolon = ';',
+  tok_parenthese_open = '(',
+  tok_parenthese_close = ')',
+  tok_bracket_open = '{',
+  tok_bracket_close = '}',
+  tok_sbracket_open = '[',
+  tok_sbracket_close = ']',
+
+  tok_eof = -1,
+
+  // commands
+  tok_return = -2,
+  tok_var = -3,
+  tok_def = -4,
+
+  // primary
+  tok_identifier = -5,
+  tok_number = -6,
+};
+
+/// The Lexer is an abstract base class providing all the facilities that the
+/// Parser expects. It goes through the stream one token at a time and keeps
+/// track of the location in the file for debugging purpose.
+/// It relies on a subclass to provide a `readNextLine()` method. The subclass
+/// can proceed by reading the next line from the standard input or from a
+/// memory mapped file.
+class Lexer {
+public:
+  /// Create a lexer for the given filename. The filename is kept only for
+  /// debugging purpose (attaching a location to a Token).
+  Lexer(std::string filename)
+      : lastLocation(
+            {std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
+  virtual ~Lexer() = default;
+
+  /// Look at the current token in the stream.
+  Token getCurToken() { return curTok; }
+
+  /// Move to the next token in the stream and return it.
+  Token getNextToken() { return curTok = getTok(); }
+
+  /// Move to the next token in the stream, asserting on the current token
+  /// matching the expectation.
+  void consume(Token tok) {
+    assert(tok == curTok && "consume Token mismatch expectation");
+    getNextToken();
+  }
+
+  /// Return the current identifier (prereq: getCurToken() == tok_identifier)
+  llvm::StringRef getId() {
+    assert(curTok == tok_identifier);
+    return IdentifierStr;
+  }
+
+  /// Return the current number (prereq: getCurToken() == tok_number)
+  double getValue() {
+    assert(curTok == tok_number);
+    return NumVal;
+  }
+
+  /// Return the location for the beginning of the current token.
+  Location getLastLocation() { return lastLocation; }
+
+  // Return the current line in the file.
+  int getLine() { return curLineNum; }
+
+  // Return the current column in the file.
+  int getCol() { return curCol; }
+
+private:
+  /// Delegate to a derived class fetching the next line. Returns an empty
+  /// string to signal end of file (EOF). Lines are expected to always finish
+  /// with "\n"
+  virtual llvm::StringRef readNextLine() = 0;
+
+  /// Return the next character from the stream. This manages the buffer for the
+  /// current line and request the next line buffer to the derived class as
+  /// needed.
+  int getNextChar() {
+    // The current line buffer should not be empty unless it is the end of file.
+    if (curLineBuffer.empty())
+      return EOF;
+    ++curCol;
+    auto nextchar = curLineBuffer.front();
+    curLineBuffer = curLineBuffer.drop_front();
+    if (curLineBuffer.empty())
+      curLineBuffer = readNextLine();
+    if (nextchar == '\n') {
+      ++curLineNum;
+      curCol = 0;
+    }
+    return nextchar;
+  }
+
+  ///  Return the next token from standard input.
+  Token getTok() {
+    // Skip any whitespace.
+    while (isspace(LastChar))
+      LastChar = Token(getNextChar());
+
+    // Save the current location before reading the token characters.
+    lastLocation.line = curLineNum;
+    lastLocation.col = curCol;
+
+    if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
+      IdentifierStr = (char)LastChar;
+      while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
+        IdentifierStr += (char)LastChar;
+
+      if (IdentifierStr == "return")
+        return tok_return;
+      if (IdentifierStr == "def")
+        return tok_def;
+      if (IdentifierStr == "var")
+        return tok_var;
+      return tok_identifier;
+    }
+
+    if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
+      std::string NumStr;
+      do {
+        NumStr += LastChar;
+        LastChar = Token(getNextChar());
+      } while (isdigit(LastChar) || LastChar == '.');
+
+      NumVal = strtod(NumStr.c_str(), nullptr);
+      return tok_number;
+    }
+
+    if (LastChar == '#') {
+      // Comment until end of line.
+      do
+        LastChar = Token(getNextChar());
+      while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+
+      if (LastChar != EOF)
+        return getTok();
+    }
+
+    // Check for end of file.  Don't eat the EOF.
+    if (LastChar == EOF)
+      return tok_eof;
+
+    // Otherwise, just return the character as its ascii value.
+    Token ThisChar = Token(LastChar);
+    LastChar = Token(getNextChar());
+    return ThisChar;
+  }
+
+  /// The last token read from the input.
+  Token curTok = tok_eof;
+
+  /// Location for `curTok`.
+  Location lastLocation;
+
+  /// If the current Token is an identifier, this string contains the value.
+  std::string IdentifierStr;
+
+  /// If the current Token is a number, this contains the value.
+  double NumVal = 0;
+
+  /// The last value returned by getNextChar(). We need to keep it around as we
+  /// always need to read ahead one character to decide when to end a token and
+  /// we can't put it back in the stream after reading from it.
+  Token LastChar = Token(' ');
+
+  /// Keep track of the current line number in the input stream
+  int curLineNum = 0;
+
+  /// Keep track of the current column number in the input stream
+  int curCol = 0;
+
+  /// Buffer supplied by the derived class on calls to `readNextLine()`
+  llvm::StringRef curLineBuffer = "\n";
+};
+
+/// A lexer implementation operating on a buffer in memory.
+class LexerBuffer final : public Lexer {
+public:
+  LexerBuffer(const char *begin, const char *end, std::string filename)
+      : Lexer(std::move(filename)), current(begin), end(end) {}
+
+private:
+  /// Provide one line at a time to the Lexer, return an empty string when
+  /// reaching the end of the buffer.
+  llvm::StringRef readNextLine() override {
+    auto *begin = current;
+    while (current <= end && *current && *current != '\n')
+      ++current;
+    if (current <= end && *current)
+      ++current;
+    llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+    return result;
+  }
+  const char *current, *end;
+};
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_LEXER_H_
diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h
new file mode 100644 (file)
index 0000000..bc7aa52
--- /dev/null
@@ -0,0 +1,494 @@
+//===- Parser.h - Toy Language Parser -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the Toy language. It processes the Token
+// provided by the Lexer and returns an AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PARSER_H
+#define MLIR_TUTORIAL_TOY_PARSER_H
+
+#include "toy/AST.h"
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+
+namespace toy {
+
+/// This is a simple recursive parser for the Toy language. It produces a well
+/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
+/// or symbol resolution is performed. For example, variables are referenced by
+/// string and the code could reference an undeclared variable and the parsing
+/// succeeds.
+class Parser {
+public:
+  /// Create a Parser for the supplied lexer.
+  Parser(Lexer &lexer) : lexer(lexer) {}
+
+  /// Parse a full Module. A module is a list of function definitions.
+  std::unique_ptr<ModuleAST> ParseModule() {
+    lexer.getNextToken(); // prime the lexer
+
+    // Parse functions one at a time and accumulate in this vector.
+    std::vector<FunctionAST> functions;
+    while (auto F = ParseDefinition()) {
+      functions.push_back(std::move(*F));
+      if (lexer.getCurToken() == tok_eof)
+        break;
+    }
+    // If we didn't reach EOF, there was an error during parsing
+    if (lexer.getCurToken() != tok_eof)
+      return parseError<ModuleAST>("nothing", "at end of module");
+
+    return llvm::make_unique<ModuleAST>(std::move(functions));
+  }
+
+private:
+  Lexer &lexer;
+
+  /// Parse a return statement.
+  /// return :== return ; | return expr ;
+  std::unique_ptr<ReturnExprAST> ParseReturn() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(tok_return);
+
+    // return takes an optional argument
+    llvm::Optional<std::unique_ptr<ExprAST>> expr;
+    if (lexer.getCurToken() != ';') {
+      expr = ParseExpression();
+      if (!expr)
+        return nullptr;
+    }
+    return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+  }
+
+  /// Parse a literal number.
+  /// numberexpr ::= number
+  std::unique_ptr<ExprAST> ParseNumberExpr() {
+    auto loc = lexer.getLastLocation();
+    auto Result =
+        llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+    lexer.consume(tok_number);
+    return std::move(Result);
+  }
+
+  /// Parse a literal array expression.
+  /// tensorLiteral ::= [ literalList ] | number
+  /// literalList ::= tensorLiteral | tensorLiteral, literalList
+  std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(Token('['));
+
+    // Hold the list of values at this nesting level.
+    std::vector<std::unique_ptr<ExprAST>> values;
+    // Hold the dimensions for all the nesting inside this level.
+    std::vector<int64_t> dims;
+    do {
+      // We can have either another nested array or a number literal.
+      if (lexer.getCurToken() == '[') {
+        values.push_back(ParseTensorLitteralExpr());
+        if (!values.back())
+          return nullptr; // parse error in the nested array.
+      } else {
+        if (lexer.getCurToken() != tok_number)
+          return parseError<ExprAST>("<num> or [", "in literal expression");
+        values.push_back(ParseNumberExpr());
+      }
+
+      // End of this list on ']'
+      if (lexer.getCurToken() == ']')
+        break;
+
+      // Elements are separated by a comma.
+      if (lexer.getCurToken() != ',')
+        return parseError<ExprAST>("] or ,", "in literal expression");
+
+      lexer.getNextToken(); // eat ,
+    } while (true);
+    if (values.empty())
+      return parseError<ExprAST>("<something>", "to fill literal expression");
+    lexer.getNextToken(); // eat ]
+    /// Fill in the dimensions now. First the current nesting level:
+    dims.push_back(values.size());
+    /// If there is any nested array, process all of them and ensure that
+    /// dimensions are uniform.
+    if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+          return llvm::isa<LiteralExprAST>(expr.get());
+        })) {
+      auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+      if (!firstLiteral)
+        return parseError<ExprAST>("uniform well-nested dimensions",
+                                   "inside literal expession");
+
+      // Append the nested dimensions to the current level
+      auto &firstDims = firstLiteral->getDims();
+      dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+
+      // Sanity check that shape is uniform across all elements of the list.
+      for (auto &expr : values) {
+        auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+        if (!exprLiteral)
+          return parseError<ExprAST>("uniform well-nested dimensions",
+                                     "inside literal expession");
+        if (exprLiteral->getDims() != firstDims)
+          return parseError<ExprAST>("uniform well-nested dimensions",
+                                     "inside literal expession");
+      }
+    }
+    return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+                                             std::move(dims));
+  }
+
+  /// parenexpr ::= '(' expression ')'
+  std::unique_ptr<ExprAST> ParseParenExpr() {
+    lexer.getNextToken(); // eat (.
+    auto V = ParseExpression();
+    if (!V)
+      return nullptr;
+
+    if (lexer.getCurToken() != ')')
+      return parseError<ExprAST>(")", "to close expression with parentheses");
+    lexer.consume(Token(')'));
+    return V;
+  }
+
+  /// identifierexpr
+  ///   ::= identifier
+  ///   ::= identifier '(' expression ')'
+  std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+    std::string name = lexer.getId();
+
+    auto loc = lexer.getLastLocation();
+    lexer.getNextToken(); // eat identifier.
+
+    if (lexer.getCurToken() != '(') // Simple variable ref.
+      return llvm::make_unique<VariableExprAST>(std::move(loc), name);
+
+    // This is a function call.
+    lexer.consume(Token('('));
+    std::vector<std::unique_ptr<ExprAST>> Args;
+    if (lexer.getCurToken() != ')') {
+      while (true) {
+        if (auto Arg = ParseExpression())
+          Args.push_back(std::move(Arg));
+        else
+          return nullptr;
+
+        if (lexer.getCurToken() == ')')
+          break;
+
+        if (lexer.getCurToken() != ',')
+          return parseError<ExprAST>(", or )", "in argument list");
+        lexer.getNextToken();
+      }
+    }
+    lexer.consume(Token(')'));
+
+    // It can be a builtin call to print
+    if (name == "print") {
+      if (Args.size() != 1)
+        return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+      return llvm::make_unique<PrintExprAST>(std::move(loc),
+                                             std::move(Args[0]));
+    }
+
+    // Call to a user-defined function
+    return llvm::make_unique<CallExprAST>(std::move(loc), name,
+                                          std::move(Args));
+  }
+
+  /// primary
+  ///   ::= identifierexpr
+  ///   ::= numberexpr
+  ///   ::= parenexpr
+  ///   ::= tensorliteral
+  std::unique_ptr<ExprAST> ParsePrimary() {
+    switch (lexer.getCurToken()) {
+    default:
+      llvm::errs() << "unknown token '" << lexer.getCurToken()
+                   << "' when expecting an expression\n";
+      return nullptr;
+    case tok_identifier:
+      return ParseIdentifierExpr();
+    case tok_number:
+      return ParseNumberExpr();
+    case '(':
+      return ParseParenExpr();
+    case '[':
+      return ParseTensorLitteralExpr();
+    case ';':
+      return nullptr;
+    case '}':
+      return nullptr;
+    }
+  }
+
+  /// Recursively parse the right hand side of a binary expression, the ExprPrec
+  /// argument indicates the precedence of the current binary operator.
+  ///
+  /// binoprhs ::= ('+' primary)*
+  std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+                                         std::unique_ptr<ExprAST> LHS) {
+    // If this is a binop, find its precedence.
+    while (true) {
+      int TokPrec = GetTokPrecedence();
+
+      // If this is a binop that binds at least as tightly as the current binop,
+      // consume it, otherwise we are done.
+      if (TokPrec < ExprPrec)
+        return LHS;
+
+      // Okay, we know this is a binop.
+      int BinOp = lexer.getCurToken();
+      lexer.consume(Token(BinOp));
+      auto loc = lexer.getLastLocation();
+
+      // Parse the primary expression after the binary operator.
+      auto RHS = ParsePrimary();
+      if (!RHS)
+        return parseError<ExprAST>("expression", "to complete binary operator");
+
+      // If BinOp binds less tightly with RHS than the operator after RHS, let
+      // the pending operator take RHS as its LHS.
+      int NextPrec = GetTokPrecedence();
+      if (TokPrec < NextPrec) {
+        RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
+        if (!RHS)
+          return nullptr;
+      }
+
+      // Merge LHS/RHS.
+      LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp,
+                                             std::move(LHS), std::move(RHS));
+    }
+  }
+
+  /// expression::= primary binoprhs
+  std::unique_ptr<ExprAST> ParseExpression() {
+    auto LHS = ParsePrimary();
+    if (!LHS)
+      return nullptr;
+
+    return ParseBinOpRHS(0, std::move(LHS));
+  }
+
+  /// type ::= < shape_list >
+  /// shape_list ::= num | num , shape_list
+  std::unique_ptr<VarType> ParseType() {
+    if (lexer.getCurToken() != '<')
+      return parseError<VarType>("<", "to begin type");
+    lexer.getNextToken(); // eat <
+
+    auto type = llvm::make_unique<VarType>();
+
+    while (lexer.getCurToken() == tok_number) {
+      type->shape.push_back(lexer.getValue());
+      lexer.getNextToken();
+      if (lexer.getCurToken() == ',')
+        lexer.getNextToken();
+    }
+
+    if (lexer.getCurToken() != '>')
+      return parseError<VarType>(">", "to end type");
+    lexer.getNextToken(); // eat >
+    return type;
+  }
+
+  /// Parse a variable declaration, it starts with a `var` keyword followed by
+  /// and identifier and an optional type (shape specification) before the
+  /// initializer.
+  /// decl ::= var identifier [ type ] = expr
+  std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+    if (lexer.getCurToken() != tok_var)
+      return parseError<VarDeclExprAST>("var", "to begin declaration");
+    auto loc = lexer.getLastLocation();
+    lexer.getNextToken(); // eat var
+
+    if (lexer.getCurToken() != tok_identifier)
+      return parseError<VarDeclExprAST>("identified",
+                                        "after 'var' declaration");
+    std::string id = lexer.getId();
+    lexer.getNextToken(); // eat id
+
+    std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+    if (lexer.getCurToken() == '<') {
+      type = ParseType();
+      if (!type)
+        return nullptr;
+    }
+
+    if (!type)
+      type = llvm::make_unique<VarType>();
+    lexer.consume(Token('='));
+    auto expr = ParseExpression();
+    return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+                                             std::move(*type), std::move(expr));
+  }
+
+  /// Parse a block: a list of expression separated by semicolons and wrapped in
+  /// curly braces.
+  ///
+  /// block ::= { expression_list }
+  /// expression_list ::= block_expr ; expression_list
+  /// block_expr ::= decl | "return" | expr
+  std::unique_ptr<ExprASTList> ParseBlock() {
+    if (lexer.getCurToken() != '{')
+      return parseError<ExprASTList>("{", "to begin block");
+    lexer.consume(Token('{'));
+
+    auto exprList = llvm::make_unique<ExprASTList>();
+
+    // Ignore empty expressions: swallow sequences of semicolons.
+    while (lexer.getCurToken() == ';')
+      lexer.consume(Token(';'));
+
+    while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+      if (lexer.getCurToken() == tok_var) {
+        // Variable declaration
+        auto varDecl = ParseDeclaration();
+        if (!varDecl)
+          return nullptr;
+        exprList->push_back(std::move(varDecl));
+      } else if (lexer.getCurToken() == tok_return) {
+        // Return statement
+        auto ret = ParseReturn();
+        if (!ret)
+          return nullptr;
+        exprList->push_back(std::move(ret));
+      } else {
+        // General expression
+        auto expr = ParseExpression();
+        if (!expr)
+          return nullptr;
+        exprList->push_back(std::move(expr));
+      }
+      // Ensure that elements are separated by a semicolon.
+      if (lexer.getCurToken() != ';')
+        return parseError<ExprASTList>(";", "after expression");
+
+      // Ignore empty expressions: swallow sequences of semicolons.
+      while (lexer.getCurToken() == ';')
+        lexer.consume(Token(';'));
+    }
+
+    if (lexer.getCurToken() != '}')
+      return parseError<ExprASTList>("}", "to close block");
+
+    lexer.consume(Token('}'));
+    return exprList;
+  }
+
+  /// prototype ::= def id '(' decl_list ')'
+  /// decl_list ::= identifier | identifier, decl_list
+  std::unique_ptr<PrototypeAST> ParsePrototype() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(tok_def);
+    if (lexer.getCurToken() != tok_identifier)
+      return parseError<PrototypeAST>("function name", "in prototype");
+
+    std::string FnName = lexer.getId();
+    lexer.consume(tok_identifier);
+
+    if (lexer.getCurToken() != '(')
+      return parseError<PrototypeAST>("(", "in prototype");
+    lexer.consume(Token('('));
+
+    std::vector<std::unique_ptr<VariableExprAST>> args;
+    if (lexer.getCurToken() != ')') {
+      do {
+        std::string name = lexer.getId();
+        auto loc = lexer.getLastLocation();
+        lexer.consume(tok_identifier);
+        auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name);
+        args.push_back(std::move(decl));
+        if (lexer.getCurToken() != ',')
+          break;
+        lexer.consume(Token(','));
+        if (lexer.getCurToken() != tok_identifier)
+          return parseError<PrototypeAST>(
+              "identifier", "after ',' in function parameter list");
+      } while (true);
+    }
+    if (lexer.getCurToken() != ')')
+      return parseError<PrototypeAST>("}", "to end function prototype");
+
+    // success.
+    lexer.consume(Token(')'));
+    return llvm::make_unique<PrototypeAST>(std::move(loc), FnName,
+                                           std::move(args));
+  }
+
+  /// Parse a function definition, we expect a prototype initiated with the
+  /// `def` keyword, followed by a block containing a list of expressions.
+  ///
+  /// definition ::= prototype block
+  std::unique_ptr<FunctionAST> ParseDefinition() {
+    auto Proto = ParsePrototype();
+    if (!Proto)
+      return nullptr;
+
+    if (auto block = ParseBlock())
+      return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+    return nullptr;
+  }
+
+  /// Get the precedence of the pending binary operator token.
+  int GetTokPrecedence() {
+    if (!isascii(lexer.getCurToken()))
+      return -1;
+
+    // 1 is lowest precedence.
+    switch (static_cast<char>(lexer.getCurToken())) {
+    case '-':
+      return 20;
+    case '+':
+      return 20;
+    case '*':
+      return 40;
+    default:
+      return -1;
+    }
+  }
+
+  /// Helper function to signal errors while parsing, it takes an argument
+  /// indicating the expected token and another argument giving more context.
+  /// Location is retrieved from the lexer to enrich the error message.
+  template <typename R, typename T, typename U = const char *>
+  std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+    auto curToken = lexer.getCurToken();
+    llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+                 << lexer.getLastLocation().col << "): expected '" << expected
+                 << "' " << context << " but has Token " << curToken;
+    if (isprint(curToken))
+      llvm::errs() << " '" << (char)curToken << "'";
+    llvm::errs() << "\n";
+    return nullptr;
+  }
+};
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PARSER_H
diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp
new file mode 100644 (file)
index 0000000..869f2ef
--- /dev/null
@@ -0,0 +1,263 @@
+//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST dump for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/AST.h"
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+
+namespace {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+  Indent(int &level) : level(level) { ++level; }
+  ~Indent() { --level; }
+  int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+  void dump(ModuleAST *Node);
+
+private:
+  void dump(VarType &type);
+  void dump(VarDeclExprAST *varDecl);
+  void dump(ExprAST *expr);
+  void dump(ExprASTList *exprList);
+  void dump(NumberExprAST *num);
+  void dump(LiteralExprAST *Node);
+  void dump(VariableExprAST *Node);
+  void dump(ReturnExprAST *Node);
+  void dump(BinaryExprAST *Node);
+  void dump(CallExprAST *Node);
+  void dump(PrintExprAST *Node);
+  void dump(PrototypeAST *Node);
+  void dump(FunctionAST *Node);
+
+  // Actually print spaces matching the current indentation level
+  void indent() {
+    for (int i = 0; i < curIndent; i++)
+      llvm::errs() << "  ";
+  }
+  int curIndent = 0;
+};
+
+} // namespace
+
+/// Return a formatted string for the location of any node
+template <typename T> static std::string loc(T *Node) {
+  const auto &loc = Node->loc();
+  return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+          llvm::Twine(loc.col))
+      .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT()                                                               \
+  Indent level_(curIndent);                                                    \
+  indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+#define dispatch(CLASS)                                                        \
+  if (CLASS *node = llvm::dyn_cast<CLASS>(expr))                               \
+    return dump(node);
+  dispatch(VarDeclExprAST);
+  dispatch(LiteralExprAST);
+  dispatch(NumberExprAST);
+  dispatch(VariableExprAST);
+  dispatch(ReturnExprAST);
+  dispatch(BinaryExprAST);
+  dispatch(CallExprAST);
+  dispatch(PrintExprAST);
+  // No match, fallback to a generic message
+  INDENT();
+  llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+  INDENT();
+  llvm::errs() << "VarDecl " << varDecl->getName();
+  dump(varDecl->getType());
+  llvm::errs() << " " << loc(varDecl) << "\n";
+  dump(varDecl->getInitVal());
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+  INDENT();
+  llvm::errs() << "Block {\n";
+  for (auto &expr : *exprList)
+    dump(expr.get());
+  indent();
+  llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+  INDENT();
+  llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recurisvely a literal. This handles nested array like:
+///    [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *lit_or_num) {
+  // Inside a literal expression we can have either a number or another literal
+  if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+    llvm::errs() << num->getValue();
+    return;
+  }
+  auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+
+  // Print the dimension for this literal first
+  llvm::errs() << "<";
+  {
+    const char *sep = "";
+    for (auto dim : literal->getDims()) {
+      llvm::errs() << sep << dim;
+      sep = ", ";
+    }
+  }
+  llvm::errs() << ">";
+
+  // Now print the content, recursing on every element of the list
+  llvm::errs() << "[ ";
+  const char *sep = "";
+  for (auto &elt : literal->getValues()) {
+    llvm::errs() << sep;
+    printLitHelper(elt.get());
+    sep = ", ";
+  }
+  llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Literal: ";
+  printLitHelper(Node);
+  llvm::errs() << " " << loc(Node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *Node) {
+  INDENT();
+  llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Return\n";
+  if (Node->getExpr().hasValue())
+    return dump(*Node->getExpr());
+  {
+    INDENT();
+    llvm::errs() << "(void)\n";
+  }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *Node) {
+  INDENT();
+  llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
+  dump(Node->getLHS());
+  dump(Node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
+  for (auto &arg : Node->getArgs())
+    dump(arg.get());
+  indent();
+  llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Print [ " << loc(Node) << "\n";
+  dump(Node->getArg());
+  indent();
+  llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(VarType &type) {
+  llvm::errs() << "<";
+  const char *sep = "";
+  for (auto shape : type.shape) {
+    llvm::errs() << sep << shape;
+    sep = ", ";
+  }
+  llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *Node) {
+  INDENT();
+  llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+  indent();
+  llvm::errs() << "Params: [";
+  const char *sep = "";
+  for (auto &arg : Node->getArgs()) {
+    llvm::errs() << sep << arg->getName();
+    sep = ", ";
+  }
+  llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *Node) {
+  INDENT();
+  llvm::errs() << "Function \n";
+  dump(Node->getProto());
+  dump(Node->getBody());
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *Node) {
+  INDENT();
+  llvm::errs() << "Module:\n";
+  for (auto &F : *Node)
+    dump(&F);
+}
+
+namespace toy {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp
new file mode 100644 (file)
index 0000000..dd308ca
--- /dev/null
@@ -0,0 +1,75 @@
+//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the entry point for the Toy compiler.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Parser.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> InputFilename(cl::Positional,
+                                          cl::desc("<input toy file>"),
+                                          cl::init("-"),
+                                          cl::value_desc("filename"));
+namespace {
+enum Action { None, DumpAST };
+}
+
+static cl::opt<enum Action>
+    emitAction("emit", cl::desc("Select the kind of output desired"),
+               cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));
+
+/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
+std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+      llvm::MemoryBuffer::getFileOrSTDIN(filename);
+  if (std::error_code EC = FileOrErr.getError()) {
+    llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+    return nullptr;
+  }
+  auto buffer = FileOrErr.get()->getBuffer();
+  LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
+  Parser parser(lexer);
+  return parser.ParseModule();
+}
+
+int main(int argc, char **argv) {
+  cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+
+  auto moduleAST = parseInputFile(InputFilename);
+  if (!moduleAST)
+    return 1;
+
+  switch (emitAction) {
+  case Action::DumpAST:
+    dump(*moduleAST);
+    return 0;
+  default:
+    llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+  }
+
+  return 0;
+}
diff --git a/mlir/g3doc/Tutorials/Toy/Ch-1.md b/mlir/g3doc/Tutorials/Toy/Ch-1.md
new file mode 100644 (file)
index 0000000..cd3bd33
--- /dev/null
@@ -0,0 +1,149 @@
+# Chapter 1: Toy Tutorial Introduction
+
+This tutorial runs through the implementation of a basic toy language on top of
+MLIR. The goal of this tutorial is to introduce the concepts of MLIR, and
+especially how *dialects* can help easily support language specific constructs
+and transformations, while still offering an easy path to lower to LLVM or other
+codegen infrastructure. This tutorial is based on the model of the
+[LLVM Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl01.html).
+
+This tutorial is divided in the following chapters:
+
+-   [Chapter #1](Ch-1.md): Introduction to the Toy language, and the definition
+    of its AST.
+-   [Chapter #2](Ch-2.md): Traversing the AST to emit custom MLIR, introducing
+    base MLIR concepts.
+-   [Chapter #3](Ch-3.md): Defining and registering a dialect in MLIR, showing
+    how we can start attaching semantics to our custom operations in MLIR.
+-   [Chapter #4](Ch-4.md): High-level language-specific analysis and
+    transformation, showcasing shape inference, generic function specialization,
+    and basic optimizations.
+-   [Chapter #5](Ch-5.md): Lowering to lower-level dialects. We'll convert our
+    high level language specific semantics towards a generic linear-algebra
+    oriented dialect for optimizations. Ultimately we will emit LLVM IR for code
+    generation.
+-   [Chapter #5](Ch-6.md): A REPL?
+-   [Chapter #6](Ch-7.md): Custom backends? GPU using LLVM? TPU? XLA
+
+## The Language
+
+This tutorial will be illustrated with a toy language that we’ll call “Toy”
+(naming is hard...). Toy is an array-based language that allows you to define
+functions, some math computation, and print results.
+
+Because we want to keep things simple, the codegen will be limited to arrays of
+rank <= 2 and the only datatype in Toy is a 64-bit floating point type (aka
+‘double’ in C parlance). As such, all values are implicitly double precision,
+Values are immutable: every operation returns a newly allocated value, and
+deallocation is automatically managed. But enough with the long description,
+nothing is better than walking through an example to get a better understanding:
+
+FIXME: update/modify matrix multiplication to use @ instead of *
+
+```Toy {.toy}
+def main() {
+  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
+  # The shape is inferred from the supplied literal.
+  var a = [[1, 2, 3], [4, 5, 6]];
+  # b is identical to a, the literal array is implicitely reshaped: defining new
+  # variables is the way to reshape arrays (element count must match).
+  var b<2, 3> = [1, 2, 3, 4, 5, 6];
+  # transpose() and print() are the only builtin, the following will transpose
+  # b and perform a matrix multiplication before printing the result.
+  print(a * transpose(b));
+}
+```
+
+Type checking is statically performed through type inference, the language only
+requires type declarations to specify array shapes when needed. Function are
+generic: their parameters are unranked (in other word we know these are arrays
+but we don't know how many dimensions or the size of the dimensions). They are
+specialized for every newly discovered signature at call sites. Let's revisit
+the previous example by adding a user-defined function:
+
+```Toy {.toy}
+# User defined generic function that operates on unknown shaped arguments
+def multiply_transpose(a, b) {
+  return a * transpose(b);
+}
+
+def main() {
+  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
+  var a = [[1, 2, 3], [4, 5, 6]];
+  var b<2, 3> = [1, 2, 3, 4, 5, 6];
+  # This call will specialize `multiply_transpose` with <2, 3> for both
+  # arguments and deduce a return type of <2, 2> in initialization of `c`.
+  var c = multiply_transpose(a, b);
+  # A second call to `multiply_transpose` with <2, 3> for both arguments will
+  # reuse the previously specialized and inferred version and return `<2, 2>`
+  var d = multiply_transpose(b, a);
+  # A new call with `<2, 2>` for both dimension will trigger another
+  # specialization of `multiply_transpose`.
+  var e = multiply_transpose(c, d);
+  # Finally, calling into `multiply_transpose` with incompatible shape will
+  # trigger a shape inference error.
+  var e = multiply_transpose(transpose(a), c);
+}
+```
+
+## The AST
+
+The AST is fairly straightforward from the above code, here is a dump of it:
+
+```
+Module:
+  Function
+    Proto 'multiply_transpose' @test/ast.toy:5:1'
+    Args: [a, b]
+    Block {
+      Return
+        BinOp: * @test/ast.toy:6:12
+          var: a @test/ast.toy:6:10
+          Call 'transpose' [ @test/ast.toy:6:14
+            var: b @test/ast.toy:6:24
+          ]
+    } // Block
+  Function
+    Proto 'main' @test/ast.toy:9:1'
+    Args: []
+    Block {
+      VarDecl a<2, 3> @test/ast.toy:11:3
+        Literal: <2, 3>[<3>[1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[4.000000e+00, 5.000000e+00, 6.000000e+00]] @test/ast.toy:11:17
+      VarDecl b<2, 3> @test/ast.toy:12:3
+        Literal: <6>[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @test/ast.toy:12:17
+      VarDecl c<> @test/ast.toy:15:3
+        Call 'multiply_transpose' [ @test/ast.toy:15:11
+          var: a @test/ast.toy:15:30
+          var: b @test/ast.toy:15:33
+        ]
+      VarDecl d<> @test/ast.toy:18:3
+        Call 'multiply_transpose' [ @test/ast.toy:18:11
+          var: b @test/ast.toy:18:30
+          var: a @test/ast.toy:18:33
+        ]
+      VarDecl e<> @test/ast.toy:21:3
+        Call 'multiply_transpose' [ @test/ast.toy:21:11
+          var: b @test/ast.toy:21:30
+          var: c @test/ast.toy:21:33
+        ]
+      VarDecl e<> @test/ast.toy:24:3
+        Call 'multiply_transpose' [ @test/ast.toy:24:11
+          Call 'transpose' [ @test/ast.toy:24:30
+            var: a @test/ast.toy:24:40
+          ]
+          var: c @test/ast.toy:24:44
+        ]
+    } // Block
+```
+
+You can reproduce this result and play with the example in the `examples/Ch1/`
+directory, try running `path/to/BUILD/bin/toyc test/ast.toy -emit=ast`.
+
+The code for the lexer is fairly straighforward, it is all in a single header:
+`examples/toy/Ch1/include/toy/Lexer.h`. The parser can be found in
+`examples/toy/Ch1/include/toy/Parser.h`, it is a recursive descent parser. If
+you are not familiar with such Lexer/Parser, these are very similar to the LLVM
+Kaleidoscope equivalent that are detailed in the first two chapters of the
+[Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl02.html#the-abstract-syntax-tree-ast).
+
+The [next chapter](Ch-2.md) will demonstrate how to convert this AST into MLIR.
index 8e3f594..16e2cb4 100644 (file)
@@ -1,3 +1,8 @@
+llvm_canonicalize_cmake_booleans(
+  LLVM_BUILD_EXAMPLES
+  )
+
+
 configure_lit_site_cfg(
   ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
   ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
@@ -20,6 +25,13 @@ set(MLIR_TEST_DEPENDS
   mlir-translate
   )
 
+
+if(LLVM_BUILD_EXAMPLES)
+  list(APPEND MLIR_TEST_DEPENDS
+    toyc-ch1
+    )
+endif()
+
 add_lit_testsuite(check-mlir "Running the MLIR regression tests"
   ${CMAKE_CURRENT_BINARY_DIR}
   DEPENDS ${MLIR_TEST_DEPENDS}
diff --git a/mlir/test/Examples/Toy/Ch1/ast.toy b/mlir/test/Examples/Toy/Ch1/ast.toy
new file mode 100644 (file)
index 0000000..e8c8fe0
--- /dev/null
@@ -0,0 +1,71 @@
+# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s
+
+
+# User defined generic function that operates solely on 
+def multiply_transpose(a, b) {
+  return a * transpose(b);
+}
+
+def main() {
+  # Define a variable `a` with shape <2, 3>, initialized with the literal value
+  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+  var b<2, 3> = [1, 2, 3, 4, 5, 6];
+  # This call will specialize `multiply_transpose` with <2, 3> for both
+  # arguments and deduce a return type of <2, 2> in initialization of `c`.
+  var c = multiply_transpose(a, b);
+  # A second call to `multiply_transpose` with <2, 3> for both arguments will
+  # reuse the previously specialized and inferred version and return `<2, 2>`
+  var d = multiply_transpose(b, a);
+  # A new call with `<2, 2>` for both dimension will trigger another
+  # specialization of `multiply_transpose`.
+  var e = multiply_transpose(b, c);
+  # Finally, calling into `multiply_transpose` with incompatible shape will
+  # trigger a shape inference error.
+  var e = multiply_transpose(transpose(a), c);
+}
+
+
+# CHECK: Module:
+# CHECK-NEXT:     Function 
+# CHECK-NEXT:       Proto 'multiply_transpose' 
+# CHECK-NEXT:       Params: [a, b]
+# CHECK-NEXT:       Block {
+# CHECK-NEXT:         Retur
+# CHECK-NEXT:           BinOp: * 
+# CHECK-NEXT:             var: a 
+# CHECK-NEXT:             Call 'transpose' [ 
+# CHECK-NEXT:               var: b 
+# CHECK-NEXT:             ]
+# CHECK-NEXT:       } // Block
+# CHECK-NEXT:     Function 
+# CHECK-NEXT:       Proto 'main' 
+# CHECK-NEXT:       Params: []
+# CHECK-NEXT:       Block {
+# CHECK-NEXT:         VarDecl a<2, 3> 
+# CHECK-NEXT:           Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] 
+# CHECK-NEXT:         VarDecl b<2, 3> 
+# CHECK-NEXT:           Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] 
+# CHECK-NEXT:         VarDecl c<> 
+# CHECK-NEXT:           Call 'multiply_transpose' [ 
+# CHECK-NEXT:             var: a 
+# CHECK-NEXT:             var: b 
+# CHECK-NEXT:           ]
+# CHECK-NEXT:         VarDecl d<> 
+# CHECK-NEXT:           Call 'multiply_transpose' [ 
+# CHECK-NEXT:             var: b 
+# CHECK-NEXT:             var: a 
+# CHECK-NEXT:           ]
+# CHECK-NEXT:         VarDecl e<> 
+# CHECK-NEXT:           Call 'multiply_transpose' [ 
+# CHECK-NEXT:             var: b 
+# CHECK-NEXT:             var: c 
+# CHECK-NEXT:           ]
+# CHECK-NEXT:         VarDecl e<> 
+# CHECK-NEXT:           Call 'multiply_transpose' [ 
+# CHECK-NEXT:             Call 'transpose' [ 
+# CHECK-NEXT:               var: a 
+# CHECK-NEXT:             ]
+# CHECK-NEXT:             var: c 
+# CHECK-NEXT:           ]
+# CHECK-NEXT:       } // Block
+
diff --git a/mlir/test/Examples/lit.local.cfg b/mlir/test/Examples/lit.local.cfg
new file mode 100644 (file)
index 0000000..97db322
--- /dev/null
@@ -0,0 +1,2 @@
+if not config.build_examples:
+  config.unsupported = True
index 51075f1..75459a1 100644 (file)
@@ -21,7 +21,7 @@ config.name = 'MLIR'
 config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
 
 # suffixes: A list of file extensions to treat as test files.
-config.suffixes = ['.td', '.mlir']
+config.suffixes = ['.td', '.mlir', '.toy']
 
 # test_source_root: The root path where tests are located.
 config.test_source_root = os.path.dirname(__file__)
@@ -54,4 +54,10 @@ tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
 tools = [
     'mlir-opt', 'mlir-tblgen', 'mlir-translate',
 ]
+
+# The following tools are optional
+tools.extend([
+    ToolSubst('toy-ch1', unresolved='ignore'),
+])
+
 llvm_config.add_tool_substitutions(tools, tool_dirs)
index 4cbdb3a..c701b04 100644 (file)
@@ -30,6 +30,7 @@ config.host_arch = "@HOST_ARCH@"
 config.mlir_src_root = "@MLIR_SOURCE_DIR@"
 config.mlir_obj_root = "@MLIR_BINARY_DIR@"
 config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
+config.build_examples = @LLVM_BUILD_EXAMPLES@
 
 # Support substitution of the tools_dir with user parameters. This is
 # used when we can't determine the tool dir at configuration time.