Toy tutorial Chapter 5: Lowering to Linalg and LLVM
authorMehdi Amini <aminim@google.com>
Tue, 9 Apr 2019 06:00:49 +0000 (23:00 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 9 Apr 2019 06:26:54 +0000 (23:26 -0700)
--

PiperOrigin-RevId: 242606796

31 files changed:
mlir/examples/Linalg/Linalg1/CMakeLists.txt
mlir/examples/Linalg/Linalg3/CMakeLists.txt
mlir/examples/toy/CMakeLists.txt
mlir/examples/toy/Ch3/mlir/ToyDialect.cpp
mlir/examples/toy/Ch4/mlir/ToyDialect.cpp
mlir/examples/toy/Ch5/CMakeLists.txt [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/AST.h [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/Dialect.h [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/Lexer.h [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/Lowering.h [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/MLIRGen.h [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/Parser.h [new file with mode: 0644]
mlir/examples/toy/Ch5/include/toy/Passes.h [new file with mode: 0644]
mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/mlir/LateLowering.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/mlir/MLIRGen.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/mlir/ToyCombine.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/mlir/ToyDialect.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/parser/AST.cpp [new file with mode: 0644]
mlir/examples/toy/Ch5/toyc.cpp [new file with mode: 0644]
mlir/g3doc/Tutorials/Toy/Ch-5.md [new file with mode: 0644]
mlir/test/CMakeLists.txt
mlir/test/Examples/Toy/Ch5/ast.toy [new file with mode: 0644]
mlir/test/Examples/Toy/Ch5/codegen.toy [new file with mode: 0644]
mlir/test/Examples/Toy/Ch5/invalid.mlir [new file with mode: 0644]
mlir/test/Examples/Toy/Ch5/lowering.toy [new file with mode: 0644]
mlir/test/Examples/Toy/Ch5/scalar.toy [new file with mode: 0644]
mlir/test/Examples/Toy/Ch5/transpose_transpose.toy [new file with mode: 0644]
mlir/test/Examples/Toy/Ch5/trivialReshape.toy [new file with mode: 0644]
mlir/test/lit.cfg.py

index ea07276..99f2362 100644 (file)
@@ -17,6 +17,8 @@ add_llvm_example(linalg-example-1
 
 target_link_libraries(linalg-example-1
   PRIVATE
+    Linalg1DialectConstruction
+    Linalg1
     MLIRAnalysis
     MLIRDialect
     MLIREDSC
@@ -25,8 +27,6 @@ target_link_libraries(linalg-example-1
     MLIRParser
     MLIRPass
     MLIRTransforms
-    Linalg1
-    Linalg1DialectConstruction
     )
 
 whole_archive_link(linalg-example-1
@@ -35,6 +35,8 @@ whole_archive_link(linalg-example-1
 
 target_link_libraries(linalg-conversion-1
   PRIVATE
+    Linalg1DialectConstruction
+    Linalg1
     MLIRAnalysis
     MLIRDialect
     MLIREDSC
@@ -43,8 +45,6 @@ target_link_libraries(linalg-conversion-1
     MLIRParser
     MLIRPass
     MLIRTransforms
-    Linalg1
-    Linalg1DialectConstruction
     )
 
 whole_archive_link(linalg-conversion-1
index 8b1b3e9..e3d5d94 100644 (file)
@@ -1,3 +1,5 @@
+add_definitions(-DLINALG_STEP=3)
+
 add_subdirectory(lib)
 
 set(LLVM_LINK_COMPONENTS
index cb65c89..73d5caf 100644 (file)
@@ -10,3 +10,4 @@ add_subdirectory(Ch1)
 add_subdirectory(Ch2)
 add_subdirectory(Ch3)
 add_subdirectory(Ch4)
+add_subdirectory(Ch5)
index 98dfe13..7972d75 100644 (file)
@@ -297,7 +297,7 @@ template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
     raw_string_ostream os(msg);
     os << "expects a Toy Array for its argument, got "
        << op->getOperand()->getType();
-    return op->emitOpError(msg);
+    return op->emitOpError(os.str());
   }
   return mlir::success();
 }
index 3900dd4..eebc971 100644 (file)
@@ -297,7 +297,7 @@ template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
     raw_string_ostream os(msg);
     os << "expects a Toy Array for its argument, got "
        << op->getOperand()->getType();
-    return op->emitOpError(msg);
+    return op->emitOpError(os.str());
   }
   return mlir::success();
 }
diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt
new file mode 100644 (file)
index 0000000..d83c65c
--- /dev/null
@@ -0,0 +1,37 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+add_toy_chapter(toyc-ch5
+  toyc.cpp
+  parser/AST.cpp
+  mlir/EarlyLowering.cpp
+  mlir/LateLowering.cpp
+  mlir/MLIRGen.cpp
+  mlir/ShapeInferencePass.cpp
+  mlir/ToyDialect.cpp
+  mlir/ToyCombine.cpp
+  )
+include_directories(include/)
+target_link_libraries(toyc-ch5
+  PRIVATE
+    Linalg3DialectConstruction
+    Linalg3
+    Linalg2
+    Linalg1
+    MLIRAnalysis
+    MLIREDSC
+    MLIRExecutionEngine
+    MLIRIR
+    MLIRLLVMIR
+    MLIRParser
+    MLIRPass
+    MLIRTargetLLVMIR
+    MLIRTransforms
+    MLIRSupport
+)
+whole_archive_link(toyc-ch5
+  MLIRAffineOps
+  MLIRStandardOps
+)
+
diff --git a/mlir/examples/toy/Ch5/include/toy/AST.h b/mlir/examples/toy/Ch5/include/toy/AST.h
new file mode 100644 (file)
index 0000000..456a323
--- /dev/null
@@ -0,0 +1,256 @@
+//===- AST.h - Node definition for the Toy AST ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST for the Toy language. It is optimized for
+// simplicity, not efficiency. The AST forms a tree structure where each node
+// references its children using std::unique_ptr<>.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_AST_H_
+#define MLIR_TUTORIAL_TOY_AST_H_
+
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <vector>
+
+namespace toy {
+
+/// A variable
+struct VarType {
+  enum { TY_FLOAT, TY_INT } elt_ty;
+  std::vector<int> shape;
+};
+
+/// Base class for all expression nodes.
+class ExprAST {
+public:
+  enum ExprASTKind {
+    Expr_VarDecl,
+    Expr_Return,
+    Expr_Num,
+    Expr_Literal,
+    Expr_Var,
+    Expr_BinOp,
+    Expr_Call,
+    Expr_Print, // builtin
+    Expr_If,
+    Expr_For,
+  };
+
+  ExprAST(ExprASTKind kind, Location location)
+      : kind(kind), location(location) {}
+
+  virtual ~ExprAST() = default;
+
+  ExprASTKind getKind() const { return kind; }
+
+  const Location &loc() { return location; }
+
+private:
+  const ExprASTKind kind;
+  Location location;
+};
+
+/// A block-list of expressions.
+using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
+
+/// Expression class for numeric literals like "1.0".
+class NumberExprAST : public ExprAST {
+  double Val;
+
+public:
+  NumberExprAST(Location loc, double Val) : ExprAST(Expr_Num, loc), Val(Val) {}
+
+  double getValue() { return Val; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
+};
+
+///
+class LiteralExprAST : public ExprAST {
+  std::vector<std::unique_ptr<ExprAST>> values;
+  std::vector<int64_t> dims;
+
+public:
+  LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
+                 std::vector<int64_t> dims)
+      : ExprAST(Expr_Literal, loc), values(std::move(values)),
+        dims(std::move(dims)) {}
+
+  std::vector<std::unique_ptr<ExprAST>> &getValues() { return values; }
+  std::vector<int64_t> &getDims() { return dims; }
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Literal; }
+};
+
+/// Expression class for referencing a variable, like "a".
+class VariableExprAST : public ExprAST {
+  std::string name;
+
+public:
+  VariableExprAST(Location loc, const std::string &name)
+      : ExprAST(Expr_Var, loc), name(name) {}
+
+  llvm::StringRef getName() { return name; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
+};
+
+///
+class VarDeclExprAST : public ExprAST {
+  std::string name;
+  VarType type;
+  std::unique_ptr<ExprAST> initVal;
+
+public:
+  VarDeclExprAST(Location loc, const std::string &name, VarType type,
+                 std::unique_ptr<ExprAST> initVal)
+      : ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
+        initVal(std::move(initVal)) {}
+
+  llvm::StringRef getName() { return name; }
+  ExprAST *getInitVal() { return initVal.get(); }
+  VarType &getType() { return type; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
+};
+
+///
+class ReturnExprAST : public ExprAST {
+  llvm::Optional<std::unique_ptr<ExprAST>> expr;
+
+public:
+  ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
+      : ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
+
+  llvm::Optional<ExprAST *> getExpr() {
+    if (expr.hasValue())
+      return expr->get();
+    return llvm::NoneType();
+  }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Return; }
+};
+
+/// Expression class for a binary operator.
+class BinaryExprAST : public ExprAST {
+  char Op;
+  std::unique_ptr<ExprAST> LHS, RHS;
+
+public:
+  char getOp() { return Op; }
+  ExprAST *getLHS() { return LHS.get(); }
+  ExprAST *getRHS() { return RHS.get(); }
+
+  BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> LHS,
+                std::unique_ptr<ExprAST> RHS)
+      : ExprAST(Expr_BinOp, loc), Op(Op), LHS(std::move(LHS)),
+        RHS(std::move(RHS)) {}
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_BinOp; }
+};
+
+/// Expression class for function calls.
+class CallExprAST : public ExprAST {
+  std::string Callee;
+  std::vector<std::unique_ptr<ExprAST>> Args;
+
+public:
+  CallExprAST(Location loc, const std::string &Callee,
+              std::vector<std::unique_ptr<ExprAST>> Args)
+      : ExprAST(Expr_Call, loc), Callee(Callee), Args(std::move(Args)) {}
+
+  llvm::StringRef getCallee() { return Callee; }
+  llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return Args; }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Call; }
+};
+
+/// Expression class for builtin print calls.
+class PrintExprAST : public ExprAST {
+  std::unique_ptr<ExprAST> Arg;
+
+public:
+  PrintExprAST(Location loc, std::unique_ptr<ExprAST> Arg)
+      : ExprAST(Expr_Print, loc), Arg(std::move(Arg)) {}
+
+  ExprAST *getArg() { return Arg.get(); }
+
+  /// LLVM style RTTI
+  static bool classof(const ExprAST *C) { return C->getKind() == Expr_Print; }
+};
+
+/// This class represents the "prototype" for a function, which captures its
+/// name, and its argument names (thus implicitly the number of arguments the
+/// function takes).
+class PrototypeAST {
+  Location location;
+  std::string name;
+  std::vector<std::unique_ptr<VariableExprAST>> args;
+
+public:
+  PrototypeAST(Location location, const std::string &name,
+               std::vector<std::unique_ptr<VariableExprAST>> args)
+      : location(location), name(name), args(std::move(args)) {}
+
+  const Location &loc() { return location; }
+  const std::string &getName() const { return name; }
+  const std::vector<std::unique_ptr<VariableExprAST>> &getArgs() {
+    return args;
+  }
+};
+
+/// This class represents a function definition itself.
+class FunctionAST {
+  std::unique_ptr<PrototypeAST> Proto;
+  std::unique_ptr<ExprASTList> Body;
+
+public:
+  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
+              std::unique_ptr<ExprASTList> Body)
+      : Proto(std::move(Proto)), Body(std::move(Body)) {}
+  PrototypeAST *getProto() { return Proto.get(); }
+  ExprASTList *getBody() { return Body.get(); }
+};
+
+/// This class represents a list of functions to be processed together
+class ModuleAST {
+  std::vector<FunctionAST> functions;
+
+public:
+  ModuleAST(std::vector<FunctionAST> functions)
+      : functions(std::move(functions)) {}
+
+  auto begin() -> decltype(functions.begin()) { return functions.begin(); }
+  auto end() -> decltype(functions.end()) { return functions.end(); }
+};
+
+void dump(ModuleAST &);
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_AST_H_
diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h
new file mode 100644 (file)
index 0000000..9d7f82d
--- /dev/null
@@ -0,0 +1,393 @@
+//===- Dialect.h - Dialect definition for the Toy IR ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the IR Dialect for the Toy language.
+// See g3doc/Tutorials/Toy/Ch-3.md for more information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_
+#define MLIR_TUTORIAL_TOY_DIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+class Builder;
+}
+
+namespace toy {
+
+/// This is the definition of the Toy dialect. A dialect inherits from
+/// mlir::Dialect and register custom operations and types (in its constructor).
+/// It can also overridding general behavior of dialects exposed as virtual
+/// method, for example regarding verification and parsing/printing.
+class ToyDialect : public mlir::Dialect {
+public:
+  explicit ToyDialect(mlir::MLIRContext *ctx);
+
+  /// Parse a type registered to this dialect. Overridding this method is
+  /// required for dialects that have custom types.
+  /// Technically this is only needed to be able to round-trip to textual IR.
+  mlir::Type parseType(llvm::StringRef tyData,
+                       mlir::Location loc) const override;
+
+  /// Print a type registered to this dialect. Overridding this method is
+  /// only required for dialects that have custom types.
+  /// Technically this is only needed to be able to round-trip to textual IR.
+  void printType(mlir::Type type, llvm::raw_ostream &os) const override;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+/////////////////////// Custom Types for the Dialect ///////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+struct ToyArrayTypeStorage;
+}
+
+/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
+enum ToyTypeKind {
+  // The enum starts at the range reserved for this dialect.
+  TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
+  TOY_ARRAY,
+};
+
+/// Type for Toy arrays.
+/// In MLIR Types are reference to immutable and uniqued objects owned by the
+/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
+/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
+/// provides the public facade API to interact with the type.
+class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
+                                                 detail::ToyArrayTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Returns the dimensions for this array, or and empty range for a generic
+  /// array.
+  llvm::ArrayRef<int64_t> getShape();
+
+  /// Predicate to test if this array is generic (shape haven't been inferred
+  /// yet).
+  bool isGeneric() { return getShape().empty(); }
+
+  /// Return the rank of this array (0 if it is generic).
+  int getRank() { return getShape().size(); }
+
+  /// Return the type of individual elements in the array.
+  mlir::Type getElementType();
+
+  /// Get a MemRef equivalent to this array type.
+  mlir::MemRefType toMemref();
+
+  /// Get the unique instance of this Type from the context.
+  /// A ToyArrayType is only defined by the shape of the array.
+  static ToyArrayType get(mlir::MLIRContext *context,
+                          llvm::ArrayRef<int64_t> shape = {});
+
+  /// Support method to enable LLVM-style RTTI type casting.
+  static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//////////////////// Custom Operations for the Dialect /////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+/// Constant operation turns a literal into an SSA value. The data is attached
+/// to the operation as an attribute. For example:
+///
+///   %0 = "toy.constant"()
+///       {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
+///     : () -> !toy<"array<2, 3>">
+///
+/// An operation inherits from `class Op` and specifies optional traits. Here we
+/// indicate that `toy.constant` does not have any operands and returns a single
+/// result. The traits provide some utilities methods for the operation, for
+/// instance we will be able to use `getResult()`, but `getOperand()` won't be
+/// available.
+class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
+                                   mlir::OpTrait::OneResult,
+                                   mlir::OpTrait::HasNoSideEffect> {
+public:
+  /// This is the name used by MLIR to match an operation to this class during
+  /// parsing.
+  static llvm::StringRef getOperationName() { return "toy.constant"; }
+
+  /// The operation can have extra verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<PrintOp>(...)
+  /// This method populates the `state` that MLIR uses to create operations.
+  /// The `toy.constant` operation does not have arguments but attaches a
+  /// constant array as an attribute and returns it as an SSA value.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    llvm::ArrayRef<int64_t> shape,
+                    mlir::DenseElementsAttr value);
+
+  /// Similar to the one above, but takes a single float and returns a
+  /// !toy<"array<1>">.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::FloatAttr value);
+
+  mlir::DenseElementsAttr getValue() {
+    return getAttr("value").cast<mlir::DenseElementsAttr>();
+  }
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// Generic calls represent calls to a user defined function that needs to
+/// be specialized for the shape of its arguments. The callee name is attached
+/// as a literal string as an attribute. The arguments list must match the
+/// arguments expected by the callee. For example:
+///
+///   %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
+///         : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+///
+/// This is only valid if a function named "my_func" exists and takes two
+/// arguments.
+class GenericCallOp
+    : public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
+                      mlir::OpTrait::OneResult> {
+public:
+  /// MLIR will use this to register the operation with the parser/printer.
+  static llvm::StringRef getOperationName() { return "toy.generic_call"; }
+
+  /// Operations can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to the builder to allow:
+  ///   mlir::Builder::create<GenericCallOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.generic_call` operation accepts a callee name and a list of
+  /// arguments for the call.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    llvm::StringRef callee,
+                    llvm::ArrayRef<mlir::Value *> arguments);
+
+  /// Return the name of the callee.
+  llvm::StringRef getCalleeName();
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// Return operations terminate blocks (and functions as well). They take a
+/// single argument and the type must match the function return type.
+class ReturnOp
+    : public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
+                      mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.return"; }
+
+  /// Operations can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<PrintOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.return` operation accepts an optional single array as an argument
+  /// and does not have any returned value.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *value = nullptr);
+
+  /// Return true if there is a returned value.
+  bool hasOperand() { return 0 != getNumOperands(); }
+
+  /// Helper to return the optional operand. Caller must check if the operand
+  /// is present before calling this.
+  mlir::Value *getOperand() { return getOperation()->getOperand(0); }
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// The print builtin takes a single array argument and does not return any.
+class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
+                                mlir::OpTrait::ZeroResult> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.print"; }
+
+  /// Operations can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<PrintOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.print` operation accepts a single array as argument and does
+  /// not have any returned value.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *value);
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
+                                    mlir::OpTrait::OneResult,
+                                    mlir::OpTrait::HasNoSideEffect> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.transpose"; }
+
+  /// Operation can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<TransposeOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.transpose` operation accepts a single array as argument and
+  /// returns the transposed array as its only result.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *value);
+
+  // Register our patterns for rewrite by the Canonicalization framework.
+  static void
+  getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
+                              mlir::MLIRContext *context);
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// Reshape operation is transforming its input array into a new array with the
+/// same number of elements but different shapes. For example:
+///
+///    %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>">
+///
+class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
+                                  mlir::OpTrait::OneResult,
+                                  mlir::OpTrait::HasNoSideEffect> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.reshape"; }
+
+  /// Operation can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<ReshapeOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.reshape` operation accepts a single array as argument and
+  /// returns the array with the specified reshapedType as its only result.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *value, ToyArrayType reshapedType);
+
+  // Register our patterns for rewrite by the Canonicalization framework.
+  static void
+  getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
+                              mlir::MLIRContext *context);
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// Binary operation implementing a multiplication. For two-dimensional array
+/// a matrix multiplication is implemented, while for one dimensional array a
+/// dot product is performed.
+class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
+                              mlir::OpTrait::OneResult,
+                              mlir::OpTrait::HasNoSideEffect> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.mul"; }
+
+  /// Operation can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<PrintOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.mul` operation accepts two operands as argument and returns
+  /// a single value.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *lhs, mlir::Value *rhs);
+
+  /// Convenience accessor for LHS of the expression.
+  mlir::Value *getLHS() { return getOperand(0); }
+
+  /// Convenience accessor for RHS of the expression.
+  mlir::Value *getRHS() { return getOperand(1); }
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// Element wise addition of two arrays. The shape must match.
+class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
+                              mlir::OpTrait::OneResult,
+                              mlir::OpTrait::HasNoSideEffect> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.add"; }
+
+  /// Operation can add custom verification beyond the traits they define.
+  mlir::LogicalResult verify();
+
+  /// Interface to mlir::Builder::create<PrintOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// The `toy.mul` operation accepts two operands as argument and returns
+  /// a single value.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *lhs, mlir::Value *rhs);
+
+  /// Convenience accessor for LHS of the expression.
+  mlir::Value *getLHS() { return getOperand(0); }
+
+  /// Convenience accessor for RHS of the expression.
+  mlir::Value *getRHS() { return getOperand(1); }
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// AllocOp is a temporary operation for buffer allocation, created as part of
+/// partial lowering.
+class AllocOp : public mlir::Op<AllocOp, mlir::OpTrait::ZeroOperands,
+                                mlir::OpTrait::OneResult> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.alloc"; }
+
+  /// Interface to mlir::Builder::create<AllocOp>(...)
+  /// This method populate the `state` that MLIR use to create operations.
+  /// `toy.alloc` does not have any argument and returns a toy array.
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Type retType);
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+/// FIXME: should be in std?
+class TypeCastOp : public mlir::Op<TypeCastOp, mlir::OpTrait::OneOperand,
+                                   mlir::OpTrait::OneResult,
+                                   mlir::OpTrait::HasNoSideEffect> {
+public:
+  static llvm::StringRef getOperationName() { return "toy.cast"; }
+
+  static void build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *value, mlir::Type destTy);
+
+  // Register our patterns for rewrite by the Canonicalization framework.
+  static void
+  getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
+                              mlir::MLIRContext *context);
+
+  /// Inherit constructor.
+  using Op::Op;
+};
+
+} // end namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
diff --git a/mlir/examples/toy/Ch5/include/toy/Lexer.h b/mlir/examples/toy/Ch5/include/toy/Lexer.h
new file mode 100644 (file)
index 0000000..d73adb9
--- /dev/null
@@ -0,0 +1,239 @@
+//===- Lexer.h - Lexer for the Toy language -------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple Lexer for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_LEXER_H_
+#define MLIR_TUTORIAL_TOY_LEXER_H_
+
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+#include <string>
+
+namespace toy {
+
+/// Structure definition a location in a file.
+struct Location {
+  std::shared_ptr<std::string> file; ///< filename
+  int line;                          ///< line number.
+  int col;                           ///< column number.
+};
+
+// List of Token returned by the lexer.
+enum Token : int {
+  tok_semicolon = ';',
+  tok_parenthese_open = '(',
+  tok_parenthese_close = ')',
+  tok_bracket_open = '{',
+  tok_bracket_close = '}',
+  tok_sbracket_open = '[',
+  tok_sbracket_close = ']',
+
+  tok_eof = -1,
+
+  // commands
+  tok_return = -2,
+  tok_var = -3,
+  tok_def = -4,
+
+  // primary
+  tok_identifier = -5,
+  tok_number = -6,
+};
+
+/// The Lexer is an abstract base class providing all the facilities that the
+/// Parser expects. It goes through the stream one token at a time and keeps
+/// track of the location in the file for debugging purpose.
+/// It relies on a subclass to provide a `readNextLine()` method. The subclass
+/// can proceed by reading the next line from the standard input or from a
+/// memory mapped file.
+class Lexer {
+public:
+  /// Create a lexer for the given filename. The filename is kept only for
+  /// debugging purpose (attaching a location to a Token).
+  Lexer(std::string filename)
+      : lastLocation(
+            {std::make_shared<std::string>(std::move(filename)), 0, 0}) {}
+  virtual ~Lexer() = default;
+
+  /// Look at the current token in the stream.
+  Token getCurToken() { return curTok; }
+
+  /// Move to the next token in the stream and return it.
+  Token getNextToken() { return curTok = getTok(); }
+
+  /// Move to the next token in the stream, asserting on the current token
+  /// matching the expectation.
+  void consume(Token tok) {
+    assert(tok == curTok && "consume Token mismatch expectation");
+    getNextToken();
+  }
+
+  /// Return the current identifier (prereq: getCurToken() == tok_identifier)
+  llvm::StringRef getId() {
+    assert(curTok == tok_identifier);
+    return IdentifierStr;
+  }
+
+  /// Return the current number (prereq: getCurToken() == tok_number)
+  double getValue() {
+    assert(curTok == tok_number);
+    return NumVal;
+  }
+
+  /// Return the location for the beginning of the current token.
+  Location getLastLocation() { return lastLocation; }
+
+  // Return the current line in the file.
+  int getLine() { return curLineNum; }
+
+  // Return the current column in the file.
+  int getCol() { return curCol; }
+
+private:
+  /// Delegate to a derived class fetching the next line. Returns an empty
+  /// string to signal end of file (EOF). Lines are expected to always finish
+  /// with "\n"
+  virtual llvm::StringRef readNextLine() = 0;
+
+  /// Return the next character from the stream. This manages the buffer for the
+  /// current line and request the next line buffer to the derived class as
+  /// needed.
+  int getNextChar() {
+    // The current line buffer should not be empty unless it is the end of file.
+    if (curLineBuffer.empty())
+      return EOF;
+    ++curCol;
+    auto nextchar = curLineBuffer.front();
+    curLineBuffer = curLineBuffer.drop_front();
+    if (curLineBuffer.empty())
+      curLineBuffer = readNextLine();
+    if (nextchar == '\n') {
+      ++curLineNum;
+      curCol = 0;
+    }
+    return nextchar;
+  }
+
+  ///  Return the next token from standard input.
+  Token getTok() {
+    // Skip any whitespace.
+    while (isspace(LastChar))
+      LastChar = Token(getNextChar());
+
+    // Save the current location before reading the token characters.
+    lastLocation.line = curLineNum;
+    lastLocation.col = curCol;
+
+    if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9_]*
+      IdentifierStr = (char)LastChar;
+      while (isalnum((LastChar = Token(getNextChar()))) || LastChar == '_')
+        IdentifierStr += (char)LastChar;
+
+      if (IdentifierStr == "return")
+        return tok_return;
+      if (IdentifierStr == "def")
+        return tok_def;
+      if (IdentifierStr == "var")
+        return tok_var;
+      return tok_identifier;
+    }
+
+    if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
+      std::string NumStr;
+      do {
+        NumStr += LastChar;
+        LastChar = Token(getNextChar());
+      } while (isdigit(LastChar) || LastChar == '.');
+
+      NumVal = strtod(NumStr.c_str(), nullptr);
+      return tok_number;
+    }
+
+    if (LastChar == '#') {
+      // Comment until end of line.
+      do
+        LastChar = Token(getNextChar());
+      while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
+
+      if (LastChar != EOF)
+        return getTok();
+    }
+
+    // Check for end of file.  Don't eat the EOF.
+    if (LastChar == EOF)
+      return tok_eof;
+
+    // Otherwise, just return the character as its ascii value.
+    Token ThisChar = Token(LastChar);
+    LastChar = Token(getNextChar());
+    return ThisChar;
+  }
+
+  /// The last token read from the input.
+  Token curTok = tok_eof;
+
+  /// Location for `curTok`.
+  Location lastLocation;
+
+  /// If the current Token is an identifier, this string contains the value.
+  std::string IdentifierStr;
+
+  /// If the current Token is a number, this contains the value.
+  double NumVal = 0;
+
+  /// The last value returned by getNextChar(). We need to keep it around as we
+  /// always need to read ahead one character to decide when to end a token and
+  /// we can't put it back in the stream after reading from it.
+  Token LastChar = Token(' ');
+
+  /// Keep track of the current line number in the input stream
+  int curLineNum = 0;
+
+  /// Keep track of the current column number in the input stream
+  int curCol = 0;
+
+  /// Buffer supplied by the derived class on calls to `readNextLine()`
+  llvm::StringRef curLineBuffer = "\n";
+};
+
+/// A lexer implementation operating on a buffer in memory.
+class LexerBuffer final : public Lexer {
+public:
+  LexerBuffer(const char *begin, const char *end, std::string filename)
+      : Lexer(std::move(filename)), current(begin), end(end) {}
+
+private:
+  /// Provide one line at a time to the Lexer, return an empty string when
+  /// reaching the end of the buffer.
+  llvm::StringRef readNextLine() override {
+    auto *begin = current;
+    while (current <= end && *current && *current != '\n')
+      ++current;
+    if (current <= end && *current)
+      ++current;
+    llvm::StringRef result{begin, static_cast<size_t>(current - begin)};
+    return result;
+  }
+  const char *current, *end;
+};
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_LEXER_H_
diff --git a/mlir/examples/toy/Ch5/include/toy/Lowering.h b/mlir/examples/toy/Ch5/include/toy/Lowering.h
new file mode 100644 (file)
index 0000000..362a342
--- /dev/null
@@ -0,0 +1,45 @@
+//===- Lowering.h - Lexer for the Toy language ----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file exposes the interface to the lowering for Toy. It is divided in
+// two parts:  an *early lowering* that emits operations in the `Linalg`
+// dialects for a subset of the Toy IR, and a *late lowering* that materializes
+// buffers and converts all operations and type to the LLVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_
+#define MLIR_EXAMPLES_TOY_LOWERING_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class DialectConversion;
+} // namespace mlir
+
+namespace toy {
+/// Create a pass for lowering operations in the `Linalg` dialects, for a subset
+/// of the Toy IR (matmul).
+mlir::Pass *createEarlyLoweringPass();
+
+/// Create a pass for the late lowering toward LLVM dialect.
+mlir::Pass *createLateLoweringPass();
+
+} // namespace toy
+
+#endif // MLIR_EXAMPLES_TOY_LOWERING_H_
diff --git a/mlir/examples/toy/Ch5/include/toy/MLIRGen.h b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h
new file mode 100644 (file)
index 0000000..21637bc
--- /dev/null
@@ -0,0 +1,42 @@
+//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares a simple interface to perform IR generation targeting MLIR
+// from a Module AST for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_MLIRGEN_H_
+#define MLIR_TUTORIAL_TOY_MLIRGEN_H_
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+class Module;
+} // namespace mlir
+
+namespace toy {
+class ModuleAST;
+
+/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
+/// or nullptr on failure.
+std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
+                                      ModuleAST &moduleAST);
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
diff --git a/mlir/examples/toy/Ch5/include/toy/Parser.h b/mlir/examples/toy/Ch5/include/toy/Parser.h
new file mode 100644 (file)
index 0000000..bc7aa52
--- /dev/null
@@ -0,0 +1,494 @@
+//===- Parser.h - Toy Language Parser -------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the parser for the Toy language. It processes the Token
+// provided by the Lexer and returns an AST.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PARSER_H
+#define MLIR_TUTORIAL_TOY_PARSER_H
+
+#include "toy/AST.h"
+#include "toy/Lexer.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <map>
+#include <utility>
+#include <vector>
+
+namespace toy {
+
+/// This is a simple recursive parser for the Toy language. It produces a well
+/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
+/// or symbol resolution is performed. For example, variables are referenced by
+/// string and the code could reference an undeclared variable and the parsing
+/// succeeds.
+class Parser {
+public:
+  /// Create a Parser for the supplied lexer.
+  Parser(Lexer &lexer) : lexer(lexer) {}
+
+  /// Parse a full Module. A module is a list of function definitions.
+  std::unique_ptr<ModuleAST> ParseModule() {
+    lexer.getNextToken(); // prime the lexer
+
+    // Parse functions one at a time and accumulate in this vector.
+    std::vector<FunctionAST> functions;
+    while (auto F = ParseDefinition()) {
+      functions.push_back(std::move(*F));
+      if (lexer.getCurToken() == tok_eof)
+        break;
+    }
+    // If we didn't reach EOF, there was an error during parsing
+    if (lexer.getCurToken() != tok_eof)
+      return parseError<ModuleAST>("nothing", "at end of module");
+
+    return llvm::make_unique<ModuleAST>(std::move(functions));
+  }
+
+private:
+  Lexer &lexer;
+
+  /// Parse a return statement.
+  /// return :== return ; | return expr ;
+  std::unique_ptr<ReturnExprAST> ParseReturn() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(tok_return);
+
+    // return takes an optional argument
+    llvm::Optional<std::unique_ptr<ExprAST>> expr;
+    if (lexer.getCurToken() != ';') {
+      expr = ParseExpression();
+      if (!expr)
+        return nullptr;
+    }
+    return llvm::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
+  }
+
+  /// Parse a literal number.
+  /// numberexpr ::= number
+  std::unique_ptr<ExprAST> ParseNumberExpr() {
+    auto loc = lexer.getLastLocation();
+    auto Result =
+        llvm::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
+    lexer.consume(tok_number);
+    return std::move(Result);
+  }
+
+  /// Parse a literal array expression.
+  /// tensorLiteral ::= [ literalList ] | number
+  /// literalList ::= tensorLiteral | tensorLiteral, literalList
+  std::unique_ptr<ExprAST> ParseTensorLitteralExpr() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(Token('['));
+
+    // Hold the list of values at this nesting level.
+    std::vector<std::unique_ptr<ExprAST>> values;
+    // Hold the dimensions for all the nesting inside this level.
+    std::vector<int64_t> dims;
+    do {
+      // We can have either another nested array or a number literal.
+      if (lexer.getCurToken() == '[') {
+        values.push_back(ParseTensorLitteralExpr());
+        if (!values.back())
+          return nullptr; // parse error in the nested array.
+      } else {
+        if (lexer.getCurToken() != tok_number)
+          return parseError<ExprAST>("<num> or [", "in literal expression");
+        values.push_back(ParseNumberExpr());
+      }
+
+      // End of this list on ']'
+      if (lexer.getCurToken() == ']')
+        break;
+
+      // Elements are separated by a comma.
+      if (lexer.getCurToken() != ',')
+        return parseError<ExprAST>("] or ,", "in literal expression");
+
+      lexer.getNextToken(); // eat ,
+    } while (true);
+    if (values.empty())
+      return parseError<ExprAST>("<something>", "to fill literal expression");
+    lexer.getNextToken(); // eat ]
+    /// Fill in the dimensions now. First the current nesting level:
+    dims.push_back(values.size());
+    /// If there is any nested array, process all of them and ensure that
+    /// dimensions are uniform.
+    if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
+          return llvm::isa<LiteralExprAST>(expr.get());
+        })) {
+      auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
+      if (!firstLiteral)
+        return parseError<ExprAST>("uniform well-nested dimensions",
+                                   "inside literal expession");
+
+      // Append the nested dimensions to the current level
+      auto &firstDims = firstLiteral->getDims();
+      dims.insert(dims.end(), firstDims.begin(), firstDims.end());
+
+      // Sanity check that shape is uniform across all elements of the list.
+      for (auto &expr : values) {
+        auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
+        if (!exprLiteral)
+          return parseError<ExprAST>("uniform well-nested dimensions",
+                                     "inside literal expession");
+        if (exprLiteral->getDims() != firstDims)
+          return parseError<ExprAST>("uniform well-nested dimensions",
+                                     "inside literal expession");
+      }
+    }
+    return llvm::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
+                                             std::move(dims));
+  }
+
+  /// parenexpr ::= '(' expression ')'
+  std::unique_ptr<ExprAST> ParseParenExpr() {
+    lexer.getNextToken(); // eat (.
+    auto V = ParseExpression();
+    if (!V)
+      return nullptr;
+
+    if (lexer.getCurToken() != ')')
+      return parseError<ExprAST>(")", "to close expression with parentheses");
+    lexer.consume(Token(')'));
+    return V;
+  }
+
+  /// identifierexpr
+  ///   ::= identifier
+  ///   ::= identifier '(' expression ')'
+  std::unique_ptr<ExprAST> ParseIdentifierExpr() {
+    std::string name = lexer.getId();
+
+    auto loc = lexer.getLastLocation();
+    lexer.getNextToken(); // eat identifier.
+
+    if (lexer.getCurToken() != '(') // Simple variable ref.
+      return llvm::make_unique<VariableExprAST>(std::move(loc), name);
+
+    // This is a function call.
+    lexer.consume(Token('('));
+    std::vector<std::unique_ptr<ExprAST>> Args;
+    if (lexer.getCurToken() != ')') {
+      while (true) {
+        if (auto Arg = ParseExpression())
+          Args.push_back(std::move(Arg));
+        else
+          return nullptr;
+
+        if (lexer.getCurToken() == ')')
+          break;
+
+        if (lexer.getCurToken() != ',')
+          return parseError<ExprAST>(", or )", "in argument list");
+        lexer.getNextToken();
+      }
+    }
+    lexer.consume(Token(')'));
+
+    // It can be a builtin call to print
+    if (name == "print") {
+      if (Args.size() != 1)
+        return parseError<ExprAST>("<single arg>", "as argument to print()");
+
+      return llvm::make_unique<PrintExprAST>(std::move(loc),
+                                             std::move(Args[0]));
+    }
+
+    // Call to a user-defined function
+    return llvm::make_unique<CallExprAST>(std::move(loc), name,
+                                          std::move(Args));
+  }
+
+  /// primary
+  ///   ::= identifierexpr
+  ///   ::= numberexpr
+  ///   ::= parenexpr
+  ///   ::= tensorliteral
+  std::unique_ptr<ExprAST> ParsePrimary() {
+    switch (lexer.getCurToken()) {
+    default:
+      llvm::errs() << "unknown token '" << lexer.getCurToken()
+                   << "' when expecting an expression\n";
+      return nullptr;
+    case tok_identifier:
+      return ParseIdentifierExpr();
+    case tok_number:
+      return ParseNumberExpr();
+    case '(':
+      return ParseParenExpr();
+    case '[':
+      return ParseTensorLitteralExpr();
+    case ';':
+      return nullptr;
+    case '}':
+      return nullptr;
+    }
+  }
+
+  /// Recursively parse the right hand side of a binary expression, the ExprPrec
+  /// argument indicates the precedence of the current binary operator.
+  ///
+  /// binoprhs ::= ('+' primary)*
+  std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
+                                         std::unique_ptr<ExprAST> LHS) {
+    // If this is a binop, find its precedence.
+    while (true) {
+      int TokPrec = GetTokPrecedence();
+
+      // If this is a binop that binds at least as tightly as the current binop,
+      // consume it, otherwise we are done.
+      if (TokPrec < ExprPrec)
+        return LHS;
+
+      // Okay, we know this is a binop.
+      int BinOp = lexer.getCurToken();
+      lexer.consume(Token(BinOp));
+      auto loc = lexer.getLastLocation();
+
+      // Parse the primary expression after the binary operator.
+      auto RHS = ParsePrimary();
+      if (!RHS)
+        return parseError<ExprAST>("expression", "to complete binary operator");
+
+      // If BinOp binds less tightly with RHS than the operator after RHS, let
+      // the pending operator take RHS as its LHS.
+      int NextPrec = GetTokPrecedence();
+      if (TokPrec < NextPrec) {
+        RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
+        if (!RHS)
+          return nullptr;
+      }
+
+      // Merge LHS/RHS.
+      LHS = llvm::make_unique<BinaryExprAST>(std::move(loc), BinOp,
+                                             std::move(LHS), std::move(RHS));
+    }
+  }
+
+  /// expression::= primary binoprhs
+  std::unique_ptr<ExprAST> ParseExpression() {
+    auto LHS = ParsePrimary();
+    if (!LHS)
+      return nullptr;
+
+    return ParseBinOpRHS(0, std::move(LHS));
+  }
+
+  /// type ::= < shape_list >
+  /// shape_list ::= num | num , shape_list
+  std::unique_ptr<VarType> ParseType() {
+    if (lexer.getCurToken() != '<')
+      return parseError<VarType>("<", "to begin type");
+    lexer.getNextToken(); // eat <
+
+    auto type = llvm::make_unique<VarType>();
+
+    while (lexer.getCurToken() == tok_number) {
+      type->shape.push_back(lexer.getValue());
+      lexer.getNextToken();
+      if (lexer.getCurToken() == ',')
+        lexer.getNextToken();
+    }
+
+    if (lexer.getCurToken() != '>')
+      return parseError<VarType>(">", "to end type");
+    lexer.getNextToken(); // eat >
+    return type;
+  }
+
+  /// Parse a variable declaration, it starts with a `var` keyword followed by
+  /// and identifier and an optional type (shape specification) before the
+  /// initializer.
+  /// decl ::= var identifier [ type ] = expr
+  std::unique_ptr<VarDeclExprAST> ParseDeclaration() {
+    if (lexer.getCurToken() != tok_var)
+      return parseError<VarDeclExprAST>("var", "to begin declaration");
+    auto loc = lexer.getLastLocation();
+    lexer.getNextToken(); // eat var
+
+    if (lexer.getCurToken() != tok_identifier)
+      return parseError<VarDeclExprAST>("identified",
+                                        "after 'var' declaration");
+    std::string id = lexer.getId();
+    lexer.getNextToken(); // eat id
+
+    std::unique_ptr<VarType> type; // Type is optional, it can be inferred
+    if (lexer.getCurToken() == '<') {
+      type = ParseType();
+      if (!type)
+        return nullptr;
+    }
+
+    if (!type)
+      type = llvm::make_unique<VarType>();
+    lexer.consume(Token('='));
+    auto expr = ParseExpression();
+    return llvm::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
+                                             std::move(*type), std::move(expr));
+  }
+
+  /// Parse a block: a list of expression separated by semicolons and wrapped in
+  /// curly braces.
+  ///
+  /// block ::= { expression_list }
+  /// expression_list ::= block_expr ; expression_list
+  /// block_expr ::= decl | "return" | expr
+  std::unique_ptr<ExprASTList> ParseBlock() {
+    if (lexer.getCurToken() != '{')
+      return parseError<ExprASTList>("{", "to begin block");
+    lexer.consume(Token('{'));
+
+    auto exprList = llvm::make_unique<ExprASTList>();
+
+    // Ignore empty expressions: swallow sequences of semicolons.
+    while (lexer.getCurToken() == ';')
+      lexer.consume(Token(';'));
+
+    while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
+      if (lexer.getCurToken() == tok_var) {
+        // Variable declaration
+        auto varDecl = ParseDeclaration();
+        if (!varDecl)
+          return nullptr;
+        exprList->push_back(std::move(varDecl));
+      } else if (lexer.getCurToken() == tok_return) {
+        // Return statement
+        auto ret = ParseReturn();
+        if (!ret)
+          return nullptr;
+        exprList->push_back(std::move(ret));
+      } else {
+        // General expression
+        auto expr = ParseExpression();
+        if (!expr)
+          return nullptr;
+        exprList->push_back(std::move(expr));
+      }
+      // Ensure that elements are separated by a semicolon.
+      if (lexer.getCurToken() != ';')
+        return parseError<ExprASTList>(";", "after expression");
+
+      // Ignore empty expressions: swallow sequences of semicolons.
+      while (lexer.getCurToken() == ';')
+        lexer.consume(Token(';'));
+    }
+
+    if (lexer.getCurToken() != '}')
+      return parseError<ExprASTList>("}", "to close block");
+
+    lexer.consume(Token('}'));
+    return exprList;
+  }
+
+  /// prototype ::= def id '(' decl_list ')'
+  /// decl_list ::= identifier | identifier, decl_list
+  std::unique_ptr<PrototypeAST> ParsePrototype() {
+    auto loc = lexer.getLastLocation();
+    lexer.consume(tok_def);
+    if (lexer.getCurToken() != tok_identifier)
+      return parseError<PrototypeAST>("function name", "in prototype");
+
+    std::string FnName = lexer.getId();
+    lexer.consume(tok_identifier);
+
+    if (lexer.getCurToken() != '(')
+      return parseError<PrototypeAST>("(", "in prototype");
+    lexer.consume(Token('('));
+
+    std::vector<std::unique_ptr<VariableExprAST>> args;
+    if (lexer.getCurToken() != ')') {
+      do {
+        std::string name = lexer.getId();
+        auto loc = lexer.getLastLocation();
+        lexer.consume(tok_identifier);
+        auto decl = llvm::make_unique<VariableExprAST>(std::move(loc), name);
+        args.push_back(std::move(decl));
+        if (lexer.getCurToken() != ',')
+          break;
+        lexer.consume(Token(','));
+        if (lexer.getCurToken() != tok_identifier)
+          return parseError<PrototypeAST>(
+              "identifier", "after ',' in function parameter list");
+      } while (true);
+    }
+    if (lexer.getCurToken() != ')')
+      return parseError<PrototypeAST>("}", "to end function prototype");
+
+    // success.
+    lexer.consume(Token(')'));
+    return llvm::make_unique<PrototypeAST>(std::move(loc), FnName,
+                                           std::move(args));
+  }
+
+  /// Parse a function definition, we expect a prototype initiated with the
+  /// `def` keyword, followed by a block containing a list of expressions.
+  ///
+  /// definition ::= prototype block
+  std::unique_ptr<FunctionAST> ParseDefinition() {
+    auto Proto = ParsePrototype();
+    if (!Proto)
+      return nullptr;
+
+    if (auto block = ParseBlock())
+      return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(block));
+    return nullptr;
+  }
+
+  /// Get the precedence of the pending binary operator token.
+  int GetTokPrecedence() {
+    if (!isascii(lexer.getCurToken()))
+      return -1;
+
+    // 1 is lowest precedence.
+    switch (static_cast<char>(lexer.getCurToken())) {
+    case '-':
+      return 20;
+    case '+':
+      return 20;
+    case '*':
+      return 40;
+    default:
+      return -1;
+    }
+  }
+
+  /// Helper function to signal errors while parsing, it takes an argument
+  /// indicating the expected token and another argument giving more context.
+  /// Location is retrieved from the lexer to enrich the error message.
+  template <typename R, typename T, typename U = const char *>
+  std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
+    auto curToken = lexer.getCurToken();
+    llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
+                 << lexer.getLastLocation().col << "): expected '" << expected
+                 << "' " << context << " but has Token " << curToken;
+    if (isprint(curToken))
+      llvm::errs() << " '" << (char)curToken << "'";
+    llvm::errs() << "\n";
+    return nullptr;
+  }
+};
+
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PARSER_H
diff --git a/mlir/examples/toy/Ch5/include/toy/Passes.h b/mlir/examples/toy/Ch5/include/toy/Passes.h
new file mode 100644 (file)
index 0000000..dd73b95
--- /dev/null
@@ -0,0 +1,33 @@
+//===- Passes.h - Toy Passes Definition -----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file exposes the entry points to create compiler passes for Toy.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_PASSES_H
+#define MLIR_TUTORIAL_TOY_PASSES_H
+
+namespace mlir {
+class Pass;
+} // namespace mlir
+
+namespace toy {
+mlir::Pass *createShapeInferencePass();
+} // namespace toy
+
+#endif // MLIR_TUTORIAL_TOY_PASSES_H
diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp
new file mode 100644 (file)
index 0000000..db6ba73
--- /dev/null
@@ -0,0 +1,158 @@
+//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements early lowering of Toy IR to Linalg Dialect: we only
+// lower the computationally intensive part of the program (matmul...) to a
+// dialect specialized for optimizations.
+//
+// This is intended to showcase how multiple dialects can cohabit in the same
+// function. After this lowering, you would still have toy.print in the IR for
+// example.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "linalg1/Intrinsics.h"
+#include "linalg1/ViewOp.h"
+#include "linalg3/TensorOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+
+#include <algorithm>
+
+using namespace mlir;
+
+namespace {
+/// Utility function for type casting: this is making the type checker happy,
+/// while delaying the actual work involved to convert the type. Most of the
+/// time both side of the cast (producer and consumer) will be lowered to a
+/// dialect like LLVM and end up with the same LLVM representation, at which
+/// point this becomes a no-op and is eliminated.
+Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) {
+  if (val->getType() == destTy)
+    return val;
+  return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
+      .getResult();
+}
+
+/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
+/// lowered to a memref during buffer allocation, at which point the type cast
+/// becomes useless.
+Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
+  if (val->getType().isa<MemRefType>())
+    return val;
+  auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
+  if (!toyArrayTy)
+    return val;
+  return typeCast(builder, val, toyArrayTy.toMemref());
+}
+
+/// Lower toy.mul to Linalg `matmul`.
+///
+/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// similarly to the PatternRewriter introduced in the previous chapter.
+/// It will be called by the DialectConversion framework (see `LateLowering`
+/// class below).
+class MulOpConversion : public DialectOpConversion {
+public:
+  explicit MulOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::MulOp::getOperationName(), 1, context) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    using namespace edsc;
+    using intrinsics::constant_index;
+    using linalg::intrinsics::range;
+    using linalg::intrinsics::view;
+    toy::MulOp mul = op->cast<toy::MulOp>();
+    auto loc = mul.getLoc();
+    Value *result = memRefTypeCast(
+        rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
+                      .getResult());
+    Value *lhs = memRefTypeCast(rewriter, operands[0]);
+    auto memrefLHSTy = lhs->getType().cast<MemRefType>();
+    Value *rhs = memRefTypeCast(rewriter, operands[1]);
+    auto memrefRHSTy = rhs->getType().cast<MemRefType>();
+    mlir::edsc::ScopedContext scope(rewriter, loc);
+    edsc::ValueHandle r0 =
+        range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)),
+              constant_index(1));
+    edsc::ValueHandle r1 =
+        range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)),
+              constant_index(1));
+    edsc::ValueHandle r2 =
+        range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)),
+              constant_index(1));
+    auto lhsView = view(lhs, {r0, r1});
+    auto rhsView = view(rhs, {r1, r2});
+    auto resultView = view(result, {r0, r2});
+    rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
+    return {typeCast(rewriter, result, mul.getType())};
+  }
+};
+
+// The conversion class from Toy IR Dialect to a mix of Linalg and LLVM.
+class EarlyLowering : public DialectConversion {
+protected:
+  // Initialize the list of converters.
+  llvm::DenseSet<DialectOpConversion *>
+  initConverters(MLIRContext *context) override {
+    return ConversionListBuilder<MulOpConversion>::build(&allocator, context);
+  }
+
+private:
+  llvm::BumpPtrAllocator allocator;
+};
+
+/// This is lowering to Linalg the parts that are computationally intensive
+/// (like matmul for example...) while keeping the rest of the code in the Toy
+/// dialect.
+struct EarlyLoweringPass : public ModulePass<EarlyLoweringPass> {
+
+  void runOnModule() override {
+    if (failed(EarlyLowering().convert(&getModule()))) {
+      getModule().getContext()->emitError(
+          mlir::UnknownLoc::get(getModule().getContext()),
+          "Error lowering Toy\n");
+      signalPassFailure();
+    }
+  }
+};
+} // end anonymous namespace
+
+namespace toy {
+Pass *createEarlyLoweringPass() { return new EarlyLoweringPass(); }
+
+std::unique_ptr<mlir::DialectConversion> makeToyEarlyLowering() {
+  return llvm::make_unique<EarlyLowering>();
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp
new file mode 100644 (file)
index 0000000..4ef62d3
--- /dev/null
@@ -0,0 +1,452 @@
+//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements late lowering of IR mixing Toy and Linalg to LLVM.
+// It involves intemerdiate steps:
+// -
+// - a mix of affine and standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "linalg1/Intrinsics.h"
+#include "linalg1/ViewOp.h"
+#include "linalg3/ConvertToLLVMDialect.h"
+#include "linalg3/TensorOps.h"
+#include "linalg3/Transforms.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Type.h"
+
+#include <algorithm>
+
+using namespace mlir;
+
+namespace {
+/// Utility function for type casting: this is making the type checker happy,
+/// while delaying the actual work involved to convert the type. Most of the
+/// time both side of the cast (producer and consumer) will be lowered to a
+/// dialect like LLVM and end up with the same LLVM representation, at which
+/// point this becomes a no-op and is eliminated.
+Value *typeCast(FuncBuilder &builder, Value *val, Type destTy) {
+  if (val->getType() == destTy)
+    return val;
+  return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
+      .getResult();
+}
+
+/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
+/// lowered to a memref during buffer allocation, at which point the type cast
+/// becomes useless.
+Value *memRefTypeCast(FuncBuilder &builder, Value *val) {
+  if (val->getType().isa<MemRefType>())
+    return val;
+  auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
+  if (!toyArrayTy)
+    return val;
+  return typeCast(builder, val, toyArrayTy.toMemref());
+}
+
+/// Lower a toy.add to an affine loop nest.
+///
+/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// similarly to the PatternRewriter introduced in the previous chapter.
+/// It will be called by the DialectConversion framework (see `LateLowering`
+/// class below).
+class AddOpConversion : public DialectOpConversion {
+public:
+  explicit AddOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
+
+  /// Lower the `op` by generating IR using the `rewriter` builder. The builder
+  /// is setup with a new function, the `operands` array has been populated with
+  /// the rewritten operands for `op` in the new function.
+  /// The results created by the new IR with the builder are returned, and their
+  /// number must match the number of result of `op`.
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto add = op->cast<toy::AddOp>();
+    auto loc = add.getLoc();
+    // Create a `toy.alloc` operation to allocate the output buffer for this op.
+    Value *result = memRefTypeCast(
+        rewriter, rewriter.create<toy::AllocOp>(loc, add.getResult()->getType())
+                      .getResult());
+    Value *lhs = memRefTypeCast(rewriter, operands[0]);
+    Value *rhs = memRefTypeCast(rewriter, operands[1]);
+
+    using namespace edsc;
+    ScopedContext scope(rewriter, loc);
+    ValueHandle zero = intrinsics::constant_index(0);
+    MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
+    IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
+    IndexHandle i, j, M(vRes.ub(0));
+    if (vRes.rank() == 1) {
+      LoopNestBuilder({&i}, {zero}, {M}, {1})({iRes(i) = iLHS(i) + iRHS(i)});
+    } else {
+      assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
+      IndexHandle N(vRes.ub(1));
+      LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
+                      {1, 1})({iRes(i, j) = iLHS(i, j) + iRHS(i, j)});
+    }
+
+    // Return the newly allocated buffer, with a type.cast to preserve the
+    // consumers.
+    return {typeCast(rewriter, result, add.getType())};
+  }
+};
+
+/// Lowers `toy.print` to a loop nest calling `printf` on every individual
+/// elements of the array.
+class PrintOpConversion : public DialectOpConversion {
+public:
+  explicit PrintOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::PrintOp::getOperationName(), 1, context) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    // Get or create the declaration of the printf function in the module.
+    Function *printfFunc = getPrintf(*op->getFunction()->getModule());
+
+    auto print = op->cast<toy::PrintOp>();
+    auto loc = print.getLoc();
+    // We will operate on a MemRef abstraction, we use a type.cast to get one
+    // if our operand is still a Toy array.
+    Value *operand = memRefTypeCast(rewriter, operands[0]);
+    Type retTy = printfFunc->getType().getResult(0);
+
+    // Create our loop nest now
+    using namespace edsc;
+    using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
+    ScopedContext scope(rewriter, loc);
+    ValueHandle zero = intrinsics::constant_index(0);
+    ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
+    MemRefView vOp(operand);
+    IndexedValue iOp(operand);
+    IndexHandle i, j, M(vOp.ub(0));
+
+    ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
+    if (vOp.rank() == 1) {
+      // clang-format off
+      LoopBuilder(&i, zero, M, 1)({
+        llvmCall(retTy,
+                 rewriter.getFunctionAttr(printfFunc),
+                 {fmtCst, iOp(i)})
+      });
+      llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol});
+      // clang-format on
+    } else {
+      IndexHandle N(vOp.ub(1));
+      // clang-format off
+      LoopBuilder(&i, zero, M, 1)({
+        LoopBuilder(&j, zero, N, 1)({
+          llvmCall(retTy,
+                   rewriter.getFunctionAttr(printfFunc),
+                   {fmtCst, iOp(i, j)})
+        }),
+        llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
+      });
+      // clang-format on
+    }
+    return {};
+  }
+
+private:
+  // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence
+  // of stores into the buffer, and return a MemRef into the buffer.
+  Value *getConstantCharBuffer(FuncBuilder &builder, Location loc,
+                               StringRef data) const {
+    auto retTy =
+        builder.getMemRefType(data.size() + 1, builder.getIntegerType(8));
+    Value *result = builder.create<toy::AllocOp>(loc, retTy).getResult();
+    using namespace edsc;
+    using intrinsics::constant_index;
+    using intrinsics::constant_int;
+    ScopedContext scope(builder, loc);
+    MemRefView vOp(result);
+    IndexedValue iOp(result);
+    for (uint64_t i = 0; i < data.size(); ++i) {
+      iOp(constant_index(i)) = constant_int(data[i], 8);
+    }
+    iOp(constant_index(data.size())) = constant_int(0, 8);
+    return result;
+  }
+
+  /// Return the prototype declaration for printf in the module, create it if
+  /// necessary.
+  Function *getPrintf(Module &module) const {
+    auto *printfFunc = module.getNamedFunction("printf");
+    if (printfFunc)
+      return printfFunc;
+
+    // Create a function declaration for printf, signature is `i32 (i8*, ...)`
+    Builder builder(&module);
+    MLIRContext *context = module.getContext();
+    LLVM::LLVMDialect *llvmDialect = static_cast<LLVM::LLVMDialect *>(
+        module.getContext()->getRegisteredDialect("llvm"));
+    auto &llvmModule = llvmDialect->getLLVMModule();
+    llvm::IRBuilder<> llvmBuilder(llvmModule.getContext());
+
+    auto llvmI32Ty = LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(32));
+    auto llvmI8PtrTy =
+        LLVM::LLVMType::get(context, llvmBuilder.getIntNTy(8)->getPointerTo());
+    auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty});
+    printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy);
+    // It should be variadic, but we don't support it fully just yet.
+    printfFunc->setAttr("std.varargs", builder.getBoolAttr(true));
+    module.getFunctions().push_back(printfFunc);
+    return printfFunc;
+  }
+};
+
+/// Lowers constant to a sequence of store in a buffer.
+class ConstantOpConversion : public DialectOpConversion {
+public:
+  explicit ConstantOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::ConstantOp::getOperationName(), 1, context) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    toy::ConstantOp cstOp = op->cast<toy::ConstantOp>();
+    auto loc = cstOp.getLoc();
+    auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
+    auto shape = retTy.getShape();
+    Value *result = memRefTypeCast(
+        rewriter, rewriter.create<toy::AllocOp>(loc, retTy).getResult());
+
+    auto cstValue = cstOp.getValue();
+    auto f64Ty = rewriter.getF64Type();
+    using namespace edsc;
+    using intrinsics::constant_float;
+    using intrinsics::constant_index;
+    ScopedContext scope(rewriter, loc);
+    MemRefView vOp(result);
+    IndexedValue iOp(result);
+    for (uint64_t i = 0; i < shape[0]; ++i) {
+      if (shape.size() == 1) {
+        auto value = cstValue.getValue(ArrayRef<uint64_t>{i})
+                         .cast<FloatAttr>()
+                         .getValue();
+        iOp(constant_index(i)) = constant_float(value, f64Ty);
+        continue;
+      }
+      for (uint64_t j = 0; j < shape[1]; ++j) {
+        auto value = cstValue.getValue(ArrayRef<uint64_t>{i, j})
+                         .cast<FloatAttr>()
+                         .getValue();
+        iOp(constant_index(i), constant_index(j)) =
+            constant_float(value, f64Ty);
+      }
+    }
+    return {result};
+  }
+};
+
+/// Lower transpose operation to an affine loop nest.
+class TransposeOpConversion : public DialectOpConversion {
+public:
+  explicit TransposeOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::TransposeOp::getOperationName(), 1, context) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto transpose = op->cast<toy::TransposeOp>();
+    auto loc = transpose.getLoc();
+    Value *result = memRefTypeCast(
+        rewriter,
+        rewriter.create<toy::AllocOp>(loc, transpose.getResult()->getType())
+            .getResult());
+    Value *operand = memRefTypeCast(rewriter, operands[0]);
+
+    using namespace edsc;
+    ScopedContext scope(rewriter, loc);
+    ValueHandle zero = intrinsics::constant_index(0);
+    MemRefView vRes(result), vOperand(operand);
+    IndexedValue iRes(result), iOperand(operand);
+    IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
+    // clang-format off
+    LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})({
+      iRes(i, j) = iOperand(j, i)
+    });
+    // clang-format on
+
+    return {typeCast(rewriter, result, transpose.getType())};
+  }
+};
+
+// Lower toy.return to standard return operation.
+class ReturnOpConversion : public DialectOpConversion {
+public:
+  explicit ReturnOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::ReturnOp::getOperationName(), 1, context) {}
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto retOp = op->cast<toy::ReturnOp>();
+    using namespace edsc;
+    auto loc = retOp.getLoc();
+    // Argument is optional, handle both cases.
+    if (retOp.getNumOperands())
+      rewriter.create<ReturnOp>(loc, operands[0]);
+    else
+      rewriter.create<ReturnOp>(loc);
+    return {};
+  }
+};
+
+/// This is the main class registering our individual converter classes with
+/// the DialectConversion framework in MLIR.
+class LateLowering : public DialectConversion {
+protected:
+  /// Initialize the list of converters.
+  llvm::DenseSet<DialectOpConversion *>
+  initConverters(MLIRContext *context) override {
+    return ConversionListBuilder<AddOpConversion, PrintOpConversion,
+                                 ConstantOpConversion, TransposeOpConversion,
+                                 ReturnOpConversion>::build(&allocator,
+                                                            context);
+  }
+
+  /// Convert a Toy type, this gets called for block and region arguments, and
+  /// attributes.
+  Type convertType(Type t) override {
+    if (auto array = t.cast<toy::ToyArrayType>()) {
+      return array.toMemref();
+    }
+    return t;
+  }
+
+private:
+  llvm::BumpPtrAllocator allocator;
+};
+
+/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
+/// and is targeting LLVM otherwise.
+struct LateLoweringPass : public ModulePass<LateLoweringPass> {
+
+  void runOnModule() override {
+    // Perform Toy specific lowering
+    if (failed(LateLowering().convert(&getModule()))) {
+      getModule().getContext()->emitError(
+          UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n");
+      signalPassFailure();
+    }
+    // At this point the IR is almost using only standard and affine dialects.
+    // A few things remain before we emit LLVM IR. First to reuse as much of
+    // MLIR as possible we will try to lower everything to the standard and/or
+    // affine dialect: they already include conversion to the LLVM dialect.
+
+    // First patch calls type to return memref instead of ToyArray
+    for (auto &function : getModule()) {
+      function.walk([&](Operation *op) {
+        auto callOp = op->dyn_cast<CallOp>();
+        if (!callOp)
+          return;
+        if (!callOp.getNumResults())
+          return;
+        auto retToyTy =
+            callOp.getResult(0)->getType().dyn_cast<toy::ToyArrayType>();
+        if (!retToyTy)
+          return;
+        callOp.getResult(0)->setType(retToyTy.toMemref());
+      });
+    }
+
+    for (auto &function : getModule()) {
+      function.walk([&](Operation *op) {
+        // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
+        if (auto allocOp = op->dyn_cast<toy::AllocOp>()) {
+          auto result = allocTensor(allocOp);
+          allocOp.replaceAllUsesWith(result);
+          allocOp.erase();
+          return;
+        }
+        // Eliminate all type.cast before lowering to LLVM.
+        if (auto typeCastOp = op->dyn_cast<toy::TypeCastOp>()) {
+          typeCastOp.replaceAllUsesWith(typeCastOp.getOperand());
+          typeCastOp.erase();
+          return;
+        }
+      });
+    }
+
+    // Lower Linalg to affine
+    for (auto &function : getModule())
+      linalg::lowerToLoops(&function);
+
+    getModule().dump();
+
+    // Finally convert to LLVM Dialect
+    linalg::convertLinalg3ToLLVM(getModule());
+  }
+
+  /// Allocate buffers (malloc/free) for Toy operations. This can't be done as
+  /// part of dialect conversion framework since we need to insert `dealloc`
+  /// operations just before the return, but the conversion framework is
+  /// operating in a brand new function: we don't have the return to hook the
+  /// dealloc operations.
+  Value *allocTensor(toy::AllocOp alloc) {
+    FuncBuilder builder(alloc);
+    auto retTy = alloc.getResult()->getType();
+
+    auto memRefTy = retTy.dyn_cast<MemRefType>();
+    if (!memRefTy)
+      memRefTy = retTy.cast<toy::ToyArrayType>().toMemref();
+    if (!memRefTy) {
+      alloc.emitOpError("is expected to allocate a Toy array or a MemRef");
+      llvm_unreachable("fatal error");
+    }
+    auto loc = alloc.getLoc();
+    Value *result = builder.create<AllocOp>(loc, memRefTy).getResult();
+
+    // Insert a `dealloc` operation right before the `return` operations, unless
+    // it is returned itself in which case the caller is responsible for it.
+    builder.getFunction()->walk([&](Operation *op) {
+      auto returnOp = op->dyn_cast<ReturnOp>();
+      if (!returnOp)
+        return;
+      if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc)
+        return;
+      builder.setInsertionPoint(returnOp);
+      builder.create<DeallocOp>(alloc.getLoc(), result);
+    });
+    return result;
+  }
+};
+} // end anonymous namespace
+
+namespace toy {
+Pass *createLateLoweringPass() { return new LateLoweringPass(); }
+
+std::unique_ptr<DialectConversion> makeToyLateLowering() {
+  return llvm::make_unique<LateLowering>();
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
new file mode 100644 (file)
index 0000000..e2001fb
--- /dev/null
@@ -0,0 +1,480 @@
+//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR generation targeting MLIR from a Module AST
+// for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/MLIRGen.h"
+#include "toy/AST.h"
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/StandardOps/Ops.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/raw_ostream.h"
+#include <numeric>
+
+using namespace toy;
+using llvm::cast;
+using llvm::dyn_cast;
+using llvm::isa;
+using llvm::make_unique;
+using llvm::ScopedHashTableScope;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+/// Implementation of a simple MLIR emission from the Toy AST.
+///
+/// This will emit operations that are specific to the Toy language, preserving
+/// the semantics of the language and (hopefully) allow to perform accurate
+/// analysis and transformation based on these high level semantics.
+///
+/// At this point we take advantage of the "raw" MLIR APIs to create operations
+/// that haven't been registered in any way with MLIR. These operations are
+/// unknown to MLIR, custom passes could operate by string-matching the name of
+/// these operations, but no other type checking or semantic is associated with
+/// them natively by MLIR.
+class MLIRGenImpl {
+public:
+  MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
+
+  /// Public API: convert the AST for a Toy module (source file) to an MLIR
+  /// Module.
+  std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+    // We create an empty MLIR module and codegen functions one at a time and
+    // add them to the module.
+    theModule = make_unique<mlir::Module>(&context);
+
+    for (FunctionAST &F : moduleAST) {
+      auto func = mlirGen(F);
+      if (!func)
+        return nullptr;
+      theModule->getFunctions().push_back(func.release());
+    }
+
+    // FIXME: (in the next chapter...) without registering a dialect in MLIR,
+    // this won't do much, but it should at least check some structural
+    // properties.
+    if (failed(theModule->verify())) {
+      context.emitError(mlir::UnknownLoc::get(&context),
+                        "Module verification error");
+      return nullptr;
+    }
+
+    return std::move(theModule);
+  }
+
+private:
+  /// In MLIR (like in LLVM) a "context" object holds the memory allocation and
+  /// the ownership of many internal structure of the IR and provide a level
+  /// of "uniquing" across multiple modules (types for instance).
+  mlir::MLIRContext &context;
+
+  /// A "module" matches a source file: it contains a list of functions.
+  std::unique_ptr<mlir::Module> theModule;
+
+  /// The builder is a helper class to create IR inside a function. It is
+  /// re-initialized every time we enter a function and kept around as a
+  /// convenience for emitting individual operations.
+  /// The builder is stateful, in particular it keeeps an "insertion point":
+  /// this is where the next operations will be introduced.
+  std::unique_ptr<mlir::FuncBuilder> builder;
+
+  /// The symbol table maps a variable name to a value in the current scope.
+  /// Entering a function creates a new scope, and the function arguments are
+  /// added to the mapping. When the processing of a function is terminated, the
+  /// scope is destroyed and the mappings created in this scope are dropped.
+  llvm::ScopedHashTable<StringRef, mlir::Value *> symbolTable;
+
+  /// Helper conversion for a Toy AST location to an MLIR location.
+  mlir::FileLineColLoc loc(Location loc) {
+    return mlir::FileLineColLoc::get(
+        mlir::UniquedFilename::get(*loc.file, &context), loc.line, loc.col,
+        &context);
+  }
+
+  /// Declare a variable in the current scope, return true if the variable
+  /// wasn't declared yet.
+  bool declare(llvm::StringRef var, mlir::Value *value) {
+    if (symbolTable.count(var)) {
+      return false;
+    }
+    symbolTable.insert(var, value);
+    return true;
+  }
+
+  /// Create the prototype for an MLIR function with as many arguments as the
+  /// provided Toy AST prototype.
+  mlir::Function *mlirGen(PrototypeAST &proto) {
+    // This is a generic function, the return type will be inferred later.
+    llvm::SmallVector<mlir::Type, 4> ret_types;
+    // Arguments type is uniformly a generic array.
+    llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
+                                               getType(VarType{}));
+    auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
+    auto *function = new mlir::Function(loc(proto.loc()), proto.getName(),
+                                        func_type, /* attrs = */ {});
+
+    // Mark the function as generic: it'll require type specialization for every
+    // call site.
+    if (function->getNumArguments())
+      function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
+
+    return function;
+  }
+
+  /// Emit a new function and add it to the MLIR module.
+  std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) {
+    // Create a scope in the symbol table to hold variable declarations.
+    ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+
+    // Create an MLIR function for the given prototype.
+    std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto()));
+    if (!function)
+      return nullptr;
+
+    // Let's start the body of the function now!
+    // In MLIR the entry block of the function is special: it must have the same
+    // argument list as the function itself.
+    function->addEntryBlock();
+
+    auto &entryBlock = function->front();
+    auto &protoArgs = funcAST.getProto()->getArgs();
+    // Declare all the function arguments in the symbol table.
+    for (const auto &name_value :
+         llvm::zip(protoArgs, entryBlock.getArguments())) {
+      declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
+    }
+
+    // Create a builder for the function, it will be used throughout the codegen
+    // to create operations in this function.
+    builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
+
+    // Emit the body of the function.
+    if (!mlirGen(*funcAST.getBody()))
+      return nullptr;
+
+    // Implicitly return void if no return statement was emited.
+    // FIXME: we may fix the parser instead to always return the last expression
+    // (this would possibly help the REPL case later)
+    if (function->getBlocks().back().back().getName().getStringRef() !=
+        "toy.return") {
+      ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
+      mlirGen(fakeRet);
+    }
+
+    return function;
+  }
+
+  /// Emit a binary operation
+  mlir::Value *mlirGen(BinaryExprAST &binop) {
+    // First emit the operations for each side of the operation before emitting
+    // the operation itself. For example if the expression is `a + foo(a)`
+    // 1) First it will visiting the LHS, which will return a reference to the
+    //    value holding `a`. This value should have been emitted at declaration
+    //    time and registered in the symbol table, so nothing would be
+    //    codegen'd. If the value is not in the symbol table, an error has been
+    //    emitted and nullptr is returned.
+    // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted
+    //    and the result value is returned. If an error occurs we get a nullptr
+    //    and propagate.
+    //
+    mlir::Value *L = mlirGen(*binop.getLHS());
+    if (!L)
+      return nullptr;
+    mlir::Value *R = mlirGen(*binop.getRHS());
+    if (!R)
+      return nullptr;
+    auto location = loc(binop.loc());
+
+    // Derive the operation name from the binary operator. At the moment we only
+    // support '+' and '*'.
+    switch (binop.getOp()) {
+    case '+':
+      return builder->create<AddOp>(location, L, R).getResult();
+      break;
+    case '*':
+      return builder->create<MulOp>(location, L, R).getResult();
+    default:
+      context.emitError(loc(binop.loc()),
+                        Twine("Error: invalid binary operator '") +
+                            Twine(binop.getOp()) + "'");
+      return nullptr;
+    }
+  }
+
+  // This is a reference to a variable in an expression. The variable is
+  // expected to have been declared and so should have a value in the symbol
+  // table, otherwise emit an error and return nullptr.
+  mlir::Value *mlirGen(VariableExprAST &expr) {
+    if (symbolTable.count(expr.getName()))
+      return symbolTable.lookup(expr.getName());
+    context.emitError(loc(expr.loc()), Twine("Error: unknown variable '") +
+                                           expr.getName() + "'");
+    return nullptr;
+  }
+
+  // Emit a return operation, return true on success.
+  bool mlirGen(ReturnExprAST &ret) {
+    auto location = loc(ret.loc());
+    // `return` takes an optional expression, we need to account for it here.
+    if (!ret.getExpr().hasValue()) {
+      builder->create<ReturnOp>(location);
+      return true;
+    }
+    auto *expr = mlirGen(*ret.getExpr().getValue());
+    if (!expr)
+      return false;
+    builder->create<ReturnOp>(location, expr);
+    return true;
+  }
+
+  // Emit a literal/constant array. It will be emitted as a flattened array of
+  // data in an Attribute attached to a `toy.constant` operation.
+  // See documentation on [Attributes](LangRef.md#attributes) for more details.
+  // Here is an excerpt:
+  //
+  //   Attributes are the mechanism for specifying constant data in MLIR in
+  //   places where a variable is never allowed [...]. They consist of a name
+  //   and a [concrete attribute value](#attribute-values). It is possible to
+  //   attach attributes to operations, functions, and function arguments. The
+  //   set of expected attributes, their structure, and their interpretation
+  //   are all contextually dependent on what they are attached to.
+  //
+  // Example, the source level statement:
+  //   var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+  // will be converted to:
+  //   %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
+  //     [[1.000000e+00, 2.000000e+00, 3.000000e+00],
+  //      [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
+  //
+  mlir::Value *mlirGen(LiteralExprAST &lit) {
+    auto location = loc(lit.loc());
+    // The attribute is a vector with an attribute per element (number) in the
+    // array, see `collectData()` below for more details.
+    std::vector<mlir::Attribute> data;
+    data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
+                                 std::multiplies<int>()));
+    collectData(lit, data);
+
+    // FIXME: using a tensor type is a HACK here.
+    // Can we do differently without registering a dialect? Using a string blob?
+    mlir::Type elementType = mlir::FloatType::getF64(&context);
+    auto dataType = builder->getTensorType(lit.getDims(), elementType);
+
+    // This is the actual attribute that actually hold the list of values for
+    // this array literal.
+    auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
+                             .cast<mlir::DenseElementsAttr>();
+
+    // Build the MLIR op `toy.constant`, only boilerplate below.
+    return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
+        .getResult();
+  }
+
+  // Recursive helper function to accumulate the data that compose an array
+  // literal. It flattens the nested structure in the supplied vector. For
+  // example with this array:
+  //  [[1, 2], [3, 4]]
+  // we will generate:
+  //  [ 1, 2, 3, 4 ]
+  // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
+  // Attributes are the way MLIR attaches constant to operations and functions.
+  void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
+    if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
+      for (auto &value : lit->getValues())
+        collectData(*value, data);
+      return;
+    }
+    assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
+    mlir::Type elementType = mlir::FloatType::getF64(&context);
+    auto attr = mlir::FloatAttr::getChecked(
+        elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
+    data.push_back(attr);
+  }
+
+  // Emit a call expression. It emits specific operations for the `transpose`
+  // builtin. Other identifiers are assumed to be user-defined functions.
+  mlir::Value *mlirGen(CallExprAST &call) {
+    auto location = loc(call.loc());
+    std::string callee = call.getCallee();
+    if (callee == "transpose") {
+      if (call.getArgs().size() != 1) {
+        context.emitError(
+            location, Twine("MLIR codegen encountered an error: toy.transpose "
+                            "does not accept multiple arguments"));
+        return nullptr;
+      }
+      mlir::Value *arg = mlirGen(*call.getArgs()[0]);
+      return builder->create<TransposeOp>(location, arg).getResult();
+    }
+
+    // Codegen the operands first
+    SmallVector<mlir::Value *, 4> operands;
+    for (auto &expr : call.getArgs()) {
+      auto *arg = mlirGen(*expr);
+      if (!arg)
+        return nullptr;
+      operands.push_back(arg);
+    }
+    // Calls to user-defined function are mapped to a custom call that takes
+    // the callee name as an attribute.
+    return builder->create<GenericCallOp>(location, call.getCallee(), operands)
+        .getResult();
+  }
+
+  // Emit a call expression. It emits specific operations for two builtins:
+  // transpose(x) and print(x). Other identifiers are assumed to be user-defined
+  // functions. Return false on failure.
+  bool mlirGen(PrintExprAST &call) {
+    auto *arg = mlirGen(*call.getArg());
+    if (!arg)
+      return false;
+    auto location = loc(call.loc());
+    builder->create<PrintOp>(location, arg);
+    return true;
+  }
+
+  // Emit a constant for a single number (FIXME: semantic? broadcast?)
+  mlir::Value *mlirGen(NumberExprAST &num) {
+    auto location = loc(num.loc());
+    mlir::Type elementType = mlir::FloatType::getF64(&context);
+    auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
+                                            loc(num.loc()));
+    return builder->create<ConstantOp>(location, attr).getResult();
+  }
+
+  // Dispatch codegen for the right expression subclass using RTTI.
+  mlir::Value *mlirGen(ExprAST &expr) {
+    switch (expr.getKind()) {
+    case toy::ExprAST::Expr_BinOp:
+      return mlirGen(cast<BinaryExprAST>(expr));
+    case toy::ExprAST::Expr_Var:
+      return mlirGen(cast<VariableExprAST>(expr));
+    case toy::ExprAST::Expr_Literal:
+      return mlirGen(cast<LiteralExprAST>(expr));
+    case toy::ExprAST::Expr_Call:
+      return mlirGen(cast<CallExprAST>(expr));
+    case toy::ExprAST::Expr_Num:
+      return mlirGen(cast<NumberExprAST>(expr));
+    default:
+      context.emitError(
+          loc(expr.loc()),
+          Twine("MLIR codegen encountered an unhandled expr kind '") +
+              Twine(expr.getKind()) + "'");
+      return nullptr;
+    }
+  }
+
+  // Handle a variable declaration, we'll codegen the expression that forms the
+  // initializer and record the value in the symbol table before returning it.
+  // Future expressions will be able to reference this variable through symbol
+  // table lookup.
+  mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
+    mlir::Value *value = nullptr;
+    auto location = loc(vardecl.loc());
+    if (auto init = vardecl.getInitVal()) {
+      value = mlirGen(*init);
+      if (!value)
+        return nullptr;
+      // We have the initializer value, but in case the variable was declared
+      // with specific shape, we emit a "reshape" operation. It will get
+      // optimized out later as needed.
+      if (!vardecl.getType().shape.empty()) {
+        value = builder
+                    ->create<ReshapeOp>(
+                        location, value,
+                        getType(vardecl.getType()).cast<ToyArrayType>())
+                    .getResult();
+      }
+    } else {
+      context.emitError(loc(vardecl.loc()),
+                        "Missing initializer in variable declaration");
+      return nullptr;
+    }
+    // Register the value in the symbol table
+    declare(vardecl.getName(), value);
+    return value;
+  }
+
+  /// Codegen a list of expression, return false if one of them hit an error.
+  bool mlirGen(ExprASTList &blockAST) {
+    ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
+    for (auto &expr : blockAST) {
+      // Specific handling for variable declarations, return statement, and
+      // print. These can only appear in block list and not in nested
+      // expressions.
+      if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
+        if (!mlirGen(*vardecl))
+          return false;
+        continue;
+      }
+      if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
+        if (!mlirGen(*ret))
+          return false;
+        return true;
+      }
+      if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
+        if (!mlirGen(*print))
+          return false;
+        continue;
+      }
+      // Generic expression dispatch codegen.
+      if (!mlirGen(*expr))
+        return false;
+    }
+    return true;
+  }
+
+  /// Build a type from a list of shape dimensions. Types are `array` followed
+  /// by an optional dimension list, example: array<2, 2>
+  /// They are wrapped in a `toy` dialect (see next chapter) and get printed:
+  ///   !toy.array<2, 2>
+  template <typename T> mlir::Type getType(T shape) {
+    SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
+    return ToyArrayType::get(&context, shape64);
+  }
+
+  /// Build an MLIR type from a Toy AST variable type
+  /// (forward to the generic getType(T) above).
+  mlir::Type getType(const VarType &type) { return getType(type.shape); }
+};
+
+} // namespace
+
+namespace toy {
+
+// The public API for codegen.
+std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
+                                      ModuleAST &moduleAST) {
+  return MLIRGenImpl(context).mlirGen(moduleAST);
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
new file mode 100644 (file)
index 0000000..7e3ea3f
--- /dev/null
@@ -0,0 +1,387 @@
+//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a Module level pass performing interprocedural
+// propagation of array shapes through function specialization.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "toy-shape-inference"
+
+using namespace toy;
+using llvm::MutableArrayRef;
+using llvm::SmallVector;
+using llvm::SmallVectorImpl;
+using llvm::StringRef;
+using llvm::Twine;
+
+/// Create mangled name for function specialization. We will simply append the
+/// shape of the arguments to the function name. For example calling
+///
+///   "toy.generic_call"(%1, %3) {callee: "foo"}
+///       : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
+///
+/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
+/// have provide a function with a similar name. But we will claim this as a
+/// feature: this allow the user to provide custom specialization!
+static std::string mangle(StringRef funcName,
+                          MutableArrayRef<mlir::OpOperand> operands) {
+  std::string mangledName;
+  mangledName.reserve(funcName.size() + operands.size() * 6);
+  mangledName = funcName;
+  for (auto &operand : operands) {
+    auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
+    mangledName += "_";
+    const char *sep = "";
+    for (auto dim : arrayTy.getShape()) {
+      mangledName += (sep + Twine(dim)).str();
+      sep = "x";
+    }
+  }
+  return mangledName;
+}
+
+namespace {
+
+/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
+/// whole. MLIR also supports FunctionPass which are restricted to modify a
+/// single function at a time. This pass couldn't be a function pass due the
+/// nature of its interprocedural transformations.
+///
+/// The algorithm has two levels, first intra-procedurally:
+///
+///   1) Build a worklist containing all the operations that are returning
+///      a generic Toy array: these are the operations that need shape
+///      inference.
+///   2) Iterate on the worklist:
+///     a) find an operation to process: the next ready operation in the
+///        worklist has all of its arguments non-generic,
+///     b) if no operation is found, break out of the loop,
+///     c) remove the operation from the worklist,
+///     d) infer the shape of its output from the arguments type.
+///   3) If the worklist is empty, the algorithm succeeded and we infer the
+///      return type for the function from the return operation.
+///
+/// There is a twist though: when a call to a generic function is encountered,
+/// shape inference requires the return type of the callee to be inferred first.
+/// At this point we need to run specialize the callee by cloning it. Here is
+/// the inter-procedural flow:
+///
+///   1) Keep a worklist of function to process. Start with function "main".
+///   2) While the worklist isn't empty:
+///     a) Take the last inserted function in the worklist.
+///     b) Run the intra-procedural shape inference on this function.
+///     c) If the intra-procedural shape inference can't complete, it returns
+///        a Function that needs to be inferred first. In this case, queue this
+///        new function and continue. Otherwise the inference succeeded and we
+///        can pop from the queue.
+///
+class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
+public:
+  // One entry in the inter-procedural worklist. It keeps track of the
+  // function to process, the mangled name for this specialization, and the
+  // types of the arguments on which to specialize.
+  struct FunctionToSpecialize {
+    mlir::Function *function;
+    std::string mangledName;
+    std::vector<mlir::Type> argumentsType;
+  };
+
+  void runOnModule() override {
+    auto &module = getModule();
+    auto *main = module.getNamedFunction("main");
+    if (!main) {
+      module.getContext()->emitError(
+          mlir::UnknownLoc::get(module.getContext()),
+          "Shape inference failed: can't find a main function\n");
+      signalPassFailure();
+      return;
+    }
+
+    /// Inter-procedural loop, initialize with `main` and iterate till
+    /// successfully infer the full reachable call-graph from main.
+    SmallVector<FunctionToSpecialize, 8> worklist;
+    worklist.push_back({main, "", {}});
+    while (!worklist.empty()) {
+      if (failed(specialize(worklist)))
+        return;
+    }
+
+    // Delete any generic function left
+    // FIXME: we may want this as a separate pass.
+    for (mlir::Function &function : llvm::make_early_inc_range(module)) {
+      if (auto genericAttr =
+              function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
+        if (genericAttr.getValue())
+          function.erase();
+      }
+    }
+  }
+
+  /// Run inference on a function. If a mangledName is provided, we need to
+  /// specialize the function: to this end clone it first.
+  mlir::LogicalResult
+  specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
+    FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
+    mlir::Function *f = functionToSpecialize.function;
+
+    // Check if cloning for specialization is needed (usually anything but main)
+    // We will create a new function with the concrete types for the parameters
+    // and clone the body into it.
+    if (!functionToSpecialize.mangledName.empty()) {
+      if (getModule().getNamedFunction(functionToSpecialize.mangledName)) {
+        funcWorklist.pop_back();
+        // Function already specialized, move on.
+        return mlir::success();
+      }
+      // Create a new function with a generic array return type, it will be
+      // updated when the inference for the function body completes.
+      auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
+                                          {ToyArrayType::get(&getContext())},
+                                          &getContext());
+      auto *newFunction = new mlir::Function(
+          f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
+      getModule().getFunctions().push_back(newFunction);
+
+      // Clone the function body
+      mlir::BlockAndValueMapping mapper;
+      f->cloneInto(newFunction, mapper);
+      LLVM_DEBUG({
+        llvm::dbgs() << "====== Cloned : \n";
+        f->dump();
+        llvm::dbgs() << "====== Into : \n";
+        newFunction->dump();
+      });
+      f = newFunction;
+      f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+      // Remap the entry-block arguments
+      // FIXME: this seems like a bug in `cloneInto()` above?
+      auto &entryBlock = f->getBlocks().front();
+      int blockArgSize = entryBlock.getArguments().size();
+      assert(blockArgSize == f->getType().getInputs().size());
+      entryBlock.addArguments(f->getType().getInputs());
+      auto argList = entryBlock.getArguments();
+      for (int argNum = 0; argNum < blockArgSize; ++argNum) {
+        argList[0]->replaceAllUsesWith(argList[blockArgSize]);
+        entryBlock.eraseArgument(0);
+      }
+      assert(succeeded(f->verify()));
+    }
+    LLVM_DEBUG(llvm::dbgs()
+               << "Run shape inference on : '" << f->getName() << "'\n");
+
+    auto *toyDialect = getContext().getRegisteredDialect("toy");
+    if (!toyDialect) {
+      getContext().emitError(mlir::UnknownLoc::get(&getContext()),
+                             "Toy dialect is not registered");
+      signalPassFailure();
+      return mlir::failure();
+    }
+
+    // Populate the worklist with the operations that need shape inference:
+    // these are the Toy operations that return a generic array.
+    llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
+    f->walk([&](mlir::Operation *op) {
+      if (op->getDialect() == toyDialect) {
+        if (op->getNumResults() == 1 &&
+            op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
+          opWorklist.insert(op);
+      }
+    });
+
+    // Iterate on the operations in the worklist until all operations have been
+    // inferred or no change happened (fix point).
+    while (!opWorklist.empty()) {
+      // Find the next operation ready for inference, that is an operation
+      // with all operands already resolved (non-generic).
+      auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
+        return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
+          return !v->getType().cast<ToyArrayType>().isGeneric();
+        });
+      });
+      if (nextop == opWorklist.end())
+        break; // failure: no operations can be inferred.
+
+      mlir::Operation *op = *nextop;
+      opWorklist.erase(op);
+      LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
+
+      // The add operation is trivial: propagate the input type as is.
+      if (auto addOp = op->dyn_cast<AddOp>()) {
+        op->getResult(0)->setType(op->getOperand(0)->getType());
+        continue;
+      }
+
+      // Transpose is easy: just invert the dimensions.
+      if (op->getName().getStringRef() == "toy.transpose") {
+        SmallVector<int64_t, 2> dims;
+        auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>();
+        dims.insert(dims.end(), arrayTy.getShape().begin(),
+                    arrayTy.getShape().end());
+        if (dims.size() == 2)
+          std::swap(dims[0], dims[1]);
+        op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
+        continue;
+      }
+
+      // Multiplication is a bit trickier, handle rank 1 as dot product and rank
+      // 2 as matrix multiplications.
+      // We need to be careful about rank mismatch here: the verifier could
+      // catch it but shape inference earlier in the pass could generate an
+      // invalid IR (from an invalid Toy input of course) and we wouldn't want
+      // to crash here.
+      if (auto mulOp = op->dyn_cast<MulOp>()) {
+        auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
+        auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
+        auto lhsRank = lhs.getShape().size();
+        auto rhsRank = rhs.getShape().size();
+        if (lhsRank != rhsRank) {
+          op->emitError("Shape mismatch: LHS and RHS must have the same "
+                        "rank for multiplication, got " +
+                        Twine(lhsRank) + " vs  " + Twine(lhsRank));
+          return mlir::failure();
+        }
+        SmallVector<int64_t, 2> dims;
+        if (lhsRank == 1) {
+          // dot product, result shape is <1>
+          dims.push_back(1);
+        } else {
+          if (lhsRank != 2) {
+            op->emitError(
+                "Shape mismatch: expect rank 1 or 2 for mul operands, got " +
+                Twine(lhsRank));
+            return mlir::failure();
+          }
+          dims.push_back(lhs.getShape()[0]);
+          dims.push_back(rhs.getShape()[1]);
+        }
+        op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
+        continue;
+      }
+
+      // Process calls: lookup the callee after mangling the name with the
+      // argument shapes. If the callee does not exist, we stop the inference
+      // for this function, queue the callee in the inter-procedural work list,
+      // and return. The current function stays in the work list and will
+      // restart after the callee is processed.
+      if (auto callOp = op->dyn_cast<GenericCallOp>()) {
+        auto calleeName = callOp.getCalleeName();
+        auto *callee = getModule().getNamedFunction(calleeName);
+        if (!callee) {
+          f->emitError(
+              llvm::Twine("Shape inference failed, call to unknown '") +
+              calleeName + "'");
+          signalPassFailure();
+          return mlir::failure();
+        }
+        auto mangledName = mangle(calleeName, op->getOpOperands());
+        LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
+                                << "', mangled: '" << mangledName << "'\n");
+        auto *mangledCallee = getModule().getNamedFunction(mangledName);
+        if (!mangledCallee) {
+          // Can't find the target, this is where we queue the request for the
+          // callee and stop the inference for the current function now.
+          std::vector<mlir::Type> funcArgs;
+          for (auto operand : op->getOperands())
+            funcArgs.push_back(operand->getType());
+          funcWorklist.push_back(
+              {callee, std::move(mangledName), std::move(funcArgs)});
+          return mlir::success();
+        }
+        // Found a specialized callee! Let's turn this into a normal call
+        // operation.
+        SmallVector<mlir::Value *, 8> operands;
+        for (mlir::Value *v : op->getOperands())
+          operands.push_back(v);
+        mlir::FuncBuilder builder(f);
+        builder.setInsertionPoint(op);
+        auto newCall =
+            builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
+        if (newCall.getNumResults()) {
+          op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
+          op->erase();
+          continue;
+        }
+      }
+    }
+
+    // Done with inference on this function, removing it from the worklist.
+    funcWorklist.pop_back();
+    // Mark the function as non-generic now that inference has succeeded
+    f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+
+    // If the operation worklist isn't empty, this indicates a failure.
+    if (!opWorklist.empty()) {
+      std::string str;
+      llvm::raw_string_ostream errorMsg(str);
+      errorMsg << "Shape inference failed, " << opWorklist.size()
+               << " operations couldn't be inferred\n";
+      for (auto *ope : opWorklist)
+        errorMsg << " - " << *ope << "\n";
+      f->emitError(errorMsg.str());
+      signalPassFailure();
+      return mlir::failure();
+    }
+
+    // Finally, update the return type of the function based on the argument to
+    // the return operation.
+    for (auto &block : f->getBlocks()) {
+      auto ret = block.getTerminator()->cast<ReturnOp>();
+      if (!ret)
+        continue;
+      if (ret.getNumOperands() &&
+          f->getType().getResult(0) == ret.getOperand()->getType())
+        // type match, we're done
+        break;
+      SmallVector<mlir::Type, 1> retTy;
+      if (ret.getNumOperands())
+        retTy.push_back(ret.getOperand()->getType());
+      mlir::Type elementType = mlir::FloatType::getF64(&getContext());
+      std::vector<mlir::Type> argumentsType;
+      for (auto arg : f->getArguments())
+        argumentsType.push_back(arg->getType());
+      auto newType =
+          mlir::FunctionType::get(argumentsType, retTy, &getContext());
+      f->setType(newType);
+      assert(succeeded(f->verify()));
+      break;
+    }
+    return mlir::success();
+  }
+};
+} // end anonymous namespace
+
+namespace toy {
+mlir::Pass *createShapeInferencePass() { return new ShapeInferencePass(); }
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp
new file mode 100644 (file)
index 0000000..8d6aed6
--- /dev/null
@@ -0,0 +1,209 @@
+//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple combiner for optimizing pattern in the Toy
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+
+#include <numeric>
+
+namespace toy {
+
+namespace {
+
+/// Fold transpose(transpose(x)) -> transpose(x)
+struct SimplifyRedundantTranspose : public mlir::RewritePattern {
+  /// We register this pattern to match every toy.transpose in the IR.
+  /// The "benefit" is used by the framework to order the patterns and process
+  /// them in order of profitability.
+  SimplifyRedundantTranspose(mlir::MLIRContext *context)
+      : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
+                       context) {}
+
+  /// This method is attempting to match a pattern and rewrite it. The rewriter
+  /// argument is the orchestrator of the sequence of rewrites. It is expected
+  /// to interact with it to perform any changes to the IR from here.
+  mlir::PatternMatchResult
+  matchAndRewrite(mlir::Operation *op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // We can directly cast the current operation as this will only get invoked
+    // on TransposeOp.
+    TransposeOp transpose = op->cast<TransposeOp>();
+    // look through the input to the current transpose
+    mlir::Value *transposeInput = transpose.getOperand();
+    mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
+    // If the input is defined by another Transpose, bingo!
+    TransposeOp transposeInputOp =
+        mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst);
+    if (!transposeInputOp)
+      return matchFailure();
+
+    // Use the rewriter to perform the replacement
+    rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
+    return matchSuccess();
+  }
+};
+
+/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
+struct SimplifyReshapeConstant : public mlir::RewritePattern {
+  SimplifyReshapeConstant(mlir::MLIRContext *context)
+      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
+                       context) {}
+
+  mlir::PatternMatchResult
+  matchAndRewrite(mlir::Operation *op,
+                  mlir::PatternRewriter &rewriter) const override {
+    ReshapeOp reshape = op->cast<ReshapeOp>();
+    // look through the input to the current reshape
+    mlir::Value *reshapeInput = reshape.getOperand();
+    mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
+    // If the input is defined by another reshape, bingo!
+    ConstantOp constantOp =
+        mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst);
+    if (!constantOp)
+      return matchFailure();
+
+    auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
+    if (auto valueAttr =
+            constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
+      // FIXME Check matching of element count!
+      //      auto oldType = constantOp.getType();
+      auto newType = rewriter.getTensorType(
+          reshapeType.getShape(), valueAttr.getType().getElementType());
+      auto newAttr =
+          mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
+      auto newConstant = rewriter.create<ConstantOp>(
+          constantOp.getLoc(), reshapeType.getShape(), newAttr);
+      rewriter.replaceOp(op, {newConstant});
+    } else if (auto valueAttr =
+                   constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
+      // Broadcast
+      auto dataSize = std::accumulate(reshapeType.getShape().begin(),
+                                      reshapeType.getShape().end(), 1,
+                                      std::multiplies<int>());
+      std::vector<mlir::Attribute> data(dataSize, valueAttr);
+      auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
+                                             reshapeType.getElementType());
+      auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
+      auto newConstant = rewriter.create<ConstantOp>(
+          constantOp.getLoc(), reshapeType.getShape(), newAttr);
+      rewriter.replaceOp(op, {newConstant});
+    } else {
+      llvm_unreachable("Unsupported Constant format");
+    }
+    return matchSuccess();
+  }
+};
+
+/// Fold reshape(reshape(x)) -> reshape(x)
+struct SimplifyReshapeReshape : public mlir::RewritePattern {
+  SimplifyReshapeReshape(mlir::MLIRContext *context)
+      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
+                       context) {}
+
+  mlir::PatternMatchResult
+  matchAndRewrite(mlir::Operation *op,
+                  mlir::PatternRewriter &rewriter) const override {
+    ReshapeOp reshape = op->cast<ReshapeOp>();
+    // look through the input to the current reshape
+    mlir::Value *reshapeInput = reshape.getOperand();
+    mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
+    // If the input is defined by another reshape, bingo!
+    ReshapeOp reshapeInputOp =
+        mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst);
+    if (!reshapeInputOp)
+      return matchFailure();
+
+    // Use the rewriter to perform the replacement
+    rewriter.replaceOp(op, {reshapeInputOp});
+    return matchSuccess();
+  }
+};
+
+/// Fold reshape(x)) -> x, when input type matches output type
+struct SimplifyNullReshape : public mlir::RewritePattern {
+  SimplifyNullReshape(mlir::MLIRContext *context)
+      : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
+                       context) {}
+
+  mlir::PatternMatchResult
+  matchAndRewrite(mlir::Operation *op,
+                  mlir::PatternRewriter &rewriter) const override {
+    ReshapeOp reshape = op->cast<ReshapeOp>();
+    if (reshape.getOperand()->getType() != reshape.getResult()->getType())
+      return matchFailure();
+    rewriter.replaceOp(reshape, {reshape.getOperand()});
+    return matchSuccess();
+  }
+};
+
+} // end anonymous namespace.
+
+// Register our patterns for rewrite by the Canonicalization framework.
+void TransposeOp::getCanonicalizationPatterns(
+    mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
+  results.push_back(llvm::make_unique<SimplifyRedundantTranspose>(context));
+}
+
+// Register our patterns for rewrite by the Canonicalization framework.
+void ReshapeOp::getCanonicalizationPatterns(
+    mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
+  results.push_back(llvm::make_unique<SimplifyReshapeConstant>(context));
+  results.push_back(llvm::make_unique<SimplifyReshapeReshape>(context));
+  results.push_back(llvm::make_unique<SimplifyNullReshape>(context));
+}
+
+namespace {
+
+/// Fold type.cast(x) -> x, when input type matches output type
+struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
+  SimplifyIdentityTypeCast(mlir::MLIRContext *context)
+      : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1,
+                       context) {}
+
+  mlir::PatternMatchResult
+  matchAndRewrite(mlir::Operation *op,
+                  mlir::PatternRewriter &rewriter) const override {
+    TypeCastOp typeCast = op->cast<TypeCastOp>();
+    auto resTy = typeCast.getResult()->getType();
+    auto *candidateOp = op;
+    while (candidateOp && candidateOp->isa<TypeCastOp>()) {
+      if (resTy == candidateOp->getOperand(0)->getType()) {
+        rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
+        return matchSuccess();
+      }
+      candidateOp = candidateOp->getOperand(0)->getDefiningOp();
+    }
+    return matchFailure();
+  }
+};
+
+} // end anonymous namespace.
+
+void TypeCastOp::getCanonicalizationPatterns(
+    mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
+  results.push_back(llvm::make_unique<SimplifyIdentityTypeCast>(context));
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/mlir/ToyDialect.cpp b/mlir/examples/toy/Ch5/mlir/ToyDialect.cpp
new file mode 100644 (file)
index 0000000..be117f5
--- /dev/null
@@ -0,0 +1,405 @@
+//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the dialect for the Toy IR: custom type parsing and
+// operation verification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/raw_ostream.h"
+
+using llvm::ArrayRef;
+using llvm::raw_ostream;
+using llvm::raw_string_ostream;
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace toy {
+namespace detail {
+
+/// This class holds the implementation of the ToyArrayType.
+/// It is intended to be uniqued based on its content and owned by the context.
+struct ToyArrayTypeStorage : public mlir::TypeStorage {
+  /// This defines how we unique this type in the context: our key contains
+  /// only the shape, a more complex type would have multiple entries in the
+  /// tuple here.
+  /// The element of the tuples usually matches 1-1 the arguments from the
+  /// public `get()` method arguments from the facade.
+  using KeyTy = std::tuple<ArrayRef<int64_t>>;
+  static unsigned hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key));
+  }
+  /// When the key hash hits an existing type, we compare the shape themselves
+  /// to confirm we have the right type.
+  bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
+
+  /// This is a factory method to create our type storage. It is only
+  /// invoked after looking up the type in the context using the key and not
+  /// finding it.
+  static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+                                        const KeyTy &key) {
+    // Copy the shape array into the bumpptr allocator owned by the context.
+    ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
+
+    // Allocate the instance for the ToyArrayTypeStorage itself
+    auto *storage = allocator.allocate<ToyArrayTypeStorage>();
+    // Initialize the instance using placement new.
+    return new (storage) ToyArrayTypeStorage(shape);
+  }
+
+  ArrayRef<int64_t> getShape() const { return shape; }
+
+private:
+  ArrayRef<int64_t> shape;
+
+  /// Constructor is only invoked from the `construct()` method above.
+  ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
+};
+
+} // namespace detail
+
+mlir::Type ToyArrayType::getElementType() {
+  return mlir::FloatType::getF64(getContext());
+}
+
+ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
+                               ArrayRef<int64_t> shape) {
+  return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
+}
+
+ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
+
+mlir::MemRefType ToyArrayType::toMemref() {
+  auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0);
+  return memRefType;
+}
+
+/// Dialect creation, the instance will be owned by the context. This is the
+/// point of registration of custom types and operations for the dialect.
+ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
+  addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
+                MulOp, AddOp, ReturnOp, AllocOp, TypeCastOp>();
+  addTypes<ToyArrayType>();
+}
+
+/// Parse a type registered to this dialect, we expect only Toy arrays.
+mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
+  // Sanity check: we only support array or array<...>
+  if (!tyData.startswith("array")) {
+    getContext()->emitError(loc, "Invalid Toy type '" + tyData +
+                                     "', array expected");
+    return nullptr;
+  }
+  // Drop the "array" prefix from the type name, we expect either an empty
+  // string or just the shape.
+  tyData = tyData.drop_front(StringRef("array").size());
+  // This is the generic array case without shape, early return it.
+  if (tyData.empty())
+    return ToyArrayType::get(getContext());
+
+  // Use a regex to parse the shape (for efficient we should store this regex in
+  // the dialect itself).
+  SmallVector<StringRef, 4> matches;
+  auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
+  if (!shapeRegex.match(tyData, &matches)) {
+    getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
+    return nullptr;
+  }
+  SmallVector<int64_t, 4> shape;
+  // Iterate through the captures, skip the first one which is the full string.
+  for (auto dimStr :
+       llvm::make_range(std::next(matches.begin()), matches.end())) {
+    if (dimStr.startswith(","))
+      continue; // POSIX misses non-capturing groups.
+    if (dimStr.empty())
+      continue; // '*' makes it an optional group capture
+    // Convert the capture to an integer
+    unsigned long long dim;
+    if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
+      getContext()->emitError(
+          loc, "Couldn't parse dimension as integer, matched: " + dimStr);
+      return mlir::Type();
+    }
+    shape.push_back(dim);
+  }
+  // Finally we collected all the dimensions in the shape,
+  // create the array type.
+  return ToyArrayType::get(getContext(), shape);
+}
+
+/// Print a Toy array type, for example `array<2, 3, 4>`
+void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
+  auto arrayTy = type.dyn_cast<ToyArrayType>();
+  if (!arrayTy) {
+    os << "unknown toy type";
+    return;
+  }
+  os << "array";
+  if (!arrayTy.getShape().empty()) {
+    os << "<";
+    mlir::interleaveComma(arrayTy.getShape(), os);
+    os << ">";
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//////////////////// Custom Operations for the Dialect /////////////////////////
+////////////////////////////////////////////////////////////////////////////////
+
+/// Helper to verify that the result of an operation is a Toy array type.
+template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
+  if (!op->getResult()->getType().template isa<ToyArrayType>()) {
+    std::string msg;
+    raw_string_ostream os(msg);
+    os << "expects a Toy Array for its argument, got "
+       << op->getResult()->getType();
+    return op->emitOpError(os.str());
+  }
+  return mlir::success();
+}
+
+/// Helper to verify that the two operands of a binary operation are Toy
+/// arrays..
+template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
+  if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
+    std::string msg;
+    raw_string_ostream os(msg);
+    os << "expects a Toy Array for its LHS, got "
+       << op->getOperand(0)->getType();
+    return op->emitOpError(os.str());
+  }
+  if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
+    std::string msg;
+    raw_string_ostream os(msg);
+    os << "expects a Toy Array for its LHS, got "
+       << op->getOperand(0)->getType();
+    return op->emitOpError(os.str());
+  }
+  return mlir::success();
+}
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                       ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
+  state->types.push_back(ToyArrayType::get(builder->getContext(), shape));
+  auto dataAttribute = builder->getNamedAttr("value", value);
+  state->attributes.push_back(dataAttribute);
+}
+
+/// Build a constant operation.
+/// The builder is passed as an argument, so is the state that this method is
+/// expected to fill in order to build the operation.
+void ConstantOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                       mlir::FloatAttr value) {
+  // Broadcast and forward to the other build factory
+  mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
+  auto dataType = builder->getTensorType({1}, elementType);
+  auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
+                           .cast<mlir::DenseElementsAttr>();
+
+  ConstantOp::build(builder, state, {1}, dataAttribute);
+}
+
+/// Verifier for constant operation.
+mlir::LogicalResult ConstantOp::verify() {
+  // Ensure that the return type is a Toy array
+  if (failed(verifyToyReturnArray(this)))
+    return mlir::failure();
+
+  // We expect the constant itself to be stored as an attribute.
+  auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
+  if (!dataAttr) {
+    return emitOpError(
+        "missing valid `value` DenseElementsAttribute on toy.constant()");
+  }
+  auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
+  if (!attrType) {
+    return emitOpError(
+        "missing valid `value` DenseElementsAttribute on toy.constant()");
+  }
+
+  // If the return type of the constant is not a generic array, the shape must
+  // match the shape of the attribute holding the data.
+  auto resultType = getResult()->getType().cast<ToyArrayType>();
+  if (!resultType.isGeneric()) {
+    if (attrType.getRank() != resultType.getRank()) {
+      return emitOpError("The rank of the toy.constant return type must match "
+                         "the one of the attached value attribute: " +
+                         Twine(attrType.getRank()) +
+                         " != " + Twine(resultType.getRank()));
+    }
+    for (int dim = 0; dim < attrType.getRank(); ++dim) {
+      if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
+        std::string msg;
+        raw_string_ostream os(msg);
+        return emitOpError(
+            "Shape mismatch between toy.constant return type and its "
+            "attribute at dimension " +
+            Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
+            " != " + Twine(resultType.getShape()[dim]));
+      }
+    }
+  }
+  return mlir::success();
+}
+
+void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                          StringRef callee, ArrayRef<mlir::Value *> arguments) {
+  // Generic call always returns a generic ToyArray initially
+  state->types.push_back(ToyArrayType::get(builder->getContext()));
+  state->operands.assign(arguments.begin(), arguments.end());
+  auto calleeAttr = builder->getStringAttr(callee);
+  state->attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
+}
+
+mlir::LogicalResult GenericCallOp::verify() {
+  // Verify that every operand is a Toy Array
+  for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
+    if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
+      std::string msg;
+      raw_string_ostream os(msg);
+      os << "expects a Toy Array for its " << opId << " operand, got "
+         << getOperand(opId)->getType();
+      return emitOpError(os.str());
+    }
+  }
+  return mlir::success();
+}
+
+/// Return the name of the callee.
+StringRef GenericCallOp::getCalleeName() {
+  return getAttr("callee").cast<mlir::StringAttr>().getValue();
+}
+
+template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
+  if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
+    std::string msg;
+    raw_string_ostream os(msg);
+    os << "expects a Toy Array for its argument, got "
+       << op->getOperand()->getType();
+    return op->emitOpError(os.str());
+  }
+  return mlir::success();
+}
+
+void ReturnOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                     mlir::Value *value) {
+  // Return does not return any value and has an optional single argument
+  if (value)
+    state->operands.push_back(value);
+}
+
+mlir::LogicalResult ReturnOp::verify() {
+  if (getNumOperands() > 1)
+    return emitOpError("expects zero or one operand, got " +
+                       Twine(getNumOperands()));
+  if (hasOperand() && failed(verifyToySingleOperand(this)))
+    return mlir::failure();
+  return mlir::success();
+}
+
+void PrintOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Value *value) {
+  // Print does not return any value and has a single argument
+  state->operands.push_back(value);
+}
+
+mlir::LogicalResult PrintOp::verify() {
+  if (failed(verifyToySingleOperand(this)))
+    return mlir::failure();
+  return mlir::success();
+}
+
+void TransposeOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                        mlir::Value *value) {
+  state->types.push_back(ToyArrayType::get(builder->getContext()));
+  state->operands.push_back(value);
+}
+
+mlir::LogicalResult TransposeOp::verify() {
+  if (failed(verifyToySingleOperand(this)))
+    return mlir::failure();
+  return mlir::success();
+}
+
+void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                      mlir::Value *value, ToyArrayType reshapedType) {
+  state->types.push_back(reshapedType);
+  state->operands.push_back(value);
+}
+
+mlir::LogicalResult ReshapeOp::verify() {
+  if (failed(verifyToySingleOperand(this)))
+    return mlir::failure();
+  auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
+  if (!retTy)
+    return emitOpError("toy.reshape is expected to produce a Toy array");
+  if (retTy.isGeneric())
+    return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
+                       "got a generic one.");
+  return mlir::success();
+}
+
+void AddOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                  mlir::Value *lhs, mlir::Value *rhs) {
+  state->types.push_back(ToyArrayType::get(builder->getContext()));
+  state->operands.push_back(lhs);
+  state->operands.push_back(rhs);
+}
+
+mlir::LogicalResult AddOp::verify() {
+  if (failed(verifyToyBinOperands(this)))
+    return mlir::failure();
+  return mlir::success();
+}
+
+void MulOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                  mlir::Value *lhs, mlir::Value *rhs) {
+  state->types.push_back(ToyArrayType::get(builder->getContext()));
+  state->operands.push_back(lhs);
+  state->operands.push_back(rhs);
+}
+
+mlir::LogicalResult MulOp::verify() {
+  if (failed(verifyToyBinOperands(this)))
+    return mlir::failure();
+  return mlir::success();
+}
+
+void AllocOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                    mlir::Type retType) {
+  state->types.push_back(retType);
+}
+
+void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState *state,
+                       mlir::Value *value, mlir::Type destTy) {
+  state->operands.push_back(value);
+  state->types.push_back(destTy);
+}
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/parser/AST.cpp b/mlir/examples/toy/Ch5/parser/AST.cpp
new file mode 100644 (file)
index 0000000..869f2ef
--- /dev/null
@@ -0,0 +1,263 @@
+//===- AST.cpp - Helper for printing out the Toy AST ----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the AST dump for the Toy language.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/AST.h"
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+
+namespace {
+
+// RAII helper to manage increasing/decreasing the indentation as we traverse
+// the AST
+struct Indent {
+  Indent(int &level) : level(level) { ++level; }
+  ~Indent() { --level; }
+  int &level;
+};
+
+/// Helper class that implement the AST tree traversal and print the nodes along
+/// the way. The only data member is the current indentation level.
+class ASTDumper {
+public:
+  void dump(ModuleAST *Node);
+
+private:
+  void dump(VarType &type);
+  void dump(VarDeclExprAST *varDecl);
+  void dump(ExprAST *expr);
+  void dump(ExprASTList *exprList);
+  void dump(NumberExprAST *num);
+  void dump(LiteralExprAST *Node);
+  void dump(VariableExprAST *Node);
+  void dump(ReturnExprAST *Node);
+  void dump(BinaryExprAST *Node);
+  void dump(CallExprAST *Node);
+  void dump(PrintExprAST *Node);
+  void dump(PrototypeAST *Node);
+  void dump(FunctionAST *Node);
+
+  // Actually print spaces matching the current indentation level
+  void indent() {
+    for (int i = 0; i < curIndent; i++)
+      llvm::errs() << "  ";
+  }
+  int curIndent = 0;
+};
+
+} // namespace
+
+/// Return a formatted string for the location of any node
+template <typename T> static std::string loc(T *Node) {
+  const auto &loc = Node->loc();
+  return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
+          llvm::Twine(loc.col))
+      .str();
+}
+
+// Helper Macro to bump the indentation level and print the leading spaces for
+// the current indentations
+#define INDENT()                                                               \
+  Indent level_(curIndent);                                                    \
+  indent();
+
+/// Dispatch to a generic expressions to the appropriate subclass using RTTI
+void ASTDumper::dump(ExprAST *expr) {
+#define dispatch(CLASS)                                                        \
+  if (CLASS *node = llvm::dyn_cast<CLASS>(expr))                               \
+    return dump(node);
+  dispatch(VarDeclExprAST);
+  dispatch(LiteralExprAST);
+  dispatch(NumberExprAST);
+  dispatch(VariableExprAST);
+  dispatch(ReturnExprAST);
+  dispatch(BinaryExprAST);
+  dispatch(CallExprAST);
+  dispatch(PrintExprAST);
+  // No match, fallback to a generic message
+  INDENT();
+  llvm::errs() << "<unknown Expr, kind " << expr->getKind() << ">\n";
+}
+
+/// A variable declaration is printing the variable name, the type, and then
+/// recurse in the initializer value.
+void ASTDumper::dump(VarDeclExprAST *varDecl) {
+  INDENT();
+  llvm::errs() << "VarDecl " << varDecl->getName();
+  dump(varDecl->getType());
+  llvm::errs() << " " << loc(varDecl) << "\n";
+  dump(varDecl->getInitVal());
+}
+
+/// A "block", or a list of expression
+void ASTDumper::dump(ExprASTList *exprList) {
+  INDENT();
+  llvm::errs() << "Block {\n";
+  for (auto &expr : *exprList)
+    dump(expr.get());
+  indent();
+  llvm::errs() << "} // Block\n";
+}
+
+/// A literal number, just print the value.
+void ASTDumper::dump(NumberExprAST *num) {
+  INDENT();
+  llvm::errs() << num->getValue() << " " << loc(num) << "\n";
+}
+
+/// Helper to print recurisvely a literal. This handles nested array like:
+///    [ [ 1, 2 ], [ 3, 4 ] ]
+/// We print out such array with the dimensions spelled out at every level:
+///    <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
+void printLitHelper(ExprAST *lit_or_num) {
+  // Inside a literal expression we can have either a number or another literal
+  if (auto num = llvm::dyn_cast<NumberExprAST>(lit_or_num)) {
+    llvm::errs() << num->getValue();
+    return;
+  }
+  auto *literal = llvm::cast<LiteralExprAST>(lit_or_num);
+
+  // Print the dimension for this literal first
+  llvm::errs() << "<";
+  {
+    const char *sep = "";
+    for (auto dim : literal->getDims()) {
+      llvm::errs() << sep << dim;
+      sep = ", ";
+    }
+  }
+  llvm::errs() << ">";
+
+  // Now print the content, recursing on every element of the list
+  llvm::errs() << "[ ";
+  const char *sep = "";
+  for (auto &elt : literal->getValues()) {
+    llvm::errs() << sep;
+    printLitHelper(elt.get());
+    sep = ", ";
+  }
+  llvm::errs() << "]";
+}
+
+/// Print a literal, see the recursive helper above for the implementation.
+void ASTDumper::dump(LiteralExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Literal: ";
+  printLitHelper(Node);
+  llvm::errs() << " " << loc(Node) << "\n";
+}
+
+/// Print a variable reference (just a name).
+void ASTDumper::dump(VariableExprAST *Node) {
+  INDENT();
+  llvm::errs() << "var: " << Node->getName() << " " << loc(Node) << "\n";
+}
+
+/// Return statement print the return and its (optional) argument.
+void ASTDumper::dump(ReturnExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Return\n";
+  if (Node->getExpr().hasValue())
+    return dump(*Node->getExpr());
+  {
+    INDENT();
+    llvm::errs() << "(void)\n";
+  }
+}
+
+/// Print a binary operation, first the operator, then recurse into LHS and RHS.
+void ASTDumper::dump(BinaryExprAST *Node) {
+  INDENT();
+  llvm::errs() << "BinOp: " << Node->getOp() << " " << loc(Node) << "\n";
+  dump(Node->getLHS());
+  dump(Node->getRHS());
+}
+
+/// Print a call expression, first the callee name and the list of args by
+/// recursing into each individual argument.
+void ASTDumper::dump(CallExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Call '" << Node->getCallee() << "' [ " << loc(Node) << "\n";
+  for (auto &arg : Node->getArgs())
+    dump(arg.get());
+  indent();
+  llvm::errs() << "]\n";
+}
+
+/// Print a builtin print call, first the builtin name and then the argument.
+void ASTDumper::dump(PrintExprAST *Node) {
+  INDENT();
+  llvm::errs() << "Print [ " << loc(Node) << "\n";
+  dump(Node->getArg());
+  indent();
+  llvm::errs() << "]\n";
+}
+
+/// Print type: only the shape is printed in between '<' and '>'
+void ASTDumper::dump(VarType &type) {
+  llvm::errs() << "<";
+  const char *sep = "";
+  for (auto shape : type.shape) {
+    llvm::errs() << sep << shape;
+    sep = ", ";
+  }
+  llvm::errs() << ">";
+}
+
+/// Print a function prototype, first the function name, and then the list of
+/// parameters names.
+void ASTDumper::dump(PrototypeAST *Node) {
+  INDENT();
+  llvm::errs() << "Proto '" << Node->getName() << "' " << loc(Node) << "'\n";
+  indent();
+  llvm::errs() << "Params: [";
+  const char *sep = "";
+  for (auto &arg : Node->getArgs()) {
+    llvm::errs() << sep << arg->getName();
+    sep = ", ";
+  }
+  llvm::errs() << "]\n";
+}
+
+/// Print a function, first the prototype and then the body.
+void ASTDumper::dump(FunctionAST *Node) {
+  INDENT();
+  llvm::errs() << "Function \n";
+  dump(Node->getProto());
+  dump(Node->getBody());
+}
+
+/// Print a module, actually loop over the functions and print them in sequence.
+void ASTDumper::dump(ModuleAST *Node) {
+  INDENT();
+  llvm::errs() << "Module:\n";
+  for (auto &F : *Node)
+    dump(&F);
+}
+
+namespace toy {
+
+// Public API
+void dump(ModuleAST &module) { ASTDumper().dump(&module); }
+
+} // namespace toy
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
new file mode 100644 (file)
index 0000000..b140b36
--- /dev/null
@@ -0,0 +1,324 @@
+//===- toyc.cpp - The Toy Compiler ----------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the entry point for the Toy compiler.
+//
+//===----------------------------------------------------------------------===//
+
+#include "toy/Dialect.h"
+#include "toy/Lowering.h"
+#include "toy/MLIRGen.h"
+#include "toy/Parser.h"
+#include "toy/Passes.h"
+
+#include "linalg1/Dialect.h"
+#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/OptUtils.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/LLVMIR.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace toy;
+namespace cl = llvm::cl;
+
+static cl::opt<std::string> inputFilename(cl::Positional,
+                                          cl::desc("<input toy file>"),
+                                          cl::init("-"),
+                                          cl::value_desc("filename"));
+
+namespace {
+enum InputType { Toy, MLIR };
+}
+static cl::opt<enum InputType> inputType(
+    "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
+    cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
+    cl::values(clEnumValN(MLIR, "mlir",
+                          "load the input file as an MLIR file")));
+
+namespace {
+enum Action {
+  None,
+  DumpAST,
+  DumpMLIR,
+  DumpMLIRLinalg,
+  DumpLLVMDialect,
+  DumpLLVMIR,
+  RunJIT
+};
+}
+static cl::opt<enum Action> emitAction(
+    "emit", cl::desc("Select the kind of output desired"),
+    cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
+    cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")),
+    cl::values(clEnumValN(DumpMLIRLinalg, "mlir-linalg",
+                          "output the MLIR dump after linalg lowering")),
+    cl::values(clEnumValN(DumpLLVMDialect, "llvm-dialect",
+                          "output the LLVM MLIR Dialect dump")),
+    cl::values(clEnumValN(DumpLLVMIR, "llvm-ir", "output the LLVM IR dump")),
+    cl::values(
+        clEnumValN(RunJIT, "jit",
+                   "JIT the code and run it by invoking the main function")));
+
+static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
+
+/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
+std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
+  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+      llvm::MemoryBuffer::getFileOrSTDIN(filename);
+  if (std::error_code EC = FileOrErr.getError()) {
+    llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+    return nullptr;
+  }
+  auto buffer = FileOrErr.get()->getBuffer();
+  LexerBuffer lexer(buffer.begin(), buffer.end(), filename);
+  Parser parser(lexer);
+  return parser.ParseModule();
+}
+
+mlir::LogicalResult optimize(mlir::Module &module) {
+  mlir::PassManager pm;
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(createShapeInferencePass());
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(mlir::createCSEPass());
+
+  // Apply any generic pass manager command line options.
+  applyPassManagerCLOptions(pm);
+
+  return pm.run(&module);
+}
+
+mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) {
+  mlir::PassManager pm;
+  pm.addPass(createEarlyLoweringPass());
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(mlir::createCSEPass());
+  if (!OnlyLinalg) {
+    pm.addPass(createLateLoweringPass());
+    pm.addPass(mlir::createCanonicalizerPass());
+    pm.addPass(mlir::createCSEPass());
+  }
+  // Apply any generic pass manager command line options.
+  applyPassManagerCLOptions(pm);
+
+  return pm.run(&module);
+}
+
+mlir::LogicalResult lowerLLVMModule(mlir::Module &module) {
+  mlir::PassManager pm;
+  pm.addPass(createEarlyLoweringPass());
+  pm.addPass(createLateLoweringPass());
+
+  // Apply any generic pass manager command line options.
+  applyPassManagerCLOptions(pm);
+
+  return pm.run(&module);
+}
+
+std::unique_ptr<mlir::Module> loadFileAndProcessModule(
+    mlir::MLIRContext &context, bool EnableLinalgLowering = false,
+    bool EnableLLVMLowering = false, bool EnableOpt = false) {
+
+  std::unique_ptr<mlir::Module> module;
+  if (inputType == InputType::MLIR ||
+      llvm::StringRef(inputFilename).endswith(".mlir")) {
+    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
+        llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
+    if (std::error_code EC = fileOrErr.getError()) {
+      llvm::errs() << "Could not open input file: " << EC.message() << "\n";
+      return nullptr;
+    }
+    llvm::SourceMgr sourceMgr;
+    sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
+    module.reset(mlir::parseSourceFile(sourceMgr, &context));
+    if (!module) {
+      llvm::errs() << "Error can't load file " << inputFilename << "\n";
+      return nullptr;
+    }
+    if (failed(module->verify())) {
+      llvm::errs() << "Error verifying MLIR module\n";
+      return nullptr;
+    }
+  } else {
+    auto moduleAST = parseInputFile(inputFilename);
+    module = mlirGen(context, *moduleAST);
+  }
+  if (!module)
+    return nullptr;
+  if (EnableOpt) {
+    if (failed(optimize(*module))) {
+      llvm::errs() << "Module optimization failed\n";
+      return nullptr;
+    }
+  }
+  if (EnableLLVMLowering || EnableLinalgLowering) {
+    if (failed(lowerDialect(*module, !EnableLLVMLowering))) {
+      llvm::errs() << "Module lowering failed\n";
+      return nullptr;
+    }
+  }
+  return module;
+}
+
+int dumpMLIR() {
+  mlir::MLIRContext context;
+  auto module =
+      loadFileAndProcessModule(context, /*EnableLinalgLowering=*/false,
+                               /*EnableLLVMLowering=*/false, EnableOpt);
+  if (!module)
+    return -1;
+  module->dump();
+  return 0;
+}
+
+int dumpMLIRLinalg() {
+  mlir::MLIRContext context;
+  auto module = loadFileAndProcessModule(context, /*EnableLinalgLowering=*/true,
+                                         /*EnableLLVMLowering=*/false,
+                                         /* EnableOpt=*/true);
+  if (!module)
+    return -1;
+  module->dump();
+  return 0;
+}
+
+int dumpLLVMDialect() {
+  mlir::MLIRContext context;
+  auto module = loadFileAndProcessModule(
+      context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
+      /* EnableOpt=*/true);
+  if (!module) {
+    llvm::errs() << "Failed to load/lower MLIR module\n";
+    return -1;
+  }
+  module->dump();
+  return 0;
+}
+
+int dumpLLVMIR() {
+  mlir::MLIRContext context;
+  auto module = loadFileAndProcessModule(
+      context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
+      /* EnableOpt=*/true);
+  if (!module) {
+    llvm::errs() << "Failed to load/lower MLIR module\n";
+    return -1;
+  }
+  auto llvmModule = translateModuleToLLVMIR(*module);
+  if (!llvmModule) {
+    llvm::errs() << "Failed to emit LLVM IR\n";
+    return -1;
+  }
+  // Initialize LLVM targets.
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+  auto optPipeline = mlir::makeOptimizingTransformer(
+      /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0);
+  if (auto err = optPipeline(llvmModule.get())) {
+    llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
+    return -1;
+  }
+  llvm::errs() << *llvmModule << "\n";
+  return 0;
+}
+
+int runJit() {
+  mlir::MLIRContext context;
+  auto module = loadFileAndProcessModule(
+      context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
+      /* EnableOpt=*/true);
+
+  // Initialize LLVM targets.
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+
+  // Create an MLIR execution engine.  Note that it takes a null pass manager
+  // to make sure it won't run "default" passes on the MLIR that would trigger
+  // a second conversion to LLVM IR.  The execution engine eagerly JIT-compiles
+  // the module.
+  auto optPipeline = mlir::makeOptimizingTransformer(
+      /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0);
+  auto maybeEngine =
+      mlir::ExecutionEngine::create(module.get(), /*pm=*/nullptr, optPipeline);
+  assert(maybeEngine && "failed to construct an execution engine");
+  auto &engine = maybeEngine.get();
+
+  // Invoke the JIT-compiled function with the arguments.  Note that, for API
+  // uniformity reasons, it takes a list of type-erased pointers to arguments.
+  auto invocationResult = engine->invoke("main");
+  if (invocationResult) {
+    llvm::errs() << "JIT invocation failed\n";
+    return -1;
+  }
+
+  return 0;
+}
+
+int dumpAST() {
+  if (inputType == InputType::MLIR) {
+    llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
+    return 5;
+  }
+
+  auto moduleAST = parseInputFile(inputFilename);
+  if (!moduleAST)
+    return 1;
+
+  dump(*moduleAST);
+  return 0;
+}
+
+int main(int argc, char **argv) {
+  // Register our Dialects with MLIR
+  mlir::registerDialect<ToyDialect>();
+  mlir::registerDialect<linalg::LinalgDialect>();
+
+  mlir::registerPassManagerCLOptions();
+  cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
+
+  switch (emitAction) {
+  case Action::DumpAST:
+    return dumpAST();
+  case Action::DumpMLIR:
+    return dumpMLIR();
+  case Action::DumpMLIRLinalg:
+    return dumpMLIRLinalg();
+  case Action::DumpLLVMDialect:
+    return dumpLLVMDialect();
+  case Action::DumpLLVMIR:
+    return dumpLLVMIR();
+  case Action::RunJIT:
+    return runJit();
+  default:
+    llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
+    return -1;
+  }
+
+  return 0;
+}
diff --git a/mlir/g3doc/Tutorials/Toy/Ch-5.md b/mlir/g3doc/Tutorials/Toy/Ch-5.md
new file mode 100644 (file)
index 0000000..5b69bdd
--- /dev/null
@@ -0,0 +1,299 @@
+# Chapter 5: CodeGen via Lowering to Lower-Level Dialects
+
+At this point, we are eager to generate actual code and see our Toy language
+taking life. We will obviously use LLVM to generate code, but just showing the
+LLVM builder interface wouldn't be very exciting here. Instead, we will show how
+to perform progressive lowering through a mix of dialects coexisting in the same
+function.
+
+To make it more interesting, we will consider that we want to reuse existing
+optimizations implemented in a dialect optimizing linear algebra: `Linalg`. This
+dialect is tailored to the computation heavy part of the program, and is
+limited: it doesn't support representing our `toy.print` builtin for instance,
+neither should it! Instead we can target `Linalg` for the computation heavy part
+of Toy (mostly matmul), we will target the `Affine` dialect for other
+well-formed loop nest, and directly the `LLVM IR` dialect for lowering `print`.
+
+# The `DialectConversion` Framework
+
+Similarly to the canonicalization patterns introduced in the previous section,
+the `DialectConversion` framework involves its own set of patterns. This
+framework operates a bit differently from the canonicalizer: a new function is
+created and the pattern matching operation in the original function are expected
+to emit the IR in the new function.
+
+Dialect conversion requires three components, implemented by overriding virtual
+methods defined in `DialectConversion`:
+
+-   Type Conversion: for things like block arguments' type.
+-   Function signature conversion: for every function it is invoked with the
+    function type and the conversion generates a new prototype for the converted
+    function. The default implementation will call into the type conversion for
+    the returned values and for each of the parameters.
+-   Operations convertions: each pattern is expected to generate new results
+    matching the current operations' in the new function. This may involve
+    generating one or multiple new operations, or possibly just remapping
+    existing operands (folding).
+
+A typical starting point for implementing our lowering would be:
+
+```c++
+class Lowering : public DialectConversion {
+public:
+  // This gets called for block and region arguments, and attributes.
+  Type convertType(Type t) override { /*...*/ }
+
+  // This gets called for functions.
+  FunctionType convertFunctionSignatureType(FunctionType type,
+      ArrayRef<NamedAttributeList> argAttrs,
+      SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) { /*...*/ }
+
+  // This gets called once to set up operation converters.
+  llvm::DenseSet<DialectOpConversion *>
+  initConverters(MLIRContext *context) override {
+    return ConversionListBuilder<MulOpConversion,
+                                 PrintOpConversion,
+                                 TransposeOpConversion>::build(allocator, context);
+  }
+
+private:
+  llvm::BumpPtrAllocator allocator;
+};
+```
+
+Individual operation converters are following this pattern:
+
+```c++
+/// Lower a toy.add to an affine loop nest.
+///
+/// This class inherit from `DialectOpConversion` and override `rewrite`,
+/// similarly to the PatternRewriter introduced in the previous chapter.
+/// It will be called by the DialectConversion framework (see `LateLowering`
+/// class below).
+class AddOpConversion : public DialectOpConversion {
+public:
+  explicit AddOpConversion(MLIRContext *context)
+      : DialectOpConversion(toy::AddOp::getOperationName(), 1, context) {}
+
+  /// Lower the `op` by generating IR using the `rewriter` builder. The builder
+  /// is setup with a new function, the `operands` array has been populated with
+  /// the rewritten operands for `op` in the new function.
+  /// The results created by the new IR with the builder are returned, and their
+  /// number must match the number of result of `op`.
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    ...
+
+    // Return the newly allocated buffer, it will be used as an operand when
+    // converting the operations corresponding to the users of this `toy.add`.
+    return result;
+  }
+```
+
+## Linalg
+
+Linalg is an advanced dialect for dense algebra optimizations. It is implemented
+as [a separate tutorial](../Linalg/Ch-1.md) in parallel with Toy. We are acting
+as a user of this dialect by lowering Toy matrix multiplications to
+`linalg.matmul`.
+
+To support this, we will split our lowering in two parts: an *early lowering*
+that emits operations in the `Linalg` dialect for a subset of the Toy IR, and a
+*late lowering* that materializes buffers and converts all operations and type
+to the LLVM dialect. We will then be able to run specific optimizations in
+between the two lowering.
+
+Let's look again at our example `multiply_transpose`:
+
+```mlir
+func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
+  attributes  {toy.generic: true} {
+  %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
+  %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
+  "toy.return"(%1) : (!toy.array) -> ()
+}
+```
+
+After shape inference, and lowering to `Linalg`, here is what our IR will look
+like:
+
+```mlir
+func @multiply_transpose_2x3_2x3(%arg0: !toy.array<2, 3>, %arg1: !toy.array<2, 3>) -> !toy.array<2, 2>
+  attributes  {toy.generic: false} {
+  %c3 = constant 3 : index
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c1 = constant 1 : index
+  %0 = "toy.transpose"(%arg1) : (!toy.array<2, 3>) -> !toy.array<3, 2>
+  %1 = "toy.alloc"() : () -> !toy.array<2, 2>
+  %2 = "toy.cast"(%1) : (!toy.array<2, 2>) -> memref<2x2xf64>
+  %3 = "toy.cast"(%arg0) : (!toy.array<2, 3>) -> memref<2x3xf64>
+  %4 = "toy.cast"(%0) : (!toy.array<3, 2>) -> memref<3x2xf64>
+  %5 = linalg.range %c0:%c2:%c1 : !linalg.range
+  %6 = linalg.range %c0:%c3:%c1 : !linalg.range
+  %7 = linalg.view %3[%5, %6] : !linalg<"view<?x?xf64>">
+  %8 = linalg.view %4[%6, %5] : !linalg<"view<?x?xf64>">
+  %9 = linalg.view %2[%5, %5] : !linalg<"view<?x?xf64>">
+  linalg.matmul(%7, %8, %9) : !linalg<"view<?x?xf64>">
+  "toy.return"(%1) : (!toy.array<2, 2>) -> ()
+}
+```
+
+Note how the operations from multiple dialects are coexisting in this function.
+
+You can reproduce this result with `bin/toyc-ch5
+test/Examples/Toy/Ch5/lowering.toy -emit=mlir-linalg`
+
+## Emitting LLVM
+
+The availability of various dialects allows for a smooth lowering by reducing
+the impedance mismatch between dialects. For example we don't need to lower our
+`toy.print` over array directly to LLVM IR, we can use the well structured loop
+from the `Affine` dialect for convenience when scanning the array and insert a
+call to `llvm.printf` in the body. We will rely on MLIR lowering to LLVM for the
+`Affine` dialect, we get it for free. Here is a simplified version of the code
+in this chapter for lowering `toy.print`:
+
+```c++
+    // Create our loop nest now
+    using namespace edsc;
+    using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
+    ScopedContext scope(rewriter, loc);
+    ValueHandle zero = intrinsics::constant_index(0);
+    ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
+    ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
+    MemRefView vOp(operand);
+    IndexedValue iOp(operand);
+    IndexHandle i, j, M(vOp.ub(0)), N(vOp.ub(1));
+    LoopBuilder(&i, zero, M, 1)({
+      LoopBuilder(&j, zero, N, 1)({
+        llvmCall(retTy,
+                 rewriter.getFunctionAttr(printfFunc),
+                 {fmtCst, iOp(i, j)})
+      }),
+      llvmCall(retTy, rewriter.getFunctionAttr(printfFunc), {fmtEol})
+    });
+```
+
+For instance the Toy IR may contain:
+
+```
+  "toy.print"(%0) : (!toy.array<2, 2>) -> ()
+```
+
+which the converter above will turn into this sequence:
+
+```mlir
+  affine.for %i0 = 0 to 2 {
+    affine.for %i1 = 0 to 2 {
+      %3 = load %0[%i0, %i1] : memref<2x2xf64>
+      %4 = llvm.call @printf(%1, %3) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
+    }
+    %5 = llvm.call @printf(%2, %cst_21) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
+  }
+```
+
+Note the mix of a loop nest in the `Affine` dialect, with an operation
+`llvm.call` in the body. MLIR knows already how to lower this to:
+
+```mlir
+  llvm.br ^bb1(%87 : !llvm.i64)
+^bb1(%89: !llvm.i64):   // 2 preds: ^bb0, ^bb5
+  %90 = llvm.icmp "slt" %89, %88 : !llvm.i64
+  llvm.cond_br %90, ^bb2, ^bb6
+^bb2:   // pred: ^bb1
+  %91 = llvm.constant(0 : index) : !llvm.i64
+  %92 = llvm.constant(2 : index) : !llvm.i64
+  llvm.br ^bb3(%91 : !llvm.i64)
+^bb3(%93: !llvm.i64):   // 2 preds: ^bb2, ^bb4
+  %94 = llvm.icmp "slt" %93, %92 : !llvm.i64
+  llvm.cond_br %94, ^bb4, ^bb5
+^bb4:   // pred: ^bb3
+  %95 = llvm.constant(2 : index) : !llvm.i64
+  %96 = llvm.constant(2 : index) : !llvm.i64
+  %97 = llvm.mul %89, %96 : !llvm.i64
+  %98 = llvm.add %97, %93 : !llvm.i64
+  %99 = llvm.getelementptr %6[%98] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
+  %100 = llvm.load %99 : !llvm<"double*">
+  %101 = llvm.call @printf(%48, %100) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
+  %102 = llvm.constant(1 : index) : !llvm.i64
+  %103 = llvm.add %93, %102 : !llvm.i64
+  llvm.br ^bb3(%103 : !llvm.i64)
+^bb5:   // pred: ^bb3
+  %104 = llvm.call @printf(%76, %71) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
+  %105 = llvm.constant(1 : index) : !llvm.i64
+  %106 = llvm.add %89, %105 : !llvm.i64
+  llvm.br ^bb1(%106 : !llvm.i64)
+```
+
+We appreciate the ease to generate the former, as well as the readability!
+
+You may reproduce these results with `echo "def main() { print([[1,2],[3,4]]); }
+" | bin/toyc-ch5 -x toy - -emit=llvm-dialect` and `echo "def main() {
+print([[1,2],[3,4]]); } " | bin/toyc-ch5 -x toy - -emit=llvm-ir`.
+
+# CodeGen: Getting Out of MLIR
+
+At this point, all the IR is expressed in the LLVM dialect, MLIR can perform a
+straight conversion to an LLVM module. You may look into
+[`Ch5/toyc.cpp`](../../../examples/toy/Ch5/toyc.cpp) for the `dumpLLVM()`
+function:
+
+```c++
+int dumpLLVM() {
+  mlir::MLIRContext context;
+  auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true);
+  auto llvmModule = translateModuleToLLVMIR(*module);
+  if (!llvmModule) {
+    llvm::errs() << "Failed to emit LLVM IR\n";
+    return -1;
+  }
+  llvm::errs() << *llvmModule << "\n";
+  return 0;
+}
+```
+
+Adding a JIT isn't much more involved either:
+
+```c++
+int runJit() {
+  mlir::MLIRContext context;
+  auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true);
+
+  // Initialize LLVM targets.
+  llvm::InitializeNativeTarget();
+  llvm::InitializeNativeTargetAsmPrinter();
+
+  // Create an MLIR execution engine.  Note that it takes a null pass manager
+  // to make sure it won't run "default" passes on the MLIR that would trigger
+  // a second conversion to LLVM IR.  The execution engine eagerly JIT-compiles
+  // the module.
+  auto maybeEngine =
+      mlir::ExecutionEngine::create(module.get(), /*pm=*/nullptr);
+  assert(maybeEngine && "failed to construct an execution engine");
+  auto &engine = maybeEngine.get();
+
+  // Invoke the JIT-compiled function with the arguments.  Note that, for API
+  // uniformity reasons, it takes a list of type-erased pointers to arguments.
+  auto invocationResult = engine->invoke("main");
+  if(invocationResult) {
+    llvm::errs() << "JIT invocation failed\n";
+    return -1;
+  }
+
+  return 0;
+}
+```
+
+You can play with it, from the build directory:
+
+```bash
+$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch5 -emit=jit
+1.000000 2.000000
+3.000000 4.000000
+```
+
+You can also play with `-emit=mlir`, `-emit=mlir-linalg`, `-emit=llvm-dialect`,
+and `-emit=llvm-ir` to compare the various level of IR involved. Try also
+options like `--print-ir-after-all` to track the evolution of the IR throughout
+the pipeline.
index a98d366..39fa952 100644 (file)
@@ -32,6 +32,7 @@ if(LLVM_BUILD_EXAMPLES)
     toyc-ch2
     toyc-ch3
     toyc-ch4
+    toyc-ch5
     )
 endif()
 
diff --git a/mlir/test/Examples/Toy/Ch5/ast.toy b/mlir/test/Examples/Toy/Ch5/ast.toy
new file mode 100644 (file)
index 0000000..0eaa513
--- /dev/null
@@ -0,0 +1,73 @@
+# RUN: toyc-ch5 %s -emit=ast 2>&1 | FileCheck %s
+
+
+# User defined generic function that operates solely on 
+def multiply_transpose(a, b) {
+  return a * transpose(b);
+}
+
+def main() {
+  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
+  # The shape is inferred from the supplied literal.
+  var a = [[1, 2, 3], [4, 5, 6]];
+  # b is identical to a, the literal array is implicitely reshaped: defining new
+  # variables is the way to reshape arrays (element count must match).
+  var b<2, 3> = [1, 2, 3, 4, 5, 6];
+  # This call will specialize `multiply_transpose` with <2, 3> for both
+  # arguments and deduce a return type of <2, 2> in initialization of `c`.
+  var c = multiply_transpose(a, b);
+  # A second call to `multiply_transpose` with <2, 3> for both arguments will
+  # reuse the previously specialized and inferred version and return `<2, 2>`
+  var d = multiply_transpose(b, a);
+  # A new call with `<2, 2>` for both dimension will trigger another
+  # specialization of `multiply_transpose`.
+  var e = multiply_transpose(b, c);
+  # Finally, calling into `multiply_transpose` with incompatible shape will
+  # trigger a shape inference error.
+  var e = multiply_transpose(transpose(a), c);
+}
+
+
+# CHECK: Module:
+# CHECK-NEXT:     Function
+# CHECK-NEXT:       Proto 'multiply_transpose' @{{.*}}Toy/Ch5/ast.toy:5:1'
+# CHECK-NEXT:       Params: [a, b]
+# CHECK-NEXT:       Block {
+# CHECK-NEXT:         Retur
+# CHECK-NEXT:           BinOp: * @{{.*}}Toy/Ch5/ast.toy:6:14
+# CHECK-NEXT:             var: a @{{.*}}Toy/Ch5/ast.toy:6:10
+# CHECK-NEXT:             Call 'transpose' [ @{{.*}}Toy/Ch5/ast.toy:6:14
+# CHECK-NEXT:               var: b @{{.*}}Toy/Ch5/ast.toy:6:24
+# CHECK-NEXT:             ]
+# CHECK-NEXT:       } // Block
+# CHECK-NEXT:     Function
+# CHECK-NEXT:       Proto 'main' @{{.*}}Toy/Ch5/ast.toy:9:1'
+# CHECK-NEXT:       Params: []
+# CHECK-NEXT:       Block {
+# CHECK-NEXT:         VarDecl a<> @{{.*}}Toy/Ch5/ast.toy:12:3
+# CHECK-NEXT:           Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}Toy/Ch5/ast.toy:12:11
+# CHECK-NEXT:         VarDecl b<2, 3> @{{.*}}Toy/Ch5/ast.toy:15:3
+# CHECK-NEXT:           Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}Toy/Ch5/ast.toy:15:17
+# CHECK-NEXT:         VarDecl c<> @{{.*}}Toy/Ch5/ast.toy:18:3
+# CHECK-NEXT:           Call 'multiply_transpose' [ @{{.*}}Toy/Ch5/ast.toy:18:11
+# CHECK-NEXT:             var: a @{{.*}}Toy/Ch5/ast.toy:18:30
+# CHECK-NEXT:             var: b @{{.*}}Toy/Ch5/ast.toy:18:33
+# CHECK-NEXT:           ]
+# CHECK-NEXT:         VarDecl d<> @{{.*}}Toy/Ch5/ast.toy:21:3
+# CHECK-NEXT:           Call 'multiply_transpose' [ @{{.*}}Toy/Ch5/ast.toy:21:11
+# CHECK-NEXT:             var: b @{{.*}}Toy/Ch5/ast.toy:21:30
+# CHECK-NEXT:             var: a @{{.*}}Toy/Ch5/ast.toy:21:33
+# CHECK-NEXT:           ]
+# CHECK-NEXT:         VarDecl e<> @{{.*}}Toy/Ch5/ast.toy:24:3
+# CHECK-NEXT:           Call 'multiply_transpose' [ @{{.*}}Toy/Ch5/ast.toy:24:11
+# CHECK-NEXT:             var: b @{{.*}}Toy/Ch5/ast.toy:24:30
+# CHECK-NEXT:             var: c @{{.*}}Toy/Ch5/ast.toy:24:33
+# CHECK-NEXT:           ]
+# CHECK-NEXT:         VarDecl e<> @{{.*}}Toy/Ch5/ast.toy:27:3
+# CHECK-NEXT:           Call 'multiply_transpose' [ @{{.*}}Toy/Ch5/ast.toy:27:11
+# CHECK-NEXT:             Call 'transpose' [ @{{.*}}Toy/Ch5/ast.toy:27:30
+# CHECK-NEXT:               var: a @{{.*}}Toy/Ch5/ast.toy:27:40
+# CHECK-NEXT:             ]
+# CHECK-NEXT:             var: c @{{.*}}Toy/Ch5/ast.toy:27:44
+# CHECK-NEXT:           ]
+
diff --git a/mlir/test/Examples/Toy/Ch5/codegen.toy b/mlir/test/Examples/Toy/Ch5/codegen.toy
new file mode 100644 (file)
index 0000000..607cb15
--- /dev/null
@@ -0,0 +1,32 @@
+# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
+
+# User defined generic function that operates on unknown shaped arguments
+def multiply_transpose(a, b) {
+  return a * transpose(b);
+}
+
+def main() {
+  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+  var b<2, 3> = [1, 2, 3, 4, 5, 6];
+  var c = multiply_transpose(a, b);
+  var d = multiply_transpose(b, a);
+  print(d);
+}
+
+# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
+# CHECK-NEXT:   attributes  {toy.generic: true} {
+# CHECK-NEXT:   %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
+# CHECK-NEXT:   %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
+# CHECK-NEXT:   "toy.return"(%1) : (!toy.array) -> ()
+# CHECK-NEXT: }
+
+# CHECK-LABEL: func @main() {
+# CHECK-NEXT:   %0 = "toy.constant"() {value: dense<tensor<2x3xf64>, {{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> !toy.array<2, 3>
+# CHECK-NEXT:   %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
+# CHECK-NEXT:   %2 = "toy.constant"() {value: dense<tensor<6xf64>, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]>} : () -> !toy.array<6>
+# CHECK-NEXT:   %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
+# CHECK-NEXT:   %4 = "toy.generic_call"(%1, %3) {callee: "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
+# CHECK-NEXT:   %5 = "toy.generic_call"(%3, %1) {callee: "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
+# CHECK-NEXT:   "toy.print"(%5) : (!toy.array) -> ()
+# CHECK-NEXT:   "toy.return"() : () -> ()
+
diff --git a/mlir/test/Examples/Toy/Ch5/invalid.mlir b/mlir/test/Examples/Toy/Ch5/invalid.mlir
new file mode 100644 (file)
index 0000000..df8e2df
--- /dev/null
@@ -0,0 +1,11 @@
+// RUN: not toyc-ch5 %s -emit=mlir 2>&1
+
+
+// This IR is not "valid":
+// - toy.print should not return a value.
+// - toy.print should take an argument.
+// - There should be a block terminator.
+// This all round-trip since this is opaque for MLIR.
+func @main() {
+  %0 = "toy.print"()  : () -> !toy.array<2, 3>
+}
diff --git a/mlir/test/Examples/Toy/Ch5/lowering.toy b/mlir/test/Examples/Toy/Ch5/lowering.toy
new file mode 100644 (file)
index 0000000..3c198a6
--- /dev/null
@@ -0,0 +1,16 @@
+# RUN: toyc-ch5 %s -emit=llvm-ir 2>&1 | FileCheck %s
+
+# User defined generic function that operates on unknown shaped arguments
+def multiply_transpose(a, b) {
+  return a * transpose(b);
+}
+
+# CHECK: define void @main() {
+# CHECK:  %1 = call i8* @malloc(i64 48)
+def main() {
+  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
+  var b<2, 3> = [1, 2, 3, 4, 5, 6];
+  var c = multiply_transpose(a, b);
+  var d = multiply_transpose(b, a);
+  print(d);
+}
diff --git a/mlir/test/Examples/Toy/Ch5/scalar.toy b/mlir/test/Examples/Toy/Ch5/scalar.toy
new file mode 100644 (file)
index 0000000..b4a82dd
--- /dev/null
@@ -0,0 +1,14 @@
+# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
+
+def main() {
+  var a<2, 2> = 5.5;
+  print(a);
+}
+
+# CHECK-LABEL: func @main() {
+# CHECK-NEXT:    %0 = "toy.constant"() {value: dense<tensor<1xf64>, [5.500000e+00]>} : () -> !toy.array<1>
+# CHECK-NEXT:    %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
+# CHECK-NEXT:    "toy.print"(%1) : (!toy.array<2, 2>) -> ()
+# CHECK-NEXT:    "toy.return"() : () -> ()
+# CHECK-NEXT:  }
+
diff --git a/mlir/test/Examples/Toy/Ch5/transpose_transpose.toy b/mlir/test/Examples/Toy/Ch5/transpose_transpose.toy
new file mode 100644 (file)
index 0000000..109cbd8
--- /dev/null
@@ -0,0 +1,19 @@
+# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
+# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
+
+def transpose_transpose(x) {
+  return transpose(transpose(x));
+}
+
+def main() {
+  print(transpose_transpose([[1, 2], [3, 4]]));
+}
+
+#CHECK-LABEL: func @transpose_transpose
+#CHECK: transpose
+#CHECK-LABEL: main
+
+
+#OPT-LABEL: func @transpose_transpose
+#OPT-NOT: transpose
+
diff --git a/mlir/test/Examples/Toy/Ch5/trivialReshape.toy b/mlir/test/Examples/Toy/Ch5/trivialReshape.toy
new file mode 100644 (file)
index 0000000..cb9946d
--- /dev/null
@@ -0,0 +1,24 @@
+# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
+# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
+
+# We expect no reshape in this function with optimizations enabled
+def foo(a) {
+  var b<2,1> = a;
+  var c<2,1> = b;
+  print(c);
+}
+
+def main() {
+  var a<2, 1> = [1, 2];
+  foo(a);
+}
+
+# without optimizations, match the reshape
+#CHECK-LABEL: func @foo
+#CHECK: reshape
+#CHECK-LABEL: main
+
+# with optimizations, ensure no reshape
+#OPT-LABEL: main
+#OPT-LABEL: func @foo_2x1
+#OPT-NOT: reshape
index d065441..fc97f11 100644 (file)
@@ -61,6 +61,7 @@ tools.extend([
     ToolSubst('toy-ch2', unresolved='ignore'),
     ToolSubst('toy-ch3', unresolved='ignore'),
     ToolSubst('toy-ch4', unresolved='ignore'),
+    ToolSubst('toy-ch5', unresolved='ignore'),
 ])
 
 llvm_config.add_tool_substitutions(tools, tool_dirs)