From: Mehdi Amini Date: Tue, 2 Apr 2019 17:02:07 +0000 (-0700) Subject: Initial version for chapter 1 of the Toy tutorial X-Git-Tag: llvmorg-11-init~1466^2~2050 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=38b71d6b8468c436f5dcdb88f937fa0f4e62c22e;p=platform%2Fupstream%2Fllvm.git Initial version for chapter 1 of the Toy tutorial -- PiperOrigin-RevId: 241549247 --- diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 882c249..5dd7d45 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -43,3 +43,7 @@ add_subdirectory(lib) add_subdirectory(tools) add_subdirectory(unittests) add_subdirectory(test) + +if( LLVM_INCLUDE_EXAMPLES ) + add_subdirectory(examples) +endif() diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt new file mode 100644 index 0000000..9dc01fe --- /dev/null +++ b/mlir/examples/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(toy) + diff --git a/mlir/examples/toy/CMakeLists.txt b/mlir/examples/toy/CMakeLists.txt new file mode 100644 index 0000000..19f5293 --- /dev/null +++ b/mlir/examples/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(Toy) +set_target_properties(Toy PROPERTIES FOLDER Examples) + +macro(add_toy_chapter name) + add_dependencies(Toy ${name}) + add_llvm_example(${name} ${ARGN}) +endmacro(add_toy_chapter name) + +add_subdirectory(Ch1) diff --git a/mlir/examples/toy/Ch1/CMakeLists.txt b/mlir/examples/toy/Ch1/CMakeLists.txt new file mode 100644 index 0000000..dd26cf1 --- /dev/null +++ b/mlir/examples/toy/Ch1/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_LINK_COMPONENTS + Support + ) + +add_toy_chapter(toyc-ch1 + toyc.cpp + parser/AST.cpp + ) +include_directories(include/) \ No newline at end of file diff --git a/mlir/examples/toy/Ch1/include/toy/AST.h b/mlir/examples/toy/Ch1/include/toy/AST.h new file mode 100644 index 0000000..456a323 --- /dev/null +++ b/mlir/examples/toy/Ch1/include/toy/AST.h @@ -0,0 +1,256 @@ +//===- AST.h - Node definition for the Toy AST ----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST for the Toy language. It is optimized for +// simplicity, not efficiency. The AST forms a tree structure where each node +// references its children using std::unique_ptr<>. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_AST_H_ +#define MLIR_TUTORIAL_TOY_AST_H_ + +#include "toy/Lexer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include + +namespace toy { + +/// A variable +struct VarType { + enum { TY_FLOAT, TY_INT } elt_ty; + std::vector shape; +}; + +/// Base class for all expression nodes. +class ExprAST { +public: + enum ExprASTKind { + Expr_VarDecl, + Expr_Return, + Expr_Num, + Expr_Literal, + Expr_Var, + Expr_BinOp, + Expr_Call, + Expr_Print, // builtin + Expr_If, + Expr_For, + }; + + ExprAST(ExprASTKind kind, Location location) + : kind(kind), location(location) {} + + virtual ~ExprAST() = default; + + ExprASTKind getKind() const { return kind; } + + const Location &loc() { return location; } + +private: + const ExprASTKind kind; + Location location; +}; + +/// A block-list of expressions. +using ExprASTList = std::vector>; + +/// Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {} + + double getValue() { return Val; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } +}; + +/// +class LiteralExprAST : public ExprAST { + std::vector> values; + std::vector dims; + +public: + LiteralExprAST(Location loc, std::vector> values, + std::vector dims) + : ExprAST(Expr_Literal, loc), values(std::move(values)), + dims(std::move(dims)) {} + + std::vector> &getValues() { return values; } + std::vector &getDims() { return dims; } + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; } +}; + +/// Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string name; + +public: + VariableExprAST(Location loc, const std::string &name) + : ExprAST(Expr_Var, loc), name(name) {} + + llvm::StringRef getName() { return name; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } +}; + +/// +class VarDeclExprAST : public ExprAST { + std::string name; + VarType type; + std::unique_ptr initVal; + +public: + VarDeclExprAST(Location loc, const std::string &name, VarType type, + std::unique_ptr initVal) + : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)), + initVal(std::move(initVal)) {} + + llvm::StringRef getName() { return name; } + ExprAST *getInitVal() { return initVal.get(); } + VarType &getType() { return type; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } +}; + +/// +class ReturnExprAST : public ExprAST { + llvm::Optional> expr; + +public: + ReturnExprAST(Location loc, llvm::Optional> expr) + : ExprAST(Expr_Return, loc), expr(std::move(expr)) {} + + llvm::Optional getExpr() { + if (expr.hasValue()) + return expr->get(); + return llvm::NoneType(); + } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; } +}; + +/// Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char Op; + std::unique_ptr LHS, RHS; + +public: + char getOp() { return Op; } + ExprAST *getLHS() { return LHS.get(); } + ExprAST *getRHS() { return RHS.get(); } + + BinaryExprAST(Location loc, char Op, std::unique_ptr LHS, + std::unique_ptr RHS) + : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)), + RHS(std::move(RHS)) {} + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; } +}; + +/// Expression class for function calls. +class CallExprAST : public ExprAST { + std::string Callee; + std::vector> Args; + +public: + CallExprAST(Location loc, const std::string &Callee, + std::vector> Args) + : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {} + + llvm::StringRef getCallee() { return Callee; } + llvm::ArrayRef> getArgs() { return Args; } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; } +}; + +/// Expression class for builtin print calls. +class PrintExprAST : public ExprAST { + std::unique_ptr Arg; + +public: + PrintExprAST(Location loc, std::unique_ptr Arg) + : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {} + + ExprAST *getArg() { return Arg.get(); } + + /// LLVM style RTTI + static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; } +}; + +/// This class represents the "prototype" for a function, which captures its +/// name, and its argument names (thus implicitly the number of arguments the +/// function takes). +class PrototypeAST { + Location location; + std::string name; + std::vector> args; + +public: + PrototypeAST(Location location, const std::string &name, + std::vector> args) + : location(location), name(name), args(std::move(args)) {} + + const Location &loc() { return location; } + const std::string &getName() const { return name; } + const std::vector> &getArgs() { + return args; + } +}; + +/// This class represents a function definition itself. +class FunctionAST { + std::unique_ptr Proto; + std::unique_ptr Body; + +public: + FunctionAST(std::unique_ptr Proto, + std::unique_ptr Body) + : Proto(std::move(Proto)), Body(std::move(Body)) {} + PrototypeAST *getProto() { return Proto.get(); } + ExprASTList *getBody() { return Body.get(); } +}; + +/// This class represents a list of functions to be processed together +class ModuleAST { + std::vector functions; + +public: + ModuleAST(std::vector functions) + : functions(std::move(functions)) {} + + auto begin() -> decltype(functions.begin()) { return functions.begin(); } + auto end() -> decltype(functions.end()) { return functions.end(); } +}; + +void dump(ModuleAST &); + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_AST_H_ diff --git a/mlir/examples/toy/Ch1/include/toy/Lexer.h b/mlir/examples/toy/Ch1/include/toy/Lexer.h new file mode 100644 index 0000000..d73adb9 --- /dev/null +++ b/mlir/examples/toy/Ch1/include/toy/Lexer.h @@ -0,0 +1,239 @@ +//===- Lexer.h - Lexer for the Toy language -------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple Lexer for the Toy language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_LEXER_H_ +#define MLIR_TUTORIAL_TOY_LEXER_H_ + +#include "llvm/ADT/StringRef.h" + +#include +#include + +namespace toy { + +/// Structure definition a location in a file. +struct Location { + std::shared_ptr file; ///< filename + int line; ///< line number. + int col; ///< column number. +}; + +// List of Token returned by the lexer. +enum Token : int { + tok_semicolon = ';', + tok_parenthese_open = '(', + tok_parenthese_close = ')', + tok_bracket_open = '{', + tok_bracket_close = '}', + tok_sbracket_open = '[', + tok_sbracket_close = ']', + + tok_eof = -1, + + // commands + tok_return = -2, + tok_var = -3, + tok_def = -4, + + // primary + tok_identifier = -5, + tok_number = -6, +}; + +/// The Lexer is an abstract base class providing all the facilities that the +/// Parser expects. It goes through the stream one token at a time and keeps +/// track of the location in the file for debugging purpose. +/// It relies on a subclass to provide a `readNextLine()` method. The subclass +/// can proceed by reading the next line from the standard input or from a +/// memory mapped file. +class Lexer { +public: + /// Create a lexer for the given filename. The filename is kept only for + /// debugging purpose (attaching a location to a Token). + Lexer(std::string filename) + : lastLocation( + {std::make_shared(std::move(filename)), 0, 0}) {} + virtual ~Lexer() = default; + + /// Look at the current token in the stream. + Token getCurToken() { return curTok; } + + /// Move to the next token in the stream and return it. + Token getNextToken() { return curTok = getTok(); } + + /// Move to the next token in the stream, asserting on the current token + /// matching the expectation. + void consume(Token tok) { + assert(tok == curTok && "consume Token mismatch expectation"); + getNextToken(); + } + + /// Return the current identifier (prereq: getCurToken() == tok_identifier) + llvm::StringRef getId() { + assert(curTok == tok_identifier); + return IdentifierStr; + } + + /// Return the current number (prereq: getCurToken() == tok_number) + double getValue() { + assert(curTok == tok_number); + return NumVal; + } + + /// Return the location for the beginning of the current token. + Location getLastLocation() { return lastLocation; } + + // Return the current line in the file. + int getLine() { return curLineNum; } + + // Return the current column in the file. + int getCol() { return curCol; } + +private: + /// Delegate to a derived class fetching the next line. Returns an empty + /// string to signal end of file (EOF). Lines are expected to always finish + /// with "\n" + virtual llvm::StringRef readNextLine() = 0; + + /// Return the next character from the stream. This manages the buffer for the + /// current line and request the next line buffer to the derived class as + /// needed. + int getNextChar() { + // The current line buffer should not be empty unless it is the end of file. + if (curLineBuffer.empty()) + return EOF; + ++curCol; + auto nextchar = curLineBuffer.front(); + curLineBuffer = curLineBuffer.drop_front(); + if (curLineBuffer.empty()) + curLineBuffer = readNextLine(); + if (nextchar == '\n') { + ++curLineNum; + curCol = 0; + } + return nextchar; + } + + /// Return the next token from standard input. + Token getTok() { + // Skip any whitespace. + while (isspace(LastChar)) + LastChar = Token(getNextChar()); + + // Save the current location before reading the token characters. + lastLocation.line = curLineNum; + lastLocation.col = curCol; + + if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]* + IdentifierStr = (char)LastChar; + while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_') + IdentifierStr += (char)LastChar; + + if (IdentifierStr == "return") + return tok_return; + if (IdentifierStr == "def") + return tok_def; + if (IdentifierStr == "var") + return tok_var; + return tok_identifier; + } + + if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ + std::string NumStr; + do { + NumStr += LastChar; + LastChar = Token(getNextChar()); + } while (isdigit(LastChar) || LastChar == '.'); + + NumVal = strtod(NumStr.c_str(), nullptr); + return tok_number; + } + + if (LastChar == '#') { + // Comment until end of line. + do + LastChar = Token(getNextChar()); + while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + + if (LastChar != EOF) + return getTok(); + } + + // Check for end of file. Don't eat the EOF. + if (LastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + Token ThisChar = Token(LastChar); + LastChar = Token(getNextChar()); + return ThisChar; + } + + /// The last token read from the input. + Token curTok = tok_eof; + + /// Location for `curTok`. + Location lastLocation; + + /// If the current Token is an identifier, this string contains the value. + std::string IdentifierStr; + + /// If the current Token is a number, this contains the value. + double NumVal = 0; + + /// The last value returned by getNextChar(). We need to keep it around as we + /// always need to read ahead one character to decide when to end a token and + /// we can't put it back in the stream after reading from it. + Token LastChar = Token(' '); + + /// Keep track of the current line number in the input stream + int curLineNum = 0; + + /// Keep track of the current column number in the input stream + int curCol = 0; + + /// Buffer supplied by the derived class on calls to `readNextLine()` + llvm::StringRef curLineBuffer = "\n"; +}; + +/// A lexer implementation operating on a buffer in memory. +class LexerBuffer final : public Lexer { +public: + LexerBuffer(const char *begin, const char *end, std::string filename) + : Lexer(std::move(filename)), current(begin), end(end) {} + +private: + /// Provide one line at a time to the Lexer, return an empty string when + /// reaching the end of the buffer. + llvm::StringRef readNextLine() override { + auto *begin = current; + while (current <= end && *current && *current != '\n') + ++current; + if (current <= end && *current) + ++current; + llvm::StringRef result{begin, static_cast(current - begin)}; + return result; + } + const char *current, *end; +}; +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_LEXER_H_ diff --git a/mlir/examples/toy/Ch1/include/toy/Parser.h b/mlir/examples/toy/Ch1/include/toy/Parser.h new file mode 100644 index 0000000..bc7aa52 --- /dev/null +++ b/mlir/examples/toy/Ch1/include/toy/Parser.h @@ -0,0 +1,494 @@ +//===- Parser.h - Toy Language Parser -------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the parser for the Toy language. It processes the Token +// provided by the Lexer and returns an AST. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_PARSER_H +#define MLIR_TUTORIAL_TOY_PARSER_H + +#include "toy/AST.h" +#include "toy/Lexer.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace toy { + +/// This is a simple recursive parser for the Toy language. It produces a well +/// formed AST from a stream of Token supplied by the Lexer. No semantic checks +/// or symbol resolution is performed. For example, variables are referenced by +/// string and the code could reference an undeclared variable and the parsing +/// succeeds. +class Parser { +public: + /// Create a Parser for the supplied lexer. + Parser(Lexer &lexer) : lexer(lexer) {} + + /// Parse a full Module. A module is a list of function definitions. + std::unique_ptr ParseModule() { + lexer.getNextToken(); // prime the lexer + + // Parse functions one at a time and accumulate in this vector. + std::vector functions; + while (auto F = ParseDefinition()) { + functions.push_back(std::move(*F)); + if (lexer.getCurToken() == tok_eof) + break; + } + // If we didn't reach EOF, there was an error during parsing + if (lexer.getCurToken() != tok_eof) + return parseError("nothing", "at end of module"); + + return llvm::make_unique(std::move(functions)); + } + +private: + Lexer &lexer; + + /// Parse a return statement. + /// return :== return ; | return expr ; + std::unique_ptr ParseReturn() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_return); + + // return takes an optional argument + llvm::Optional> expr; + if (lexer.getCurToken() != ';') { + expr = ParseExpression(); + if (!expr) + return nullptr; + } + return llvm::make_unique(std::move(loc), std::move(expr)); + } + + /// Parse a literal number. + /// numberexpr ::= number + std::unique_ptr ParseNumberExpr() { + auto loc = lexer.getLastLocation(); + auto Result = + llvm::make_unique(std::move(loc), lexer.getValue()); + lexer.consume(tok_number); + return std::move(Result); + } + + /// Parse a literal array expression. + /// tensorLiteral ::= [ literalList ] | number + /// literalList ::= tensorLiteral | tensorLiteral, literalList + std::unique_ptr ParseTensorLitteralExpr() { + auto loc = lexer.getLastLocation(); + lexer.consume(Token('[')); + + // Hold the list of values at this nesting level. + std::vector> values; + // Hold the dimensions for all the nesting inside this level. + std::vector dims; + do { + // We can have either another nested array or a number literal. + if (lexer.getCurToken() == '[') { + values.push_back(ParseTensorLitteralExpr()); + if (!values.back()) + return nullptr; // parse error in the nested array. + } else { + if (lexer.getCurToken() != tok_number) + return parseError(" or [", "in literal expression"); + values.push_back(ParseNumberExpr()); + } + + // End of this list on ']' + if (lexer.getCurToken() == ']') + break; + + // Elements are separated by a comma. + if (lexer.getCurToken() != ',') + return parseError("] or ,", "in literal expression"); + + lexer.getNextToken(); // eat , + } while (true); + if (values.empty()) + return parseError("", "to fill literal expression"); + lexer.getNextToken(); // eat ] + /// Fill in the dimensions now. First the current nesting level: + dims.push_back(values.size()); + /// If there is any nested array, process all of them and ensure that + /// dimensions are uniform. + if (llvm::any_of(values, [](std::unique_ptr &expr) { + return llvm::isa(expr.get()); + })) { + auto *firstLiteral = llvm::dyn_cast(values.front().get()); + if (!firstLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + + // Append the nested dimensions to the current level + auto &firstDims = firstLiteral->getDims(); + dims.insert(dims.end(), firstDims.begin(), firstDims.end()); + + // Sanity check that shape is uniform across all elements of the list. + for (auto &expr : values) { + auto *exprLiteral = llvm::cast(expr.get()); + if (!exprLiteral) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + if (exprLiteral->getDims() != firstDims) + return parseError("uniform well-nested dimensions", + "inside literal expession"); + } + } + return llvm::make_unique(std::move(loc), std::move(values), + std::move(dims)); + } + + /// parenexpr ::= '(' expression ')' + std::unique_ptr ParseParenExpr() { + lexer.getNextToken(); // eat (. + auto V = ParseExpression(); + if (!V) + return nullptr; + + if (lexer.getCurToken() != ')') + return parseError(")", "to close expression with parentheses"); + lexer.consume(Token(')')); + return V; + } + + /// identifierexpr + /// ::= identifier + /// ::= identifier '(' expression ')' + std::unique_ptr ParseIdentifierExpr() { + std::string name = lexer.getId(); + + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat identifier. + + if (lexer.getCurToken() != '(') // Simple variable ref. + return llvm::make_unique(std::move(loc), name); + + // This is a function call. + lexer.consume(Token('(')); + std::vector> Args; + if (lexer.getCurToken() != ')') { + while (true) { + if (auto Arg = ParseExpression()) + Args.push_back(std::move(Arg)); + else + return nullptr; + + if (lexer.getCurToken() == ')') + break; + + if (lexer.getCurToken() != ',') + return parseError(", or )", "in argument list"); + lexer.getNextToken(); + } + } + lexer.consume(Token(')')); + + // It can be a builtin call to print + if (name == "print") { + if (Args.size() != 1) + return parseError("", "as argument to print()"); + + return llvm::make_unique(std::move(loc), + std::move(Args[0])); + } + + // Call to a user-defined function + return llvm::make_unique(std::move(loc), name, + std::move(Args)); + } + + /// primary + /// ::= identifierexpr + /// ::= numberexpr + /// ::= parenexpr + /// ::= tensorliteral + std::unique_ptr ParsePrimary() { + switch (lexer.getCurToken()) { + default: + llvm::errs() << "unknown token '" << lexer.getCurToken() + << "' when expecting an expression\n"; + return nullptr; + case tok_identifier: + return ParseIdentifierExpr(); + case tok_number: + return ParseNumberExpr(); + case '(': + return ParseParenExpr(); + case '[': + return ParseTensorLitteralExpr(); + case ';': + return nullptr; + case '}': + return nullptr; + } + } + + /// Recursively parse the right hand side of a binary expression, the ExprPrec + /// argument indicates the precedence of the current binary operator. + /// + /// binoprhs ::= ('+' primary)* + std::unique_ptr ParseBinOpRHS(int ExprPrec, + std::unique_ptr LHS) { + // If this is a binop, find its precedence. + while (true) { + int TokPrec = GetTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (TokPrec < ExprPrec) + return LHS; + + // Okay, we know this is a binop. + int BinOp = lexer.getCurToken(); + lexer.consume(Token(BinOp)); + auto loc = lexer.getLastLocation(); + + // Parse the primary expression after the binary operator. + auto RHS = ParsePrimary(); + if (!RHS) + return parseError("expression", "to complete binary operator"); + + // If BinOp binds less tightly with RHS than the operator after RHS, let + // the pending operator take RHS as its LHS. + int NextPrec = GetTokPrecedence(); + if (TokPrec < NextPrec) { + RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); + if (!RHS) + return nullptr; + } + + // Merge LHS/RHS. + LHS = llvm::make_unique(std::move(loc), BinOp, + std::move(LHS), std::move(RHS)); + } + } + + /// expression::= primary binoprhs + std::unique_ptr ParseExpression() { + auto LHS = ParsePrimary(); + if (!LHS) + return nullptr; + + return ParseBinOpRHS(0, std::move(LHS)); + } + + /// type ::= < shape_list > + /// shape_list ::= num | num , shape_list + std::unique_ptr ParseType() { + if (lexer.getCurToken() != '<') + return parseError("<", "to begin type"); + lexer.getNextToken(); // eat < + + auto type = llvm::make_unique(); + + while (lexer.getCurToken() == tok_number) { + type->shape.push_back(lexer.getValue()); + lexer.getNextToken(); + if (lexer.getCurToken() == ',') + lexer.getNextToken(); + } + + if (lexer.getCurToken() != '>') + return parseError(">", "to end type"); + lexer.getNextToken(); // eat > + return type; + } + + /// Parse a variable declaration, it starts with a `var` keyword followed by + /// and identifier and an optional type (shape specification) before the + /// initializer. + /// decl ::= var identifier [ type ] = expr + std::unique_ptr ParseDeclaration() { + if (lexer.getCurToken() != tok_var) + return parseError("var", "to begin declaration"); + auto loc = lexer.getLastLocation(); + lexer.getNextToken(); // eat var + + if (lexer.getCurToken() != tok_identifier) + return parseError("identified", + "after 'var' declaration"); + std::string id = lexer.getId(); + lexer.getNextToken(); // eat id + + std::unique_ptr type; // Type is optional, it can be inferred + if (lexer.getCurToken() == '<') { + type = ParseType(); + if (!type) + return nullptr; + } + + if (!type) + type = llvm::make_unique(); + lexer.consume(Token('=')); + auto expr = ParseExpression(); + return llvm::make_unique(std::move(loc), std::move(id), + std::move(*type), std::move(expr)); + } + + /// Parse a block: a list of expression separated by semicolons and wrapped in + /// curly braces. + /// + /// block ::= { expression_list } + /// expression_list ::= block_expr ; expression_list + /// block_expr ::= decl | "return" | expr + std::unique_ptr ParseBlock() { + if (lexer.getCurToken() != '{') + return parseError("{", "to begin block"); + lexer.consume(Token('{')); + + auto exprList = llvm::make_unique(); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + + while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { + if (lexer.getCurToken() == tok_var) { + // Variable declaration + auto varDecl = ParseDeclaration(); + if (!varDecl) + return nullptr; + exprList->push_back(std::move(varDecl)); + } else if (lexer.getCurToken() == tok_return) { + // Return statement + auto ret = ParseReturn(); + if (!ret) + return nullptr; + exprList->push_back(std::move(ret)); + } else { + // General expression + auto expr = ParseExpression(); + if (!expr) + return nullptr; + exprList->push_back(std::move(expr)); + } + // Ensure that elements are separated by a semicolon. + if (lexer.getCurToken() != ';') + return parseError(";", "after expression"); + + // Ignore empty expressions: swallow sequences of semicolons. + while (lexer.getCurToken() == ';') + lexer.consume(Token(';')); + } + + if (lexer.getCurToken() != '}') + return parseError("}", "to close block"); + + lexer.consume(Token('}')); + return exprList; + } + + /// prototype ::= def id '(' decl_list ')' + /// decl_list ::= identifier | identifier, decl_list + std::unique_ptr ParsePrototype() { + auto loc = lexer.getLastLocation(); + lexer.consume(tok_def); + if (lexer.getCurToken() != tok_identifier) + return parseError("function name", "in prototype"); + + std::string FnName = lexer.getId(); + lexer.consume(tok_identifier); + + if (lexer.getCurToken() != '(') + return parseError("(", "in prototype"); + lexer.consume(Token('(')); + + std::vector> args; + if (lexer.getCurToken() != ')') { + do { + std::string name = lexer.getId(); + auto loc = lexer.getLastLocation(); + lexer.consume(tok_identifier); + auto decl = llvm::make_unique(std::move(loc), name); + args.push_back(std::move(decl)); + if (lexer.getCurToken() != ',') + break; + lexer.consume(Token(',')); + if (lexer.getCurToken() != tok_identifier) + return parseError( + "identifier", "after ',' in function parameter list"); + } while (true); + } + if (lexer.getCurToken() != ')') + return parseError("}", "to end function prototype"); + + // success. + lexer.consume(Token(')')); + return llvm::make_unique(std::move(loc), FnName, + std::move(args)); + } + + /// Parse a function definition, we expect a prototype initiated with the + /// `def` keyword, followed by a block containing a list of expressions. + /// + /// definition ::= prototype block + std::unique_ptr ParseDefinition() { + auto Proto = ParsePrototype(); + if (!Proto) + return nullptr; + + if (auto block = ParseBlock()) + return llvm::make_unique(std::move(Proto), std::move(block)); + return nullptr; + } + + /// Get the precedence of the pending binary operator token. + int GetTokPrecedence() { + if (!isascii(lexer.getCurToken())) + return -1; + + // 1 is lowest precedence. + switch (static_cast(lexer.getCurToken())) { + case '-': + return 20; + case '+': + return 20; + case '*': + return 40; + default: + return -1; + } + } + + /// Helper function to signal errors while parsing, it takes an argument + /// indicating the expected token and another argument giving more context. + /// Location is retrieved from the lexer to enrich the error message. + template + std::unique_ptr parseError(T &&expected, U &&context = "") { + auto curToken = lexer.getCurToken(); + llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " + << lexer.getLastLocation().col << "): expected '" << expected + << "' " << context << " but has Token " << curToken; + if (isprint(curToken)) + llvm::errs() << " '" << (char)curToken << "'"; + llvm::errs() << "\n"; + return nullptr; + } +}; + +} // namespace toy + +#endif // MLIR_TUTORIAL_TOY_PARSER_H diff --git a/mlir/examples/toy/Ch1/parser/AST.cpp b/mlir/examples/toy/Ch1/parser/AST.cpp new file mode 100644 index 0000000..869f2ef --- /dev/null +++ b/mlir/examples/toy/Ch1/parser/AST.cpp @@ -0,0 +1,263 @@ +//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the AST dump for the Toy language. +// +//===----------------------------------------------------------------------===// + +#include "toy/AST.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; + +namespace { + +// RAII helper to manage increasing/decreasing the indentation as we traverse +// the AST +struct Indent { + Indent(int &level) : level(level) { ++level; } + ~Indent() { --level; } + int &level; +}; + +/// Helper class that implement the AST tree traversal and print the nodes along +/// the way. The only data member is the current indentation level. +class ASTDumper { +public: + void dump(ModuleAST *Node); + +private: + void dump(VarType &type); + void dump(VarDeclExprAST *varDecl); + void dump(ExprAST *expr); + void dump(ExprASTList *exprList); + void dump(NumberExprAST *num); + void dump(LiteralExprAST *Node); + void dump(VariableExprAST *Node); + void dump(ReturnExprAST *Node); + void dump(BinaryExprAST *Node); + void dump(CallExprAST *Node); + void dump(PrintExprAST *Node); + void dump(PrototypeAST *Node); + void dump(FunctionAST *Node); + + // Actually print spaces matching the current indentation level + void indent() { + for (int i = 0; i < curIndent; i++) + llvm::errs() << " "; + } + int curIndent = 0; +}; + +} // namespace + +/// Return a formatted string for the location of any node +template static std::string loc(T *Node) { + const auto &loc = Node->loc(); + return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + + llvm::Twine(loc.col)) + .str(); +} + +// Helper Macro to bump the indentation level and print the leading spaces for +// the current indentations +#define INDENT() \ + Indent level_(curIndent); \ + indent(); + +/// Dispatch to a generic expressions to the appropriate subclass using RTTI +void ASTDumper::dump(ExprAST *expr) { +#define dispatch(CLASS) \ + if (CLASS *node = llvm::dyn_cast(expr)) \ + return dump(node); + dispatch(VarDeclExprAST); + dispatch(LiteralExprAST); + dispatch(NumberExprAST); + dispatch(VariableExprAST); + dispatch(ReturnExprAST); + dispatch(BinaryExprAST); + dispatch(CallExprAST); + dispatch(PrintExprAST); + // No match, fallback to a generic message + INDENT(); + llvm::errs() << "getKind() << ">\n"; +} + +/// A variable declaration is printing the variable name, the type, and then +/// recurse in the initializer value. +void ASTDumper::dump(VarDeclExprAST *varDecl) { + INDENT(); + llvm::errs() << "VarDecl " << varDecl->getName(); + dump(varDecl->getType()); + llvm::errs() << " " << loc(varDecl) << "\n"; + dump(varDecl->getInitVal()); +} + +/// A "block", or a list of expression +void ASTDumper::dump(ExprASTList *exprList) { + INDENT(); + llvm::errs() << "Block {\n"; + for (auto &expr : *exprList) + dump(expr.get()); + indent(); + llvm::errs() << "} // Block\n"; +} + +/// A literal number, just print the value. +void ASTDumper::dump(NumberExprAST *num) { + INDENT(); + llvm::errs() << num->getValue() << " " << loc(num) << "\n"; +} + +/// Helper to print recurisvely a literal. This handles nested array like: +/// [ [ 1, 2 ], [ 3, 4 ] ] +/// We print out such array with the dimensions spelled out at every level: +/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] +void printLitHelper(ExprAST *lit_or_num) { + // Inside a literal expression we can have either a number or another literal + if (auto num = llvm::dyn_cast(lit_or_num)) { + llvm::errs() << num->getValue(); + return; + } + auto *literal = llvm::cast(lit_or_num); + + // Print the dimension for this literal first + llvm::errs() << "<"; + { + const char *sep = ""; + for (auto dim : literal->getDims()) { + llvm::errs() << sep << dim; + sep = ", "; + } + } + llvm::errs() << ">"; + + // Now print the content, recursing on every element of the list + llvm::errs() << "[ "; + const char *sep = ""; + for (auto &elt : literal->getValues()) { + llvm::errs() << sep; + printLitHelper(elt.get()); + sep = ", "; + } + llvm::errs() << "]"; +} + +/// Print a literal, see the recursive helper above for the implementation. +void ASTDumper::dump(LiteralExprAST *Node) { + INDENT(); + llvm::errs() << "Literal: "; + printLitHelper(Node); + llvm::errs() << " " << loc(Node) << "\n"; +} + +/// Print a variable reference (just a name). +void ASTDumper::dump(VariableExprAST *Node) { + INDENT(); + llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n"; +} + +/// Return statement print the return and its (optional) argument. +void ASTDumper::dump(ReturnExprAST *Node) { + INDENT(); + llvm::errs() << "Return\n"; + if (Node->getExpr().hasValue()) + return dump(*Node->getExpr()); + { + INDENT(); + llvm::errs() << "(void)\n"; + } +} + +/// Print a binary operation, first the operator, then recurse into LHS and RHS. +void ASTDumper::dump(BinaryExprAST *Node) { + INDENT(); + llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n"; + dump(Node->getLHS()); + dump(Node->getRHS()); +} + +/// Print a call expression, first the callee name and the list of args by +/// recursing into each individual argument. +void ASTDumper::dump(CallExprAST *Node) { + INDENT(); + llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n"; + for (auto &arg : Node->getArgs()) + dump(arg.get()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print a builtin print call, first the builtin name and then the argument. +void ASTDumper::dump(PrintExprAST *Node) { + INDENT(); + llvm::errs() << "Print [ " << loc(Node) << "\n"; + dump(Node->getArg()); + indent(); + llvm::errs() << "]\n"; +} + +/// Print type: only the shape is printed in between '<' and '>' +void ASTDumper::dump(VarType &type) { + llvm::errs() << "<"; + const char *sep = ""; + for (auto shape : type.shape) { + llvm::errs() << sep << shape; + sep = ", "; + } + llvm::errs() << ">"; +} + +/// Print a function prototype, first the function name, and then the list of +/// parameters names. +void ASTDumper::dump(PrototypeAST *Node) { + INDENT(); + llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n"; + indent(); + llvm::errs() << "Params: ["; + const char *sep = ""; + for (auto &arg : Node->getArgs()) { + llvm::errs() << sep << arg->getName(); + sep = ", "; + } + llvm::errs() << "]\n"; +} + +/// Print a function, first the prototype and then the body. +void ASTDumper::dump(FunctionAST *Node) { + INDENT(); + llvm::errs() << "Function \n"; + dump(Node->getProto()); + dump(Node->getBody()); +} + +/// Print a module, actually loop over the functions and print them in sequence. +void ASTDumper::dump(ModuleAST *Node) { + INDENT(); + llvm::errs() << "Module:\n"; + for (auto &F : *Node) + dump(&F); +} + +namespace toy { + +// Public API +void dump(ModuleAST &module) { ASTDumper().dump(&module); } + +} // namespace toy diff --git a/mlir/examples/toy/Ch1/toyc.cpp b/mlir/examples/toy/Ch1/toyc.cpp new file mode 100644 index 0000000..dd308ca --- /dev/null +++ b/mlir/examples/toy/Ch1/toyc.cpp @@ -0,0 +1,75 @@ +//===- toyc.cpp - The Toy Compiler ----------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the entry point for the Toy compiler. +// +//===----------------------------------------------------------------------===// + +#include "toy/Parser.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" + +using namespace toy; +namespace cl = llvm::cl; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); +namespace { +enum Action { None, DumpAST }; +} + +static cl::opt + emitAction("emit", cl::desc("Select the kind of output desired"), + cl::values(clEnumValN(DumpAST, "ast", "output the AST dump"))); + +/// Returns a Toy AST resulting from parsing the file or a nullptr on error. +std::unique_ptr parseInputFile(llvm::StringRef filename) { + llvm::ErrorOr> FileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code EC = FileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return nullptr; + } + auto buffer = FileOrErr.get()->getBuffer(); + LexerBuffer lexer(buffer.begin(), buffer.end(), filename); + Parser parser(lexer); + return parser.ParseModule(); +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); + + auto moduleAST = parseInputFile(InputFilename); + if (!moduleAST) + return 1; + + switch (emitAction) { + case Action::DumpAST: + dump(*moduleAST); + return 0; + default: + llvm::errs() << "No action specified (parsing only?), use -emit=\n"; + } + + return 0; +} diff --git a/mlir/g3doc/Tutorials/Toy/Ch-1.md b/mlir/g3doc/Tutorials/Toy/Ch-1.md new file mode 100644 index 0000000..cd3bd33 --- /dev/null +++ b/mlir/g3doc/Tutorials/Toy/Ch-1.md @@ -0,0 +1,149 @@ +# Chapter 1: Toy Tutorial Introduction + +This tutorial runs through the implementation of a basic toy language on top of +MLIR. The goal of this tutorial is to introduce the concepts of MLIR, and +especially how *dialects* can help easily support language specific constructs +and transformations, while still offering an easy path to lower to LLVM or other +codegen infrastructure. This tutorial is based on the model of the +[LLVM Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl01.html). + +This tutorial is divided in the following chapters: + +- [Chapter #1](Ch-1.md): Introduction to the Toy language, and the definition + of its AST. +- [Chapter #2](Ch-2.md): Traversing the AST to emit custom MLIR, introducing + base MLIR concepts. +- [Chapter #3](Ch-3.md): Defining and registering a dialect in MLIR, showing + how we can start attaching semantics to our custom operations in MLIR. +- [Chapter #4](Ch-4.md): High-level language-specific analysis and + transformation, showcasing shape inference, generic function specialization, + and basic optimizations. +- [Chapter #5](Ch-5.md): Lowering to lower-level dialects. We'll convert our + high level language specific semantics towards a generic linear-algebra + oriented dialect for optimizations. Ultimately we will emit LLVM IR for code + generation. +- [Chapter #5](Ch-6.md): A REPL? +- [Chapter #6](Ch-7.md): Custom backends? GPU using LLVM? TPU? XLA + +## The Language + +This tutorial will be illustrated with a toy language that we’ll call “Toy” +(naming is hard...). Toy is an array-based language that allows you to define +functions, some math computation, and print results. + +Because we want to keep things simple, the codegen will be limited to arrays of +rank <= 2 and the only datatype in Toy is a 64-bit floating point type (aka +‘double’ in C parlance). As such, all values are implicitly double precision, +Values are immutable: every operation returns a newly allocated value, and +deallocation is automatically managed. But enough with the long description, +nothing is better than walking through an example to get a better understanding: + +FIXME: update/modify matrix multiplication to use @ instead of * + +```Toy {.toy} +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + # The shape is inferred from the supplied literal. + var a = [[1, 2, 3], [4, 5, 6]]; + # b is identical to a, the literal array is implicitely reshaped: defining new + # variables is the way to reshape arrays (element count must match). + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + # transpose() and print() are the only builtin, the following will transpose + # b and perform a matrix multiplication before printing the result. + print(a * transpose(b)); +} +``` + +Type checking is statically performed through type inference, the language only +requires type declarations to specify array shapes when needed. Function are +generic: their parameters are unranked (in other word we know these are arrays +but we don't know how many dimensions or the size of the dimensions). They are +specialized for every newly discovered signature at call sites. Let's revisit +the previous example by adding a user-defined function: + +```Toy {.toy} +# User defined generic function that operates on unknown shaped arguments +def multiply_transpose(a, b) { + return a * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value. + var a = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(c, d); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var e = multiply_transpose(transpose(a), c); +} +``` + +## The AST + +The AST is fairly straightforward from the above code, here is a dump of it: + +``` +Module: + Function + Proto 'multiply_transpose' @test/ast.toy:5:1' + Args: [a, b] + Block { + Return + BinOp: * @test/ast.toy:6:12 + var: a @test/ast.toy:6:10 + Call 'transpose' [ @test/ast.toy:6:14 + var: b @test/ast.toy:6:24 + ] + } // Block + Function + Proto 'main' @test/ast.toy:9:1' + Args: [] + Block { + VarDecl a<2, 3> @test/ast.toy:11:3 + Literal: <2, 3>[<3>[1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[4.000000e+00, 5.000000e+00, 6.000000e+00]] @test/ast.toy:11:17 + VarDecl b<2, 3> @test/ast.toy:12:3 + Literal: <6>[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @test/ast.toy:12:17 + VarDecl c<> @test/ast.toy:15:3 + Call 'multiply_transpose' [ @test/ast.toy:15:11 + var: a @test/ast.toy:15:30 + var: b @test/ast.toy:15:33 + ] + VarDecl d<> @test/ast.toy:18:3 + Call 'multiply_transpose' [ @test/ast.toy:18:11 + var: b @test/ast.toy:18:30 + var: a @test/ast.toy:18:33 + ] + VarDecl e<> @test/ast.toy:21:3 + Call 'multiply_transpose' [ @test/ast.toy:21:11 + var: b @test/ast.toy:21:30 + var: c @test/ast.toy:21:33 + ] + VarDecl e<> @test/ast.toy:24:3 + Call 'multiply_transpose' [ @test/ast.toy:24:11 + Call 'transpose' [ @test/ast.toy:24:30 + var: a @test/ast.toy:24:40 + ] + var: c @test/ast.toy:24:44 + ] + } // Block +``` + +You can reproduce this result and play with the example in the `examples/Ch1/` +directory, try running `path/to/BUILD/bin/toyc test/ast.toy -emit=ast`. + +The code for the lexer is fairly straighforward, it is all in a single header: +`examples/toy/Ch1/include/toy/Lexer.h`. The parser can be found in +`examples/toy/Ch1/include/toy/Parser.h`, it is a recursive descent parser. If +you are not familiar with such Lexer/Parser, these are very similar to the LLVM +Kaleidoscope equivalent that are detailed in the first two chapters of the +[Kaleidoscope Tutorial](https://llvm.org/docs/tutorial/LangImpl02.html#the-abstract-syntax-tree-ast). + +The [next chapter](Ch-2.md) will demonstrate how to convert this AST into MLIR. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index 8e3f594..16e2cb4 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,3 +1,8 @@ +llvm_canonicalize_cmake_booleans( + LLVM_BUILD_EXAMPLES + ) + + configure_lit_site_cfg( ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py @@ -20,6 +25,13 @@ set(MLIR_TEST_DEPENDS mlir-translate ) + +if(LLVM_BUILD_EXAMPLES) + list(APPEND MLIR_TEST_DEPENDS + toyc-ch1 + ) +endif() + add_lit_testsuite(check-mlir "Running the MLIR regression tests" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${MLIR_TEST_DEPENDS} diff --git a/mlir/test/Examples/Toy/Ch1/ast.toy b/mlir/test/Examples/Toy/Ch1/ast.toy new file mode 100644 index 0000000..e8c8fe0 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch1/ast.toy @@ -0,0 +1,71 @@ +# RUN: toyc-ch1 %s -emit=ast 2>&1 | FileCheck %s + + +# User defined generic function that operates solely on +def multiply_transpose(a, b) { + return a * transpose(b); +} + +def main() { + # Define a variable `a` with shape <2, 3>, initialized with the literal value + var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + var b<2, 3> = [1, 2, 3, 4, 5, 6]; + # This call will specialize `multiply_transpose` with <2, 3> for both + # arguments and deduce a return type of <2, 2> in initialization of `c`. + var c = multiply_transpose(a, b); + # A second call to `multiply_transpose` with <2, 3> for both arguments will + # reuse the previously specialized and inferred version and return `<2, 2>` + var d = multiply_transpose(b, a); + # A new call with `<2, 2>` for both dimension will trigger another + # specialization of `multiply_transpose`. + var e = multiply_transpose(b, c); + # Finally, calling into `multiply_transpose` with incompatible shape will + # trigger a shape inference error. + var e = multiply_transpose(transpose(a), c); +} + + +# CHECK: Module: +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'multiply_transpose' +# CHECK-NEXT: Params: [a, b] +# CHECK-NEXT: Block { +# CHECK-NEXT: Retur +# CHECK-NEXT: BinOp: * +# CHECK-NEXT: var: a +# CHECK-NEXT: Call 'transpose' [ +# CHECK-NEXT: var: b +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block +# CHECK-NEXT: Function +# CHECK-NEXT: Proto 'main' +# CHECK-NEXT: Params: [] +# CHECK-NEXT: Block { +# CHECK-NEXT: VarDecl a<2, 3> +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] +# CHECK-NEXT: VarDecl b<2, 3> +# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] +# CHECK-NEXT: VarDecl c<> +# CHECK-NEXT: Call 'multiply_transpose' [ +# CHECK-NEXT: var: a +# CHECK-NEXT: var: b +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl d<> +# CHECK-NEXT: Call 'multiply_transpose' [ +# CHECK-NEXT: var: b +# CHECK-NEXT: var: a +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> +# CHECK-NEXT: Call 'multiply_transpose' [ +# CHECK-NEXT: var: b +# CHECK-NEXT: var: c +# CHECK-NEXT: ] +# CHECK-NEXT: VarDecl e<> +# CHECK-NEXT: Call 'multiply_transpose' [ +# CHECK-NEXT: Call 'transpose' [ +# CHECK-NEXT: var: a +# CHECK-NEXT: ] +# CHECK-NEXT: var: c +# CHECK-NEXT: ] +# CHECK-NEXT: } // Block + diff --git a/mlir/test/Examples/lit.local.cfg b/mlir/test/Examples/lit.local.cfg new file mode 100644 index 0000000..97db322 --- /dev/null +++ b/mlir/test/Examples/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.build_examples: + config.unsupported = True diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index 51075f1..75459a1 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -21,7 +21,7 @@ config.name = 'MLIR' config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.td', '.mlir'] +config.suffixes = ['.td', '.mlir', '.toy'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -54,4 +54,10 @@ tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir] tools = [ 'mlir-opt', 'mlir-tblgen', 'mlir-translate', ] + +# The following tools are optional +tools.extend([ + ToolSubst('toy-ch1', unresolved='ignore'), +]) + llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in index 4cbdb3a..c701b04 100644 --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -30,6 +30,7 @@ config.host_arch = "@HOST_ARCH@" config.mlir_src_root = "@MLIR_SOURCE_DIR@" config.mlir_obj_root = "@MLIR_BINARY_DIR@" config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" +config.build_examples = @LLVM_BUILD_EXAMPLES@ # Support substitution of the tools_dir with user parameters. This is # used when we can't determine the tool dir at configuration time.