From 213dda687b549aefcc535fe6b5653018d0819000 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 2 Apr 2019 13:11:20 -0700 Subject: [PATCH] Chapter 2 of the Toy tutorial This introduces a basic MLIRGen through straight AST traversal, without dialect registration at this point. -- PiperOrigin-RevId: 241588354 --- mlir/examples/toy/CMakeLists.txt | 1 + mlir/examples/toy/Ch2/CMakeLists.txt | 16 + mlir/examples/toy/Ch2/include/toy/AST.h | 256 ++++++++++++++ mlir/examples/toy/Ch2/include/toy/Lexer.h | 239 +++++++++++++ mlir/examples/toy/Ch2/include/toy/MLIRGen.h | 42 +++ mlir/examples/toy/Ch2/include/toy/Parser.h | 494 ++++++++++++++++++++++++++ mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 523 ++++++++++++++++++++++++++++ mlir/examples/toy/Ch2/parser/AST.cpp | 263 ++++++++++++++ mlir/examples/toy/Ch2/toyc.cpp | 135 +++++++ mlir/g3doc/Tutorials/Toy/Ch-2.md | 216 ++++++++++++ mlir/test/CMakeLists.txt | 1 + mlir/test/Examples/Toy/Ch1/ast.toy | 66 ++-- mlir/test/Examples/Toy/Ch2/ast.toy | 73 ++++ mlir/test/Examples/Toy/Ch2/codegen.toy | 32 ++ mlir/test/Examples/Toy/Ch2/invalid.mlir | 11 + mlir/test/lit.cfg.py | 1 + 16 files changed, 2337 insertions(+), 32 deletions(-) create mode 100644 mlir/examples/toy/Ch2/CMakeLists.txt create mode 100644 mlir/examples/toy/Ch2/include/toy/AST.h create mode 100644 mlir/examples/toy/Ch2/include/toy/Lexer.h create mode 100644 mlir/examples/toy/Ch2/include/toy/MLIRGen.h create mode 100644 mlir/examples/toy/Ch2/include/toy/Parser.h create mode 100644 mlir/examples/toy/Ch2/mlir/MLIRGen.cpp create mode 100644 mlir/examples/toy/Ch2/parser/AST.cpp create mode 100644 mlir/examples/toy/Ch2/toyc.cpp create mode 100644 mlir/g3doc/Tutorials/Toy/Ch-2.md create mode 100644 mlir/test/Examples/Toy/Ch2/ast.toy create mode 100644 mlir/test/Examples/Toy/Ch2/codegen.toy create mode 100644 mlir/test/Examples/Toy/Ch2/invalid.mlir diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt index 19f5293..b70c371 100644 --- a/mlir/examples/toy/CMakeLists.txt +++ b/mlir/examples/toy/CMakeLists.txt @@ -7,3 +7,4 @@ macro(add_toy_chapter name) endmacro(add_toy_chapter name) add_subdirectory(Ch1) +add_subdirectory(Ch2) diff --git a/mlir/examples/toy/Ch2/CMakeLists.txt b/mlir/examples/toy/Ch2/CMakeLists.txt new file mode 100644 index 0000000..1209963 --- /dev/null +++ b/mlir/examples/toy/Ch2/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LLVM_LINK_COMPONENTS + Support + ) + +add_toy_chapter(toyc-ch2 + toyc.cpp + parser/AST.cpp + mlir/MLIRGen.cpp + ) +include_directories(include/) +target_link_libraries(toyc-ch2 + PRIVATE + MLIRAnalysis + MLIRIR + MLIRParser + MLIRTransforms) diff --git a/mlir/examples/toy/Ch2/include/toy/AST.h b/mlir/examples/toy/Ch2/include/toy/AST.h new file mode 100644 index 0000000..456a323 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/AST.h @@ -0,0 +1,256 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable +struct VarType { + enum { TY_FLOAT, TY_INT } elt_ty; + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, // builtin + Expr_If, + Expr_For, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } +}; + +/// +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + std::vector> &getValues() { return values; } + std::vector &getDims() { return dims; } + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, const std::string &name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } +}; + +/// +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, const std::string &name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } +}; + +/// +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::NoneType(); + } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char Op; + std::unique_ptr LHS, RHS; + +public: + char getOp() { return Op; } + ExprAST *getLHS() { return LHS.get(); } + ExprAST *getRHS() { return RHS.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr LHS, + std::unique_ptr RHS) + : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), + RHS(std::move(RHS)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string Callee; + std::vector> Args; + +public: + CallExprAST(Location loc, const std::string &Callee, + std::vector> Args) + : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + + llvm::StringRef getCallee() { return Callee; } + llvm::ArrayRef> getArgs() { return Args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr Arg; + +public: + PrintExprAST(Location loc, std::unique_ptr Arg) + : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + + ExprAST *getArg() { return Arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + const std::string &getName() const { return name; } + const std::vector> &getArgs() { + return args; + } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr Proto; + std::unique_ptr Body; + +public: + FunctionAST(std::unique_ptr Proto, + std::unique_ptr Body) + : Proto(std::move(Proto)), Body(std::move(Body)) {} + PrototypeAST *getProto() { return Proto.get(); } + ExprASTList *getBody() { return Body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/Lexer.h b/mlir/examples/toy/Ch2/include/toy/Lexer.h new file mode 100644 index 0000000..d73adb9 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/Lexer.h @@ -0,0 +1,239 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return IdentifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return NumVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(LastChar)) + LastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* + IdentifierStr = (char)LastChar; + while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') + IdentifierStr += (char)LastChar; + + if (IdentifierStr == "return") + return tok_return; + if (IdentifierStr == "def") + return tok_def; + if (IdentifierStr == "var") + return tok_var; + return tok_identifier; + } + + if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ + std::string NumStr; + do { + NumStr += LastChar; + LastChar = Token(getNextChar()); + } while (isdigit(LastChar) || LastChar == '.'); + + NumVal = strtod(NumStr.c_str(), nullptr); + return tok_number; + } + + if (LastChar == '#') { + // Comment until end of line. + do + LastChar = Token(getNextChar()); + while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + + if (LastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (LastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token ThisChar = Token(LastChar); + LastChar = Token(getNextChar()); + return ThisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string IdentifierStr; + + /// If the current Token is a number, this contains the value. + double NumVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token LastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/MLIRGen.h b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h new file mode 100644 index 0000000..21637bc --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h @@ -0,0 +1,42 @@ +//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file declares a simple interface to perform IR generation targeting MLIR +// from a Module AST for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_ +#define MLIR_TUTORIAL_TOY_MLIRGEN_H_ + +#include + +namespace mlir { +class MLIRContext; +class Module; +} // namespace mlir + +namespace toy { +class ModuleAST; + +/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module +/// or nullptr on failure. +std::unique_ptr mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST); +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch2/include/toy/Parser.h b/mlir/examples/toy/Ch2/include/toy/Parser.h new file mode 100644 index 0000000..bc7aa52 --- /dev/null +++ b/mlir/examples/toy/Ch2/include/toy/Parser.h @@ -0,0 +1,494 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr ParseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto F = ParseDefinition()) { + functions.push_back(std::move(*F)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return llvm::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr ParseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = ParseExpression(); + if (!expr) + return nullptr; + } + return llvm::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr ParseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto Result = + llvm::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(Result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr ParseTensorLitteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(ParseTensorLitteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(ParseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + + // Append the nested dimensions to the current level + auto &firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + } + } + return llvm::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr ParseParenExpr() { + lexer.getNextToken(); // eat (. + auto V = ParseExpression(); + if (!V) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return V; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr ParseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return llvm::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> Args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto Arg = ParseExpression()) + Args.push_back(std::move(Arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (Args.size() != 1) + return parseError("", "as argument to print()"); + + return llvm::make_unique(std::move(loc), + std::move(Args[0])); + } + + // Call to a user-defined function + return llvm::make_unique(std::move(loc), name, + std::move(Args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr ParsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return ParseIdentifierExpr(); + case tok_number: + return ParseNumberExpr(); + case '(': + return ParseParenExpr(); + case '[': + return ParseTensorLitteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr ParseBinOpRHS(int ExprPrec, + std::unique_ptr LHS) { + // If this is a binop, find its precedence. + while (true) { + int TokPrec = GetTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (TokPrec < ExprPrec) + return LHS; + + // Okay, we know this is a binop. + int BinOp = lexer.getCurToken(); + lexer.consume(Token(BinOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto RHS = ParsePrimary(); + if (!RHS) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with RHS than the operator after RHS, let + // the pending operator take RHS as its LHS. + int NextPrec = GetTokPrecedence(); + if (TokPrec < NextPrec) { + RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); + if (!RHS) + return nullptr; + } + + // Merge LHS/RHS. + LHS = llvm::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); + } + } + + /// expression::= primary binoprhs + std::unique_ptr ParseExpression() { + auto LHS = ParsePrimary(); + if (!LHS) + return nullptr; + + return ParseBinOpRHS(0, std::move(LHS)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr ParseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = llvm::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr ParseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = ParseType(); + if (!type) + return nullptr; + } + + if (!type) + type = llvm::make_unique(); + lexer.consume(Token('=')); + auto expr = ParseExpression(); + return llvm::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr ParseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = llvm::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = ParseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = ParseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = ParseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr ParsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string FnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = llvm::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return llvm::make_unique(std::move(loc), FnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr ParseDefinition() { + auto Proto = ParsePrototype(); + if (!Proto) + return nullptr; + + if (auto block = ParseBlock()) + return llvm::make_unique(std::move(Proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int GetTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp new file mode 100644 index 0000000..d21d629f --- /dev/null +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -0,0 +1,523 @@ +//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR generation targeting MLIR from a Module AST +// for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/AST.h" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/StandardOps/Ops.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace toy; +using llvm::cast; +using llvm::dyn_cast; +using llvm::isa; +using llvm::make_unique; +using llvm::ScopedHashTableScope; +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +/// Implementation of a simple MLIR emission from the Toy AST. +/// +/// This will emit operations that are specific to the Toy language, preserving +/// the semantics of the language and (hopefully) allow to perform accurate +/// analysis and transformation based on these high level semantics. +/// +/// At this point we take advantage of the "raw" MLIR APIs to create operations +/// that haven't been registered in any way with MLIR. These operations are +/// unknown to MLIR, custom passes could operate by string-matching the name of +/// these operations, but no other type checking or semantic is associated with +/// them natively by MLIR. +class MLIRGenImpl { +public: + MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} + + /// Public API: convert the AST for a Toy module (source file) to an MLIR + /// Module. + std::unique_ptr mlirGen(ModuleAST &moduleAST) { + // We create an empty MLIR module and codegen functions one at a time and + // add them to the module. + theModule = make_unique(&context); + + for (FunctionAST &F : moduleAST) { + auto func = mlirGen(F); + if (!func) + return nullptr; + theModule->getFunctions().push_back(func.release()); + } + + // FIXME: (in the next chapter...) without registering a dialect in MLIR, + // this won't do much, but it should at least check some structural + // properties. + if (failed(theModule->verify())) { + context.emitError(mlir::UnknownLoc::get(&context), + "Module verification error"); + return nullptr; + } + + return std::move(theModule); + } + +private: + /// In MLIR (like in LLVM) a "context" object holds the memory allocation and + /// the ownership of many internal structure of the IR and provide a level + /// of "uniquing" across multiple modules (types for instance). + mlir::MLIRContext &context; + + /// A "module" matches a source file: it contains a list of functions. + std::unique_ptr theModule; + + /// The builder is a helper class to create IR inside a function. It is + /// re-initialized every time we enter a function and kept around as a + /// convenience for emitting individual operations. + /// The builder is stateful, in particular it keeeps an "insertion point": + /// this is where the next operations will be introduced. + std::unique_ptr builder; + + /// The symbol table maps a variable name to a value in the current scope. + /// Entering a function creates a new scope, and the function arguments are + /// added to the mapping. When the processing of a function is terminated, the + /// scope is destroyed and the mappings created in this scope are dropped. + llvm::ScopedHashTable symbolTable; + + /// Create a new scope in the symbol table. The scope lifetime is managed by + /// the returned RAII object. + ScopedHashTableScope create_scope() { + return ScopedHashTableScope(symbolTable); + } + + /// Helper conversion for a Toy AST location to an MLIR location. + mlir::FileLineColLoc loc(Location loc) { + return mlir::FileLineColLoc::get( + mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col, + &context); + } + + /// Declare a variable in the current scope, return true if the variable + /// wasn't declared yet. + bool declare(llvm::StringRef var, mlir::Value *value) { + if (symbolTable.count(var)) { + return false; + } + symbolTable.insert(var, value); + return true; + } + + /// Create the prototype for an MLIR function with as many arguments as the + /// provided Toy AST prototype. + mlir::Function *mlirGen(PrototypeAST &proto) { + // This is a generic function, the return type will be inferred later. + llvm::SmallVector ret_types; + // Arguments type is uniformly a generic array. + llvm::SmallVector arg_types(proto.getArgs().size(), + getType(VarType{})); + auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); + auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); + + // Mark the function as generic: it'll require type specialization for every + // call site. + if (function->getNumArguments()) + function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + + return function; + } + + /// Emit a new function and add it to the MLIR module. + std::unique_ptr mlirGen(FunctionAST &funcAST) { + // Create a scope in the symbol table to hold variable declarations. + auto var_scope = create_scope(); + + // Create an MLIR function for the given prototype. + std::unique_ptr function(mlirGen(*funcAST.getProto())); + if (!function) + return nullptr; + + // Let's start the body of the function now! + // In MLIR the entry block of the function is special: it must have the same + // argument list as the function itself. + function->addEntryBlock(); + + auto &entryBlock = function->front(); + auto &protoArgs = funcAST.getProto()->getArgs(); + // Declare all the function arguments in the symbol table. + for (const auto &name_value : + llvm::zip(protoArgs, entryBlock.getArguments())) { + declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); + } + + // Create a builder for the function, it will be used throughout the codegen + // to create operations in this function. + builder = llvm::make_unique(function.get()); + + // Emit the body of the function. + if (!mlirGen(*funcAST.getBody())) + return nullptr; + + // Implicitly return void if no return statement was emited. + // FIXME: we may fix the parser instead to always return the last expression + // (this would possibly help the REPL case later) + if (function->getBlocks().back().back().getName().getStringRef() != + "toy.return") { + ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); + mlirGen(fakeRet); + } + + return function; + } + + /// Emit a binary operation + mlir::Value *mlirGen(BinaryExprAST &binop) { + // First emit the operations for each side of the operation before emitting + // the operation itself. For example if the expression is `a + foo(a)` + // 1) First it will visiting the LHS, which will return a reference to the + // value holding `a`. This value should have been emitted at declaration + // time and registered in the symbol table, so nothing would be + // codegen'd. If the value is not in the symbol table, an error has been + // emitted and nullptr is returned. + // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted + // and the result value is returned. If an error occurs we get a nullptr + // and propagate. + // + mlir::Value *L = mlirGen(*binop.getLHS()); + if (!L) + return nullptr; + mlir::Value *R = mlirGen(*binop.getRHS()); + if (!R) + return nullptr; + auto location = loc(binop.loc()); + + // Derive the operation name from the binary operator. At the moment we only + // support '+' and '*'. + const char *op_name = nullptr; + switch (binop.getOp()) { + case '+': + op_name = "toy.add"; + break; + case '*': + op_name = "toy.mul"; + break; + default: + context.emitError(loc(binop.loc()), + Twine("Error: invalid binary operator '") + + Twine(binop.getOp()) + "'"); + return nullptr; + } + + // Build the MLIR operation from the name and the two operands. The return + // type is always a generic array for binary operators. + mlir::OperationState result(&context, location, op_name); + result.types.push_back(getType(VarType{})); + result.operands.push_back(L); + result.operands.push_back(R); + return builder->createOperation(result)->getResult(0); + } + + // This is a reference to a variable in an expression. The variable is + // expected to have been declared and so should have a value in the symbol + // table, otherwise emit an error and return nullptr. + mlir::Value *mlirGen(VariableExprAST &expr) { + if (symbolTable.count(expr.getName())) + return symbolTable.lookup(expr.getName()); + context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") + + expr.getName() + "'"); + return nullptr; + } + + // Emit a return operation, return true on success. + bool mlirGen(ReturnExprAST &ret) { + auto location = loc(ret.loc()); + // `return` takes an optional expression, we need to account for it here. + mlir::OperationState result(&context, location, "toy.return"); + if (ret.getExpr().hasValue()) { + auto *expr = mlirGen(*ret.getExpr().getValue()); + if (!expr) + return false; + result.operands.push_back(expr); + } + builder->createOperation(result); + return true; + } + + // Emit a literal/constant array. It will be emitted as a flattened array of + // data in an Attribute attached to a `toy.constant` operation. + // See documentation on [Attributes](LangRef.md#attributes) for more details. + // Here is an excerpt: + // + // Attributes are the mechanism for specifying constant data in MLIR in + // places where a variable is never allowed [...]. They consist of a name + // and a [concrete attribute value](#attribute-values). It is possible to + // attach attributes to operations, functions, and function arguments. The + // set of expected attributes, their structure, and their interpretation + // are all contextually dependent on what they are attached to. + // + // Example, the source level statement: + // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + // will be converted to: + // %0 = "toy.constant"() {value: dense, + // [[1.000000e+00, 2.000000e+00, 3.000000e+00], + // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> + // + mlir::Value *mlirGen(LiteralExprAST &lit) { + auto location = loc(lit.loc()); + auto type = getType(lit.getDims()); + + // The attribute is a vector with an attribute per element (number) in the + // array, see `collectData()` below for more details. + std::vector data; + data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, + std::multiplies())); + collectData(lit, data); + + // FIXME: using a tensor type is a HACK here. + // Can we do differently without registering a dialect? Using a string blob? + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto dataType = builder->getTensorType(lit.getDims(), elementType); + + // This is the actual attribute that actually hold the list of values for + // this array literal. + auto dataAttribute = builder->getNamedAttr( + "value", builder->getDenseElementsAttr(dataType, data) + .cast()); + + // Build the MLIR op `toy.constant`, only boilerplate below. + mlir::OperationState result(&context, location, "toy.constant"); + result.types.push_back(type); + result.attributes.push_back(dataAttribute); + return builder->createOperation(result)->getResult(0); + } + + // Recursive helper function to accumulate the data that compose an array + // literal. It flattens the nested structure in the supplied vector. For + // example with this array: + // [[1, 2], [3, 4]] + // we will generate: + // [ 1, 2, 3, 4 ] + // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. + // Attributes are the way MLIR attaches constant to operations and functions. + void collectData(ExprAST &expr, std::vector &data) { + if (auto *lit = dyn_cast(&expr)) { + for (auto &value : lit->getValues()) + collectData(*value, data); + return; + } + assert(isa(expr) && "expected literal or number expr"); + mlir::Type elementType = mlir::FloatType::getF64(&context); + auto attr = mlir::FloatAttr::getChecked( + elementType, cast(expr).getValue(), loc(expr.loc())); + data.push_back(attr); + } + + // Emit a call expression. It emits specific operations for the `transpose` + // builtin. Other identifiers are assumed to be user-defined functions. + mlir::Value *mlirGen(CallExprAST &call) { + auto location = loc(call.loc()); + std::string callee = call.getCallee(); + // Codegen the operands first. + SmallVector operands; + for (auto &expr : call.getArgs()) { + auto *arg = mlirGen(*expr); + if (!arg) + return nullptr; + operands.push_back(arg); + } + // builtin have their custom operation, this is a straightforward emission. + if (callee == "transpose") { + mlir::OperationState result(&context, location, "toy.transpose"); + result.types.push_back(getType(VarType{})); + result.operands = std::move(operands); + return builder->createOperation(result)->getResult(0); + } + + // Calls to user-defined functions are mapped to a custom call that takes + // the callee name as an attribute. + mlir::OperationState result(&context, location, "toy.generic_call"); + result.types.push_back(getType(VarType{})); + result.operands = std::move(operands); + for (auto &expr : call.getArgs()) { + auto *arg = mlirGen(*expr); + if (!arg) + return nullptr; + result.operands.push_back(arg); + } + auto calleeAttr = builder->getStringAttr(call.getCallee()); + result.attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); + return builder->createOperation(result)->getResult(0); + } + + // Emit a call expression. It emits specific operations for two builtins: + // transpose(x) and print(x). Other identifiers are assumed to be user-defined + // functions. Return false on failure. + bool mlirGen(PrintExprAST &call) { + auto *arg = mlirGen(*call.getArg()); + if (!arg) + return false; + auto location = loc(call.loc()); + mlir::OperationState result(&context, location, "toy.print"); + result.operands.push_back(arg); + builder->createOperation(result); + return true; + } + + // Emit a constant for a single number (FIXME: semantic? broadcast?) + mlir::Value *mlirGen(NumberExprAST &num) { + auto location = loc(num.loc()); + mlir::OperationState result(&context, location, "toy.constant"); + mlir::Type elementType = mlir::FloatType::getF64(&context); + result.types.push_back(builder->getMemRefType({1}, elementType)); + auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), + loc(num.loc())); + result.attributes.push_back(builder->getNamedAttr("value", attr)); + return builder->createOperation(result)->getResult(0); + } + + // Dispatch codegen for the right expression subclass using RTTI. + mlir::Value *mlirGen(ExprAST &expr) { + switch (expr.getKind()) { + case toy::ExprAST::Expr_BinOp: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Var: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Literal: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Call: + return mlirGen(cast(expr)); + case toy::ExprAST::Expr_Num: + return mlirGen(cast(expr)); + default: + context.emitError( + loc(expr.loc()), + Twine("MLIR codegen encountered an unhandled expr kind '") + + Twine(expr.getKind()) + "'"); + return nullptr; + } + } + + // Handle a variable declaration, we'll codegen the expression that forms the + // initializer and record the value in the symbol table before returning it. + // Future expressions will be able to reference this variable through symbol + // table lookup. + mlir::Value *mlirGen(VarDeclExprAST &vardecl) { + mlir::Value *value = nullptr; + auto location = loc(vardecl.loc()); + if (auto init = vardecl.getInitVal()) { + value = mlirGen(*init); + if (!value) + return nullptr; + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + mlir::OperationState result(&context, location, "toy.reshape"); + result.types.push_back(getType(vardecl.getType())); + result.operands.push_back(value); + value = builder->createOperation(result)->getResult(0); + } + } else { + context.emitError(loc(vardecl.loc()), + "Missing initializer in variable declaration"); + return nullptr; + } + // Register the value in the symbol table + declare(vardecl.getName(), value); + return value; + } + + /// Codegen a list of expression, return false if one of them hit an error. + bool mlirGen(ExprASTList &blockAST) { + auto var_scope = create_scope(); + for (auto &expr : blockAST) { + // Specific handling for variable declarations, return statement, and + // print. These can only appear in block list and not in nested + // expressions. + if (auto *vardecl = dyn_cast(expr.get())) { + if (!mlirGen(*vardecl)) + return false; + continue; + } + if (auto *ret = dyn_cast(expr.get())) { + if (!mlirGen(*ret)) + return false; + return true; + } + if (auto *print = dyn_cast(expr.get())) { + if (!mlirGen(*print)) + return false; + return true; + } + // Generic expression dispatch codegen. + if (!mlirGen(*expr)) + return false; + } + return true; + } + + /// Build a type from a list of shape dimensions. Types are `array` followed + /// by an optional dimension list, example: array<2, 2> + /// They are wrapped in a `toy` dialect (see next chapter) and get printed: + /// !toy<"array<2, 2>"> + template mlir::Type getType(T shape) { + mlir::Type elementType = mlir::FloatType::getF64(&context); + std::string typeName = "array"; + if (!shape.empty()) { + typeName += "<"; + const char *sep = ""; + for (auto dim : shape) { + typeName += sep; + typeName += llvm::Twine(dim).str(); + sep = ", "; + } + typeName += ">"; + } + return mlir::UnknownType::get(mlir::Identifier::get("toy", &context), + typeName, &context); + } + + /// Build an MLIR type from a Toy AST variable type + /// (forward to the generic getType(T) above). + mlir::Type getType(const VarType &type) { return getType(type.shape); } +}; + +} // namespace + +namespace toy { + +// The public API for codegen. +std::unique_ptr mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { + return MLIRGenImpl(context).mlirGen(moduleAST); +} + +} // namespace toy diff --git a/mlir/examples/toy/Ch2/parser/AST.cpp b/mlir/examples/toy/Ch2/parser/AST.cpp new file mode 100644 index 0000000..869f2ef --- /dev/null +++ b/mlir/examples/toy/Ch2/parser/AST.cpp @@ -0,0 +1,263 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *Node); + +private: + void dump(VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *Node); + void dump(VariableExprAST *Node); + void dump(ReturnExprAST *Node); + void dump(BinaryExprAST *Node); + void dump(CallExprAST *Node); + void dump(PrintExprAST *Node); + void dump(PrototypeAST *Node); + void dump(FunctionAST *Node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *Node) { + const auto &loc = Node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { +#define dispatch(CLASS) \ + if (CLASS *node = llvm::dyn_cast(expr)) \ + return dump(node); + dispatch(VarDeclExprAST); + dispatch(LiteralExprAST); + dispatch(NumberExprAST); + dispatch(VariableExprAST); + dispatch(ReturnExprAST); + dispatch(BinaryExprAST); + dispatch(CallExprAST); + dispatch(PrintExprAST); + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recurisvely a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *lit_or_num) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(lit_or_num)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(lit_or_num); + + // Print the dimension for this literal first + llvm::errs() << "<"; + { + const char *sep = ""; + for (auto dim : literal->getDims()) { + llvm::errs() << sep << dim; + sep = ", "; + } + } + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + const char *sep = ""; + for (auto &elt : literal->getValues()) { + llvm::errs() << sep; + printLitHelper(elt.get()); + sep = ", "; + } + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *Node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(Node); + llvm::errs() << " " << loc(Node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *Node) { + INDENT(); + llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *Node) { + INDENT(); + llvm::errs() << "Return\n"; + if (Node->getExpr().hasValue()) + return dump(*Node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *Node) { + INDENT(); + llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; + dump(Node->getLHS()); + dump(Node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *Node) { + INDENT(); + llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; + for (auto &arg : Node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *Node) { + INDENT(); + llvm::errs() << "Print [ " << loc(Node) << "\n"; + dump(Node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(VarType &type) { + llvm::errs() << "<"; + const char *sep = ""; + for (auto shape : type.shape) { + llvm::errs() << sep << shape; + sep = ", "; + } + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *Node) { + INDENT(); + llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + const char *sep = ""; + for (auto &arg : Node->getArgs()) { + llvm::errs() << sep << arg->getName(); + sep = ", "; + } + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *Node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(Node->getProto()); + dump(Node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *Node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &F : *Node) + dump(&F); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp new file mode 100644 index 0000000..9846764 --- /dev/null +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -0,0 +1,135 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/MLIRGen.h" +#include "toy/Parser.h" +#include + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +namespace { +enum InputType { Toy, MLIR }; +} +static cl::opt inputType( + "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), + cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), + cl::values(clEnumValN(MLIR, "mlir", + "load the input file as an MLIR file"))); + +namespace { +enum Action { None, DumpAST, DumpMLIR }; +} +static cl::opt emitAction( + "emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), + cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> FileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code EC = FileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return nullptr; + } + auto buffer = FileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.ParseModule(); +} + +int dumpMLIR() { + mlir::MLIRContext context; + std::unique_ptr module; + if (inputType == InputType::MLIR || + llvm::StringRef(inputFilename).endswith(".mlir")) { + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module.reset(mlir::parseSourceFile(sourceMgr, &context)); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + if (failed(module->verify())) { + llvm::errs() << "Error verifying MLIR module\n"; + return 4; + } + } else { + auto moduleAST = parseInputFile(inputFilename); + module = mlirGen(context, *moduleAST); + } + if (!module) + return 1; + module->dump(); + return 0; +} + +int dumpAST() { + if (inputType == InputType::MLIR) { + llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; + return 5; + } + + auto moduleAST = parseInputFile(inputFilename); + if (!moduleAST) + return 1; + + dump(*moduleAST); + return 0; +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + switch (emitAction) { + case Action::DumpAST: + return dumpAST(); + case Action::DumpMLIR: + return dumpMLIR(); + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/g3doc/Tutorials/Toy/Ch-2.md b/mlir/g3doc/Tutorials/Toy/Ch-2.md new file mode 100644 index 0000000..6ea86f1 --- /dev/null +++ b/mlir/g3doc/Tutorials/Toy/Ch-2.md @@ -0,0 +1,216 @@ +# Chapter 2: Emiting Basic MLIR + +[TOC] + +Now that we're familiar with our language and the AST, let see how MLIR can help +to compile Toy. + +## Introduction: Multi-Level IR + +Other compilers like LLVM (see the +[Kaleidoscope tutorial](https://llvm.org/docs/tutorial/LangImpl01.html)) offer +a fixed set of predefined types and, usually *low-level* / RISC-like, +instructions. It is up to the frontend for a given language to perform any +language specific type-checking, analysis, or transformation before emitting +LLVM IR. For example, clang will use its AST to perform static analysis but also +transformation like C++ template instantiation through AST cloning and rewrite. +Finally, languages with construction higher-level than C/C++ may require +non-trivial lowering from their AST to generate LLVM IR. + +As a consequence, multiple frontends end up reimplementing significant pieces of +infrastructure to support the need for these analyses and transformation. MLIR +addresses this issue by being designed for extensibility. As such, there is +little to no pre-defined set of instructions (*operations* in MLIR +terminology) or types. + +## MLIR Module, Functions, Blocks, and Operations + +[Language reference](LangRef.md#operations) + +In MLIR (like in LLVM), the top level structure for the IR is a Module +(equivalent to a translation unit in C/C++). A module contains a list of +functions, and each function has a list of blocks forming a CFG. Each block is a +list of operations that execute in sequence. + +Operations in MLIR are similar to instructions in LLVM, however MLIR does not +have a closed set of operations. Instead, MLIR operations are fully extensible +and can have application-specific semantics. + +Here is the MLIR assembly for the Toy 'transpose' operations: + +```MLIR(.mlir) +%t_array = "toy.transpose"(%array) { inplace: true } : (!toy<"array<2, 3">) -> !toy<"array<3, 2"> +``` + +Let's look at the anatomy of this MLIR operation: + +- it is identified by its name, which is expected to be a unique string (e.g. + `toy.transpose`). +- it takes as input zero or more operands (or arguments), which are SSA values + defined by other operations or refering to function and block arguments + (e.g. `%array`). +- it produces zero or more results (we will limit ourselves to a single result + in the context of Toy), which are SSA values (e.g. `%t_array`). +- it has zero or more attributes, which are special operands that are always + constant (e.g. `inplace: true`). +- Lastly the type of the operation appears at the end in a functional form, + spelling the types of the arguments in parentheses and the type of the + return values afterward. + +Finally, in MLIR every operation has a mandatory source location associated with +it. Contrary to LLVM where debug info locations are metadata and can be dropped, +in MLIR the location is a core requirement which translates in APIs manipulating +operations requiring it. Dropping a location becomes an explicit choice and +cannot happen by mistake. + + +## Opaque Builder API + +Operations and types can be created with only their string names using the +raw builder API. This allows MLIR to parse, represent, and round-trip any valid +IR. For example, the following can round-trip through *mlir-opt*: + +```MLIR(.mlir) +func @some_func(%arg0: !random_dialect<"custom_type">) -> !another_dialect<"other_type"> { + %result = "custom.operation"(%arg0) : (!random_dialect<"custom_type">) -> !another_dialect<"other_type"> + return %result : !another_dialect<"other_type"> +} +``` + +Here MLIR will enforce some structural constraints (SSA, block termination, +return operand type coherent with function return type, etc.) but otherwise the +types and the operation are completely opaque. + +We will take advantage of this facility to emit MLIR for Toy by traversing the +AST. Our types will be prefixed with "!toy" and our operation name with "toy.". +MLIR refers to this prefix as a *dialect*, we will introduce this with more +details in the [next chapter](Ch-3.md). + +Programmatically creating an opaque operation like the one above involves using +the `mlir::OperationState` structure which group all the basic elements needs to +build an operation with an `mlir::Builder`: + +- The name of the operation. +- A location for debugging purpose. It is mandatory, but can be explicitly set + to "unknown". +- The list of operand values. +- The types for returned values. +- The list of attributes. +- A list of successors (for branches mostly). + +To build the `custom.operation` from the listing above, assuming you have a +`Value *` handle to `%arg0`, is as simple as: + +```c++ +// The return type for the operation: `!another_dialect<"other_type">` +auto another_dialect_prefix = mlir::Identifier::get("another_dialect", &context); +auto returnType = mlir::UnknownType::get(another_dialect_prefix, + "custom_type", &context); +// Creation of the state defining the operation: +mlir::OperationState state(&context, location, "custom.operation"); +state.types.push_back(returnType); +state.operands.push_back(arg0); +// Using a builder to create the operation and insert it where the builder +// insertion point is currently set. +auto customOperation = builder->createOperation(state); +// An operation is not an SSA value (unlike LLVM), because it can return +// multiple SSA value, the resulting value can be obtained: +Value *result = customOperation->getResult(0); +``` + +This approach is used in `Ch2/mlir/MLIRGen.cpp` to implement a naive MLIR +generation through a simple depth-first search traversal of the Toy AST. Here is +how we create a `toy.transpose` operation: + +``` +mlir::Operation *createTransposeOp(FuncBuilder *builder, + mlir::Value *input_array) { + // We bundle our custom type in a `toy` dialect. + auto toyDialect = mlir::Identifier::get("toy", builder->getContext()); + // Create a custom type, in the MLIR assembly it is: !toy<"array<2, 2>"> + auto type = mlir::UnknownType::get(toyDialect, "array<2, 2>", builder->getContext()); + + // Fill the `OperationState` with the required fields + mlir::OperationState result(builder->getContext(), location, "toy.transpose"); + result.types.push_back(type); // return type + result.operands.push_back(input_value); // argument + Operation *newTransposeOp = builder->createOperation(result); + return newTransposeOp; +} +``` + +## Complete Toy Example + +FIXME: It would be nice to have an idea for the **need** of a custom **type** in +Toy? Right now `toy` could be replaced directly by unranked `tensor<*>` +and `toy>` could be replaced by a `memref`. + +At this point we can already generate our "Toy IR" without having registered +anything with MLIR. A simplified version of the previous example: + +```Toy {.toy} +# User defined generic function that operates on unknown shaped arguments. +def multiply_transpose(a, b) { + return a * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} +``` + +Results in the following IR: + +```MLIR(.mlir) +func @multiply_transpose(%arg0: !toy<"array">, %arg1: !toy<"array">) + attributes {toy.generic: true} loc("test/codegen.toy":2:1) { + %0 = "toy.transpose"(%arg1) : (!toy<"array">) -> !toy<"array"> loc("test/codegen.toy":3:14) + %1 = "toy.mul"(%arg0, %0) : (!toy<"array">, !toy<"array">) -> !toy<"array"> loc("test/codegen.toy":3:14) + "toy.return"(%1) : (!toy<"array">) -> () loc("test/codegen.toy":3:3) +} + +func @main() loc("test/codegen.toy":6:1) { + %0 = "toy.constant"() {value: dense, [[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> !toy<"array<2, 3>"> loc("test/codegen.toy":7:17) + %1 = "toy.reshape"(%0) : (!toy<"array<2, 3>">) -> !toy<"array<2, 3>"> loc("test/codegen.toy":7:3) + %2 = "toy.constant"() {value: dense, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]>} : () -> !toy<"array<6>"> loc("test/codegen.toy":8:17) + %3 = "toy.reshape"(%2) : (!toy<"array<6>">) -> !toy<"array<2, 3>"> loc("test/codegen.toy":8:3) + %4 = "toy.generic_call"(%1, %3, %1, %3) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> loc("test/codegen.toy":9:11) + %5 = "toy.generic_call"(%3, %1, %3, %1) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> loc("test/codegen.toy":10:11) + "toy.print"(%5) : (!toy<"array">) -> () loc("test/codegen.toy":11:3) + "toy.return"() : () -> () loc("test/codegen.toy":6:1) +} +``` + +You can build `toyc` and try yourself: `toyc test/codegen.toy -emit=mlir +-mlir-print-debuginfo`. We can also check our RoundTrip: `toyc test/codegen.toy +-emit=mlir -mlir-print-debuginfo > codegen.mlir` followed by `toyc codegen.mlir +-emit=mlir`. + +Notice how these MLIR operations are prefixed with `toy.` ; by convention we use +this similarly to a "namespace" in order to avoid conflicting with other +operations with the same name. Similarly the syntax for types wraps an arbitrary +string representing our custom types within our "namespace" `!toy<...>`. Of +course at this point MLIR does not know anything about Toy, and so there is no +semantic associated with the operations and types, everything is opaque and +string-based. The only thing enforced by MLIR here is that the IR is in SSA +form: values are defined once, and uses appears after their definition. + +This can be observed by crafting what should be an invalid IR for Toy and see it +round-trip without tripping the verifier: + +```MLIR(.mlir) +// RUN: toyc %s -emit=mlir +func @main() { + %0 = "toy.print"() : () -> !toy<"array<2, 3>"> +} +``` + +There are multiple problems here: first the `toy.print` is not a terminator, +then it should take an operand, and not return any value. + +In the [next chapter](Ch-2.md) we will register our dialect and operations with +MLIR, plug in the verifier, and add nicer APIs to manipulate our operations. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 16e2cb4..f3737f2 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -29,6 +29,7 @@ set(MLIR_TEST_DEPENDS if(LLVM_BUILD_EXAMPLES) list(APPEND MLIR_TEST_DEPENDS toyc-ch1 + toyc-ch2 ) endif() diff --git a/mlir/test/Examples/Toy/Ch1/ast.toy b/mlir/test/Examples/Toy/Ch1/ast.toy index e8c8fe0..0069869 100644 --- a/mlir/test/Examples/Toy/Ch1/ast.toy +++ b/mlir/test/Examples/Toy/Ch1/ast.toy @@ -7,8 +7,11 @@ def multiply_transpose(a, b) { } def main() { - # Define a variable `a` with shape <2, 3>, initialized with the literal value - var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitely reshaped: defining new + # variables is the way to reshape arrays (element count must match). var b<2, 3> = [1, 2, 3, 4, 5, 6]; # This call will specialize `multiply_transpose` with <2, 3> for both # arguments and deduce a return type of <2, 2> in initialization of `c`. @@ -26,46 +29,45 @@ def main() { # CHECK: Module: -# CHECK-NEXT: Function -# CHECK-NEXT: Proto 'multiply_transpose' +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}Toy/Ch1/ast.toy:5:1' # CHECK-NEXT: Params: [a, b] # CHECK-NEXT: Block { # CHECK-NEXT: Retur -# CHECK-NEXT: BinOp: * -# CHECK-NEXT: var: a -# CHECK-NEXT: Call 'transpose' [ -# CHECK-NEXT: var: b +# CHECK-NEXT: BinOp: * @{{.*}}Toy/Ch1/ast.toy:6:14 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch1/ast.toy:6:10 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}Toy/Ch1/ast.toy:6:14 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch1/ast.toy:6:24 # CHECK-NEXT: ] # CHECK-NEXT: } // Block -# CHECK-NEXT: Function -# CHECK-NEXT: Proto 'main' +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}Toy/Ch1/ast.toy:9:1' # CHECK-NEXT: Params: [] # CHECK-NEXT: Block { -# CHECK-NEXT: VarDecl a<2, 3> -# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] -# CHECK-NEXT: VarDecl b<2, 3> -# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] -# CHECK-NEXT: VarDecl c<> -# CHECK-NEXT: Call 'multiply_transpose' [ -# CHECK-NEXT: var: a -# CHECK-NEXT: var: b +# CHECK-NEXT: VarDecl a<> @{{.*}}Toy/Ch1/ast.toy:12:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}Toy/Ch1/ast.toy:12:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}Toy/Ch1/ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}Toy/Ch1/ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}Toy/Ch1/ast.toy:18:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch1/ast.toy:18:11 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch1/ast.toy:18:30 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch1/ast.toy:18:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl d<> -# CHECK-NEXT: Call 'multiply_transpose' [ -# CHECK-NEXT: var: b -# CHECK-NEXT: var: a +# CHECK-NEXT: VarDecl d<> @{{.*}}Toy/Ch1/ast.toy:21:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch1/ast.toy:21:11 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch1/ast.toy:21:30 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch1/ast.toy:21:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl e<> -# CHECK-NEXT: Call 'multiply_transpose' [ -# CHECK-NEXT: var: b -# CHECK-NEXT: var: c +# CHECK-NEXT: VarDecl e<> @{{.*}}Toy/Ch1/ast.toy:24:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch1/ast.toy:24:11 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch1/ast.toy:24:30 +# CHECK-NEXT: var: c @{{.*}}Toy/Ch1/ast.toy:24:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl e<> -# CHECK-NEXT: Call 'multiply_transpose' [ -# CHECK-NEXT: Call 'transpose' [ -# CHECK-NEXT: var: a +# CHECK-NEXT: VarDecl e<> @{{.*}}Toy/Ch1/ast.toy:27:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch1/ast.toy:27:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}Toy/Ch1/ast.toy:27:30 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch1/ast.toy:27:40 # CHECK-NEXT: ] -# CHECK-NEXT: var: c +# CHECK-NEXT: var: c @{{.*}}Toy/Ch1/ast.toy:27:44 # CHECK-NEXT: ] -# CHECK-NEXT: } // Block diff --git a/mlir/test/Examples/Toy/Ch2/ast.toy b/mlir/test/Examples/Toy/Ch2/ast.toy new file mode 100644 index 0000000..91f26b7 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch2/ast.toy @@ -0,0 +1,73 @@ +# RUN: toyc-ch2 %s -emit=ast 2>&1 | FileCheck %s + + +# User defined generic function that operates solely on +def multiply_transpose(a, b) { + return a * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitely reshaped: defining new + # variables is the way to reshape arrays (element count must match). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var e = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}Toy/Ch2/ast.toy:5:1' +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Retur +# CHECK-NEXT: BinOp: * @{{.*}}Toy/Ch2/ast.toy:6:14 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch2/ast.toy:6:10 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}Toy/Ch2/ast.toy:6:14 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch2/ast.toy:6:24 +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' @{{.*}}Toy/Ch2/ast.toy:9:1' +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<> @{{.*}}Toy/Ch2/ast.toy:12:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}Toy/Ch2/ast.toy:12:11 +# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}Toy/Ch2/ast.toy:15:3 +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}Toy/Ch2/ast.toy:15:17 +# CHECK-NEXT: VarDecl c<> @{{.*}}Toy/Ch2/ast.toy:18:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch2/ast.toy:18:11 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch2/ast.toy:18:30 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch2/ast.toy:18:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> @{{.*}}Toy/Ch2/ast.toy:21:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch2/ast.toy:21:11 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch2/ast.toy:21:30 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch2/ast.toy:21:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}Toy/Ch2/ast.toy:24:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch2/ast.toy:24:11 +# CHECK-NEXT: var: b @{{.*}}Toy/Ch2/ast.toy:24:30 +# CHECK-NEXT: var: c @{{.*}}Toy/Ch2/ast.toy:24:33 +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> @{{.*}}Toy/Ch2/ast.toy:27:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}Toy/Ch2/ast.toy:27:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}Toy/Ch2/ast.toy:27:30 +# CHECK-NEXT: var: a @{{.*}}Toy/Ch2/ast.toy:27:40 +# CHECK-NEXT: ] +# CHECK-NEXT: var: c @{{.*}}Toy/Ch2/ast.toy:27:44 +# CHECK-NEXT: ] + diff --git a/mlir/test/Examples/Toy/Ch2/codegen.toy b/mlir/test/Examples/Toy/Ch2/codegen.toy new file mode 100644 index 0000000..f2397e6 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch2/codegen.toy @@ -0,0 +1,32 @@ +# RUN: toyc-ch2 %s -emit=mlir 2>&1 | FileCheck %s + +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return a * transpose(b); +} + +def main() { + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + var c = multiply_transpose(a, b); + var d = multiply_transpose(b, a); + print(d); +} + +# CHECK-LABEL: func @multiply_transpose(%arg0: !toy<"array">, %arg1: !toy<"array">) +# CHECK-NEXT: attributes {toy.generic: true} { +# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy<"array">) -> !toy<"array"> +# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy<"array">, !toy<"array">) -> !toy<"array"> +# CHECK-NEXT: "toy.return"(%1) : (!toy<"array">) -> () +# CHECK-NEXT: } + +# CHECK-LABEL: func @main() { +# CHECK-NEXT: %0 = "toy.constant"() {value: dense, {{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> !toy<"array<2, 3>"> +# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy<"array<2, 3>">) -> !toy<"array<2, 3>"> +# CHECK-NEXT: %2 = "toy.constant"() {value: dense, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]>} : () -> !toy<"array<6>"> +# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy<"array<6>">) -> !toy<"array<2, 3>"> +# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3, %1, %3) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1, %3, %1) {callee: "multiply_transpose"} : (!toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> +# CHECK-NEXT: "toy.print"(%5) : (!toy<"array">) -> () +# CHECK-NEXT: "toy.return"() : () -> () + diff --git a/mlir/test/Examples/Toy/Ch2/invalid.mlir b/mlir/test/Examples/Toy/Ch2/invalid.mlir new file mode 100644 index 0000000..324d4ca --- /dev/null +++ b/mlir/test/Examples/Toy/Ch2/invalid.mlir @@ -0,0 +1,11 @@ +// RUN: toyc-ch2 %s -emit=mlir 2>&1 + + +// This IR is not "valid": +// - toy.print should not return a value. +// - toy.print should take an argument. +// - There should be a block terminator. +// This all round-trip since this is opaque for MLIR. +func @main() { + %0 = "toy.print"() : () -> !toy<"array<2, 3>"> +} diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 75459a1..172558a 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -58,6 +58,7 @@ tools = [ # The following tools are optional tools.extend([ ToolSubst('toy-ch1', unresolved='ignore'), + ToolSubst('toy-ch2', unresolved='ignore'), ]) llvm_config.add_tool_substitutions(tools, tool_dirs) -- 2.7.4