-JIT Technical Overview
-======================
+# JIT Technical Overview
The JIT can run and optimize PyTorch programs separate from the Python interpreter. This overview is organized into sections that go over different independent components:
Sections start with a reference to the source file where the code related to the section resides.
-Core Program Representation
----------------------------
-
-
-### Modules ###
+## Table of Contents
+
+- [JIT Technical Overview](#jit-technical-overview)
+ * [Table of Contents](#table-of-contents)
+- [Core Program Representation](#core-program-representation)
+ * [Modules](#modules)
+ * [Parameters](#parameters)
+ * [Method](#method)
+ * [FunctionSchema](#functionschema)
+ * [Graph](#graph)
+ * [Node](#node)
+ * [Block](#block)
+ + [If](#if)
+ + [Loops](#loops)
+ * [Value](#value)
+ * [Type](#type)
+- [Generating Programs](#generating-programs)
+ * [Tracer](#tracer)
+ * [Script](#script)
+ * [Tree](#tree)
+ * [Tree Views](#tree-views)
+ * [frontend.py](#frontendpy)
+ * [Lexer](#lexer)
+ * [Tokens](#tokens)
+ * [Parser](#parser)
+ * [Compiler](#compiler)
+ * [SugaredValue](#sugaredvalue)
+ * [Resolver](#resolver)
+ * [Environment](#environment)
+ * [Python-Compiler Interaction](#python-compiler-interaction)
+- [Executing Programs](#executing-programs)
+ * [Evaluation Semantics](#evaluation-semantics)
+ * [IValue](#ivalue)
+ * [Operation](#operation)
+ * [Operator](#operator)
+ * [Interpreter](#interpreter)
+ * [Graph Executor](#graph-executor)
+ * [DifferentiableGraphOp](#differentiablegraphop)
+ * [Interpreter](#interpreter-1)
+ * [FusionGroup](#fusiongroup)
+ * [Handling Mutability](#handling-mutability)
+ + [Aliasing and mutation in the PyTorch API](#aliasing-and-mutation-in-the-pytorch-api)
+ + [Aliasing and mutation annotations in FunctionSchema](#aliasing-and-mutation-annotations-in-functionschema)
+ + [Alias Analysis in the IR](#alias-analysis-in-the-ir)
+ + [Writing optimization passes with `AliasDb`](#writing-optimization-passes-with--aliasdb-)
+- [Saving Programs](#saving-programs)
+ * [PythonPrint](#pythonprint)
+ * [Serialization](#serialization)
+ + [Overview](#overview)
+ + [`model.json`](#-modeljson-)
+ + [`code`](#-code-)
+ + [`tensors/`](#-tensors--)
+ + [`attributes.pkl`](#-attributespkl-)
+ + [Implementation Details](#implementation-details)
+- [Python Bindings](#python-bindings)
+
+
+# Core Program Representation
+
+## Modules ##
[script/module.h](script/module.h)
This mirrors the `nn.Module` objects used in Python. All TorchScript code is a member of some module. This includes pure functions such as those created by annotating a Python function with `@torch.jit.script`, which are represented internally as a Module that has a single method `forward` that contains the implementation of the function.
-### Parameters ###
+## Parameters ##
[script/module.h](script/module.h)
Modules contain Parameter objects, which simply hold a "slot" where a Tensor can be placed. These tensors are accessible by the Methods of the Module or the parent Module.
-### Method ###
+## Method ##
[script/module.h](script/module.h)
Methods also contain helper functions for inserting calls to the Method from other Method objects.
-### FunctionSchema ###
+## FunctionSchema ##
[aten/src/ATen/core/function_schema.h](../../../aten/src/ATen/core/function_schema.h)
Each Method has a FunctionSchema that describes the Types of the arguments and return values of a function. Operators (builtin primitives that are called by the Interpreter) also have FunctionSchema. FunctionSchema are analogous to a function _declaration_ in C++. They describe how to call the function but do not provide an implementation.
-### Graph ###
+## Graph ##
[ir.h](ir.h)
Graphs are the root of the intermediate representation (IR) used to define the implementation of TorchScript functions. If you are familiar with [LLVM](llvm.org), they are analogous to an `llvm::Function` object. A Graph is composed of Nodes, Blocks, and Values. Nodes are instructions (e.g. do a matrix multiply). Nodes are organized into Blocks of sequentially executed Nodes. Each Node produces a list of output Values, and also consumes a list of input Values. As an example, a user may write the following TorchScript code:
-```py
+```python
@torch.jit.script
def f(a, b):
c = a + b
Because Graph owns all its Nodes, Values, and Blocks, these values are always passed around by raw pointer. Generally developers should not write code that holds Value, Node, or Block objects indefinitely without also holding a shared_ptr to their owning Graph.
-### Node ###
+## Node ##
[ir.h](ir.h)
Attributes are _rarely used_. Operators like convolution or matrix-multiply have no attributes and take of their arguments through the input list. This includes things that might be typically through of as constants, like the stride of the convolution. In PyTorch, any of this information is potentially a dynamic property of the program so Nodes are always encoded in a way that allows these values to be dynamically determined. However, we recognize that many inputs are almost always constants, so we make it easy to quickly check if an input is constant and get its value with `c10::optional<IValue> Node::get(Symbol name)`, which returns an IValue (a concrete value for the input) in the case the node is constant and `nullopt` otherwise.
-### Block ###
+## Block ##
[ir.h](ir.h)
**Control-flow** is represented with using sub-blocks rather than a control-flow graph representation. A `prim::If` has one block for the true branch and one block for the else.A `prim:Loop` has a block for the loop body (there is no condition block, instead the end of the loop body computes whether to re-enter the loop body). This representation ensures we have structured control-flow. Currently TorchScript does not allow for early returns, breaking out of loops early. This limitation makes a lot of optimizations easier and is true for the vast majority of networks. Our frontend permits certain forms of syntax sugar that allow a limited amount of re-writing of if statements to avoid needing to support early returns. A Node can lookup what Block it is in, and a Block and can look up its parent (either the Node that has it as a subblock, or `nullptr` for the main Block).
-#### If ####
+### If ###
For if-statements (`prim::If`) the Blocks have no inputs, and the outputs are the new values of variables in the outer block whose values were altered in an if-statement.
Example IR for an if-statement looks like:
```
Here's an example translation of a Python program:
-```py
+```python
def f(a, b, c):
d = a + b
if c:
The outputs of the if-statement serve a role similar to "phi" nodes in traditional SSA control-flow graphs.
-#### Loops ####
+### Loops ###
Loops are implemented with `prim::Loop` which covers both `while` and `for` loops. A valid instantiation of this node always looks like this:
```
%y_1, ..., %y_r = prim::Loop(%max_trip_count, %initial_condition, %x_1, ..., %x_r)
```
The simplest way to explain the semantics is to consider this Python-like pseudo-code:
-```py
+```python
y_1, ..., y_r = x_1, ..., x_r
condition = initial_condition
i = 0
For example, this program:
-```py
+```python
def f(x):
z = x
for i in range(x.size(0)):
}
```
-### Value ###
+## Value ##
[ir.h](ir.h)
Values are abstract representation of data in the program. When executing, the actual tensors, list, tuples, etc. are stored in IValues (_interpreter_ values), which are tagged unions of all possible values in TorchScript. In retrospect the name Value is a bit confusing because it seems like it should be the tagged union, but it originally came from analogy to `llvm::Value`, which serves the same purpose as `jit::Value`.
-### Type ###
+## Type ##
[aten/src/ATen/core/jit_type.h](../../../aten/src/ATen/core/jit_type.h)
If type S is a subtype of P, then we can substitute an IValue that has type S anywhere something of type P is expected. This means that all subtyping relationships also require the representation of the IValue for subtypes to be compatible with the representation for the base type.
-Generating Programs
--------------------
+# Generating Programs #
JIT programs are created using either the tracing frontend (`torch.jit.trace`) or the scripting frontend (`torch.jit.script`). In both cases, the result of these frontends is a complete Module that contains all the code in Methods, and all the model weights in the Parameters of the Module. However, each frontend goes through a different pathway for generating those Modules.
-### Tracer ###
+## Tracer ##
[tracer.h](tracer.h)
As the trace runs, individual operators create Nodes in the Graph being traced to record what happens. This code is currently generated per operator in [tools/autograd/gen_variable_type.py](../../../tools/autograd/gen_variable_type.py). It results in code that looks like the following:
-```
+```cpp
torch::jit::Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
if (jit::tracer::isTracing()) {
The resulting Graph created by tracing is installed as the 'forward' method of the Module being created. A Module is produced regardless of whether the thing being traced was a function or a `torch.nn.Module`. In the function case, the Module produced will simply have a single `forward` function, no Parameters, and no sub-Modules.
-### Script ###
+## Script ##
The script frontend directly converts Python syntax into Modules. Like many compilers this happens in two phases. First, we generate an abstract syntax tree (AST), which is constructed out of Tree objects. The compiler (misnamed, but that is the name of the file) then does semantic analysis on the Tree and lowers it into a Module. We can generate Trees in two ways: (1) using frontend.py, which takes the Python AST and transliterates it into Tree objects, or (2) via the Lexer and Parser which parse python syntax directly. The Lexer/Parser path may seem redundant but it is crucially important. We need to define builtin functions ([script/builtin_functions.py](script/builtin_functions.py)) when Python is not linked. We allow users to load TorchScript programs directly from strings without Python ([api/include/torch/jit.h](../../api/include/torch/jit.h)). We also use this Python syntax as the serialization format for TorchScript, since it allows us to make changes to our IR without breaking backward compatibility. Furthermore, the Lexer is reused to implement the FunctionSchema parser, which turns FunctionSchema declarations from strings into FunctionSchema objects.
The following sections look into each the stages in the script frontend in detail.
-### Tree ###
+## Tree ##
[script/tree.h](script/tree.h)
Each tree also has a mandatory SourceRange object that describes the range of text that it came from. These will be used for error reporting in the rest of the code.
-### Tree Views ###
+## Tree Views ##
[script/tree_views.h](script/tree_views.h)
Trees are easy to construct visualize and traverse, but extracting information from a large compound tree like that of a function definition is unwieldy since it requires numeric indexing. Tree _Views_ are a small layer on top of a tree that make it possible to create and de-structure trees of particular kinds. For example, here is the tree view for the apply node which provides named accessors for its subtrees: the function being called, the inputs, and the attributes (i.e. kwargs):
-```
+```cpp
struct Apply : public Expr {
Expr callee() const {
return Expr(subtree(0));
The typical way to traverse a tree is to `switch` on the kind and then construct the appropriate Treeview:
-```
+```cpp
switch (tree.kind()) {
case TK_VAR:
auto var = Var(tree); // construct tree-view
```
-### frontend.py ###
+## frontend.py ##
[torch/jit/frontend.py](../../jit/frontend.py)
So this code simply constructs the Tree, filtering out the AST nodes of Python that we do not support.
-### Lexer ###
+## Lexer ##
[script/lexer.h](script/lexer.h)
Similar to Python, the Lexer handles the white-space sensitive nature of Python blocks. The Tokens `TK_INDENT`, `TK_DEDENT`, and `TK_NEWLINE` are injected into the token stream when code first becomes indented, when it dedents, and at the end of a statement. For instance for this stream:
-```
+```cpp
if
.
.
We would get a token stream `TK_IF TK_NEWLINE TK_INDENT . TK_NEWLINE . TK_NEWLINE TK_DEDENT`. Unmatched opening brackets disable the injection of these tokens. The result is that the Parser can simply treat `TK_IDENT`, `TK_DEDENT` and `TK_NEWLINE` like C's `{`, `}`, and `;`.
-### Tokens ###
+## Tokens ##
[script/lexer.h](script/lexer.h)
Tokens are either keywords (`def`), operators (`+`), literals (`3.4`), or identifiers (`foo`). A `token_kind` integer identifies what it is and is the exact same type as the `kind` of a Tree. For single-character Tokens (e.g. `+`), the kind is the same as the character, enable statements like:
-```
+```cpp
if (lexer.nextIf('+')) {
// handle + ...
}
Multi-character token kinds are defined in a list, `TC_FORALL_TOKEN_KINDS`. Tokens also have a `text()` field that records the actual string producing the token and is used by identifiers and literals to construct the actual values (e.g. the numeric value of a floating point literal).
-### Parser ###
+## Parser ##
[script/parser.h](script/parser.h)
The Parser is written as a [top-down precedence parser](https://eli.thegreenplace.net/2010/01/02/top-down-operator-precedence-parsing), or "Pratt" parser. They are simpler and easier to understand than typical parser generators, while still being flexible enough to parse programming languages. For the most part parsing is done by recursive decent. To resolve operator precedence issues, the function to parse an expression is augmented with a precedent _p_ such that calling the function means _parse an expression whose operators all have precedence higher than p_.
-### Compiler ###
+## Compiler ##
[script/compiler.h](script/compiler.h)
The Environment tracks the mapping between variable names and the SugaredValues they refer to.
-### SugaredValue ###
+## SugaredValue ##
[script/sugared_value.h](script/sugared_value.h)
Finally, normal Values are also represented by the SimpleValue SugaredValue in places where it is valid either a SugaredValue or a normal Value to appear.
-### Resolver ###
+## Resolver ##
[script/compiler.h](script/compiler.h)
This makes it possible to use most of the compiler functionality when python is not present.
-### Environment ###
+## Environment ##
[script/compiler.cpp](script/compiler.cpp)
The Environment object tracks the assignment of variable names to SugaredValues during compilation. It is local to the compiler file. A stack of environments exist, with a new environment being created for sub-blocks introduced by control flow. The Environment also handles turning the AST representation into SSA-form by tracking which variables were modified inside a sub-block and inserting the correct inputs/outputs to the Blocks of if-statements and loops.
-### Python-Compiler Interaction ###
+## Python-Compiler Interaction ##
[script/init.cpp](script/init.cpp)
A set of special SugaredValues are used to translate between objects in the Python environment and Values in the Graph during the compilation process. The entry-point for this behavior is `toSugaredValue(py::object obj, ...)` which takes a pybind11 Python value and figures out how to turn it into an appropriate SugaredValue. Values exist to represent Python functions, Python modules, and ScriptModule objects.
-Executing Programs
-------------------
+# Executing Programs #
TorchScript is executed using a interpreter attached to a JIT-optimizer and compiler. The entry-point for execution is the GraphExecutor object that is created on demand inside a Method when the method is first called. This section first goes over the semantics of graphs, i.e. what does it mean to execute a graph? And then details how the implementation works.
-### Evaluation Semantics ###
+## Evaluation Semantics ##
TorchScript programs implement a very small subset of Python of that is necessary to run models.
TorchScript includes immutable value types:
-* int
-* float
-* Tuple[T0, T1, ...]
+* `int`
+* `float`
+* `Tuple[T0, T1, ...]`
As well as mutable reference types:
-* Tensor
-* List[T]
-* Dict[K, V]
+* `Tensor`
+* `List[T]`
+* `Dict[K, V]`
A value of a reference type points to an underlying memory location where the data for the reference type is stored, and variable assignment for a reference type can cause multiple values to point to the same underlying data. This is similar to Python's class model.
It is important to remember that TorchScript uses these semantics for Tensors so not all computation on Tensor is pure. Individual Tensors may be *views* of the same underlying data. Views are established by special view creating operations, such as indexing into a tensor:
-```
+```python
t = torch.rand(3, 4)
t2 = t[0] # view of one slice of t
```
Some builtin operators also mutably write to the underlying tensor. In the standard library these operators are always named with a training underscore, or take a named `out` tensor where the result is written:
-```
+```python
t2.relu_() # inplace relu operator, note t is modified as well!
torch.add(t, t, out=t) # update t, without using temporary memory if possible
```
We also provide user-accessible parallel execution through the `fork` and `wait` primitives. The `fork` primitive begins execution of `fn` in parallel with the current thread of execution, immediately returning a Future object that will hold the result of the forked function. The `wait` method of the future then causes the invoking task to wait for the value being computed on the forked task.
-```
+```python
def fn(arg0, arg1, ...):
...
return v
Optimization passes that wish to exploit multi-threaded execution may automatically convert serial Blocks into parallel execution by inserting extra fork and wait events. This design enables our users to manually specify parallelism while also allowing optimization passes to exploit it when safe and profitable.
-### IValue ###
+## IValue ##
[ivalue.h](../../include/ATen/core/ivalue.h)
IValue contains methods to check the type `isTensor` and to convert to particular to type `toTensor`. We do not publicly expose the type tag and force clients to use the `isX` methods. This enables us to change the underlying implementation of IValue later, e.g. to use an 8-byte value with NaN-boxing. Most operators work on a specific static type, so dynamic dispatch on the tag is not frequently required.
-### Operation ###
+## Operation ##
All builtin operators are represented using a stack machine concept. An operator pops its arguments off the top of the stack and pushes its result to the stack:
-```
+```cpp
using Stack = std::vector<IValue>;
using Operation = std::function<int(Stack&)>;
Operations also return a jump offset relative to the address of the next operator in the program to for dynamic control flow. Except for special Operations in the interpreter that handle control-flow all Operations should return 0 here. It is a bit weird to force all Operations to return 0, but it avoids having to have another level of indirection to wrap void functions in something that returns 0.
-### Operator ###
+## Operator ##
[operator.h](operator.h)
The Operator object represents a single registered operator in the system. It combines a FunctionSchema that describes how an Operation executes with a method to lookup the corresponding Operation given the Node representing the operator in a Graph. Most Operators are defined by providing a FunctionSchema and an Operation function. However, primitives like prim::Unpack require knowledge of their Node to know how to operate (e.g. how many elements to unpack). These Operators have a function that takes a Node* and returns an operation.
-### Interpreter ###
+## Interpreter ##
[interpreter.cpp](interpreter.cpp)
= Store move(28)
```
-### Graph Executor ###
+## Graph Executor ##
[graph_executor.cpp](graph_executor.cpp)
This section will use an example this LSTM program:
-```
+```python
@torch.jit.script
def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
}
```
-### DifferentiableGraphOp ###
+## DifferentiableGraphOp ##
[graph_executor.cpp](graph_executor.cpp)
A DifferentiableGraphOp combines an explicit forward Graph `f` with a paired backward graph `df`. When it runs, the input Tensors to `f` are detached from the autograd, the body of `f` is run, and then the autograd graph for the outputs of `f` are hooked up to the `df` function. The `df` function's outputs are also hooked up to the autograd graph.
-### Interpreter ###
+## Interpreter ##
* Code
* InterpreterState and interpreter design
* Fork/Wait
-### FusionGroup ###
+## FusionGroup ##
* inserted by passes
-### Handling Mutability ###
-#### Aliasing and mutation in the PyTorch API
-In PyTorch, tensors are reference types. Operators can return "views" of the input tensor, creating a new tensor object that shares the same underlying storage as the original:
-```py
+## Handling Mutability ##
+### Aliasing and mutation in the PyTorch API
+In PyTorch, tensors are reference types. Operators can return "views" of the input tensor, creating a new tensor object that shares the same underlying storage as the original:
+```python
a = torch.rand(2, 3)
b = a
# At this point, `a` and `b` share their storage.
```
Some operators will *mutate* one or more of their operands in-place. These are typically denoted with a trailing underscore, or by taking an `out` argument as input:
-```py
+```python
a = torch.zeros(2, 3)
b = torch.ones(2, 3)
a.add_(b) # in-place add, so `a` is modified.
torch.add(a, b, out=a) # another way to express the same thing
```
-#### Aliasing and mutation annotations in FunctionSchema
+### Aliasing and mutation annotations in FunctionSchema
The JIT's `FunctionSchema` allows operator writers to add annotations specifying the aliasing and mutation behavior of an operator. Optimization passes will use this information to determine whether transformations are semantics-preserving. This section provides a description of the alias annotation language, assuming that the reader already knows what `FunctionSchema` looks like.
First, here is a pure function which always returns new memory:
```
Note the alias set `*`. This is the **wildcard set**. Optimization passes must assume that values in the wildcard set may alias any other value in the graph. This behavior is conservative and will disallow optimizations, but is guaranteed to be safe. In most cases, people shouldn't be writing operators with wildcard annotations. They are used as temporary workaround for when our alias analysis isn't sophisticated enough to understand something yet but we don't want to block feature development.
-This annotation language is consumed by the `FunctionSchema` parser, which produces `AliasInfo` objects summarizing the aliasing relationships for each schema `Argument`.
+This annotation language is consumed by the `FunctionSchema` parser, which produces `AliasInfo` objects summarizing the aliasing relationships for each schema `Argument`.
-#### Alias Analysis in the IR
+### Alias Analysis in the IR
An alias analysis pass consumes the per-operator aliasing information to construct a database of aliasing and mutation relationships in a graph, called `AliasDb`. This section focuses on the alias analysis pass; the public interface to `AliasDb` will be described later.
The core data structure in the AliasDb is called `AliasTracker`, which is a DAG where the edges are "may point to" relationships and the vertices are aliasing `Element`s. The most common kind of `Element` is an IR `Value`, but there are other kinds of things that can alias that aren't first-class `Value`s in the IR, like wildcards or contained types (such as in a list or tuple).
```
view(Tensor(a) self, int[] size) -> Tensor(a)
```
-and add an edge from `%output` to `%self`. The alias analysis pass is flow-insensitive, as we are only adding "points-to" edges when processing a node.
+and add an edge from `%output` to `%self`. The alias analysis pass is flow-insensitive, as we are only adding "points-to" edges when processing a node.
As a more involved example, the following TorchScript snippet:
-```py
+```python
@torch.jit.script
def foo(a : Tensor, b : Tensor):
c = 2 * b
So to determine whether `a` and `b` may alias, we traverse the `AliasTracker` DAG and figure out if `a` and `b` share any leaf nodes. If they do, then we know `a` and `b` might point to the same memory location, i.e. `a` and `b` may alias. This kind of query is common enough that `AliasTracker` does path compression to speed up leaf-finding, so that aliasing queries can be serviced in amortized constant time.
-#### Writing optimization passes with `AliasDb`
-`AliasDb` provides a high-level interface to help people write mutability-safe optimization passes.
+### Writing optimization passes with `AliasDb`
+`AliasDb` provides a high-level interface to help people write mutability-safe optimization passes.
-In particular, `moveAfterTopologicallyValid()` (and it's `moveBefore` variant) will reorder nodes in a way that preserves data dependencies and avoids any data hazards. The rules for this are that all mutable *writes* to a give memory location must occur in the same order (avoid WAW hazards), and that no reads can be reordered before or after any write (WAR, RAW hazards).
+In particular, `moveAfterTopologicallyValid()` (and it's `moveBefore` variant) will reorder nodes in a way that preserves data dependencies and avoids any data hazards. The rules for this are that all mutable *writes* to a give memory location must occur in the same order (avoid WAW hazards), and that no reads can be reordered before or after any write (WAR, RAW hazards).
However, reordering of reads across writes *is allowed* if we can prove that the read cannot alias the thing being written. This happens whenever we have tensors that come from functions that produce fresh results (common) inside of the function. It also happens whenever the creation of the mutable tensor is seen in the function (so it gets assigned a fresh variable), and all of its writes occur in that function.
TODO: fusion, operators
-Saving Programs
----------------
+# Saving Programs
+
+
+## Python Printer
+
+[python_print.cpp](python_print.cpp)
+[import_source.cpp](import_source.cpp)
+
+The Python Printer takes a `Graph` and produces Python-like code that represents the same graph. Using some special values in [import_source.cpp](import_source.cpp), this code can be read back in by the compiler to produce the same `Graph`. In Python a `ScriptModule`'s `code` property shows the Python Printed graph.
+
+The table below shows the graph and code for this small `ScriptModule`:
+```python
+class M(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x, y, z):
+ # type: (Tensor, int, float) -> Tensor
+ if y > 2:
+ x = x + z
+ else:
+ x = x + y
+ return x
+
+m = M()
+```
+
+`m.graph`
+```
+graph(%x.1 : Tensor,
+ %y : int,
+ %z : float):
+ %5 : int = prim::Constant[value=1]()
+ %3 : int = prim::Constant[value=2]()
+ %4 : bool = aten::gt(%y, %3)
+ %x : Tensor = prim::If(%4)
+ block0():
+ %x.2 : Tensor = aten::add(%x.1, %z, %5)
+ -> (%x.2)
+ block1():
+ %x.3 : Tensor = aten::add(%x.1, %y, %5)
+ -> (%x.3)
+ return (%x)
+```
+
+`m.code`
+```python
+def forward(self,
+ x: Tensor,
+ y: int,
+ z: float) -> Tensor:
+ if torch.gt(y, 2):
+ x0 = torch.add(x, z, 1)
+ else:
+ x0 = torch.add(x, y, 1)
+ return x0
+```
+
+## Serialization
+
+[export.cpp](export.cpp)
+[pickler.cpp](pickler.cpp)
+[import.cpp](import.cpp)
+[caffe2/proto/torch.proto](../../../caffe2/proto/torch.proto)
+
+
+TorchScript programs are serialized with a call to `torch.jit.save()`. The resulting file (ending in `.pt` by convention) can be loaded/executed in C++ and Python.
+
+### Overview
+
+The `.pt` file is a zip archive (which can be opened with tools such as `unzip`) and contains:
+ * code - the Python printed graph of a module
+ * `model.json` - a JSON file of a model Protobuf def (defined in [torch.proto](caffe2/proto/torch.proto))
+ * `tensors/` - each of the tensors of the model, with their tensor storage stored directly in a file
+ * `attributes.pkl` - a Python `pickle` archive of the attributes of a module
+
+### `model.json`
+The `model.json` contains the structure information of the model. Each model must contain one main Module, and each module may contain multiple submodules, and each module contains a bunch of parameters (tensors). We serialize the metadata for each tensor inline in `model.json` (e.g., dims, strides, record name, etc).
+
+### `code/`
+
+The `code` directory contains the Python Printed `Graph`s of the main module and its submodules.
+
+### `tensors/`
+
+During export a list of all the tensors in a model is created. Tensors can come from either module parameters or Tensor type attributes. Metadata about each tensor is stored in `model.json` with an index into this list. The `data` field refers to the file which contains the tensor storage data. Tensors are saved by directly writing the Tensor storage to a file.
+
+`model.json`
+```json
+{
+ ...
+ "tensors": [
+ {
+ "dims": [
+ "40",
+ "800"
+ ],
+ "offset": "0",
+ "strides": [
+ "800",
+ "1"
+ ],
+ "requiresGrad": true,
+ "dataType": "FLOAT",
+ "data": {
+ "key": "tensors/0"
+ },
+ "device": "cpu"
+ }
+ ],
+ ...
+}
+```
+
+### `attributes.pkl`
+
+Attributes are all module properties that are not parameters or constants. Attributes are saved in a list in the order they were defined on the module. The list is stored as a Python `pickle` archive. `pickle`'s format was chosen due to:
+* **user friendliness** - the attributes file can be loaded in Python with `pickle` without having PyTorch installed
+* **size limits** - formats such as Protobuf empose size limits on total message size, whereas pickle limits are on individual values (e.g. strings cannot be longer than 4 GB)
+* **standard format** - `pickle` is a standard Python module with a reasonably simple format. The format is a program to be consumed by a stack machine that is detailed in Python's [`pickletools.py`](https://svn.python.org/projects/python/trunk/Lib/pickletools.py)
+* **built-in memoization** - for shared reference types (e.g. Tensor, string, lists, dicts)
+* **self describing** - a separate definition file is not needed to understand the pickled data
+* **eager mode save** - `torch.save()` already produces a `pickle` archive, so doing the same with attributes may ease unification of these formats in the future
+
+A given module may have many attributes of different types and many submodules, each with their own attributes. Attributes are recorded in `model.json`:
+* `type` - the full type of the attribute (in [Mypy syntax](https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html))
+* `name` - the attribute's name
+* `id` - the offset into the saved list of all model attributes
+
+`model.json`
+```json
+{
+ "mainModule": {
+ "submodules": [
+ {
+ ...
+ "attributes": [
+ {
+ "type": "Dict[str, str]",
+ "name": "my_submodule_dictionary",
+ "id": 0
+ },
+ {
+ "type": "List[Tuple[int, int]]",
+ "name": "my_submodule_list",
+ "id": 1
+ }
+ ]
+ ...
+ },
+ ...
+ ],
+ ...
+ "attributes": [
+ {
+ "type": "Dict[str, str]",
+ "name": "my_main_module_dictionary",
+ "id": 2
+ },
+ {
+ "type": "Tensor",
+ "name": "my_main_module_tensor",
+ "id": 3
+ }
+ ]
+ ...
+ },
+}
+```
+
+Attributes of the main module and its submodules are saved to a single file in the `zip` archive of a `.pt` file named `attributes.pkl`. A single file is used so that attributes can reference each other and shared values. Unpickling this will return a list of values corresponding to the attributes.
+
+All attributes are written into the `attributes.pkl` file with the exception of tensors, which store only a tensor table index (see "tensors" above). Classes are used to mark special data types, such as this tensor table index or specialized lists. To load the `attributes.pkl` file without PyTorch for inspection or manual editing, these classes must be defined, so a custom [`Unpickler`](https://docs.python.org/3/library/pickle.html#pickle.Unpickler) is necessary:
+
+```python
+import pickle
+
+# Tensor objects are stored as instances of this class
+class TensorID(object):
+ def __setstate__(self, id):
+ self.id = id
+
+# List[int] has internal specializations, and these are indicated with this class
+class IntList(object):
+ def __setstate__(self, data):
+ self.data = data
+
+class JitUnpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if not module == '__main__':
+ return None
+
+ if name == 'TensorID':
+ return TensorID
+ elif name == 'IntList':
+ return IntList
+
+JitUnpickler(open("my_model/attributes.pkl", "rb")).load()
+```
+
+#### Binary Format
+
+Running the following snippet produces a `ScriptModule` with several attributes.
+
+```python
+class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.float = torch.jit.Attribute(2.3, float)
+ self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
+ self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
+ self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
+
+ @torch.jit.script_method
+ def forward(self):
+ return (self.float, self.tuple, self.tensor, self.int_list)
+
+M().save("out.pt")
+```
+
+In a terminal, Python's `pickletools` module can be used to decode the binary blob of `attributes.pkl` into a human readable format.
+
+```bash
+unzip -o out.pt
+python -m pickletools out/attributes.pkl
+```
+
+The output of the above commands demonstrates the concepts described earlier. Attributes are wrapped in with `2: EMPTY_LIST` and appear in the order they are defined on the module. Classes for certain special types (`List[int]`, `Tensor`) can be seen at `37: GLOBAL` and `66: GLOBAL`, followed by data specific to that type, then finally by an instruction to build the object at `65: BUILD` and `113: BUILD` respectively.
+```
+ 0: \x80 PROTO 2
+ 2: ] EMPTY_LIST
+ 3: ( MARK
+ 4: G BINFLOAT 2.3
+ 13: ( MARK
+ 14: J BININT 1
+ 19: J BININT 2
+ 24: J BININT 3
+ 29: J BININT 4
+ 34: t TUPLE (MARK at 13)
+ 35: q BINPUT 0
+ 37: c GLOBAL '__main__ TensorID'
+ 56: q BINPUT 1
+ 58: ) EMPTY_TUPLE
+ 59: \x81 NEWOBJ
+ 60: J BININT 0
+ 65: b BUILD
+ 66: c GLOBAL '__main__ IntList'
+ 84: q BINPUT 2
+ 86: ) EMPTY_TUPLE
+ 87: \x81 NEWOBJ
+ 88: ] EMPTY_LIST
+ 89: q BINPUT 3
+ 91: ( MARK
+ 92: J BININT 1
+ 97: J BININT 2
+ 102: J BININT 3
+ 107: J BININT 4
+ 112: e APPENDS (MARK at 91)
+ 113: b BUILD
+ 114: e APPENDS (MARK at 3)
+ 115: . STOP
+highest protocol among opcodes = 2
+```
+
+
+
+### Implementation Details
+
+[export.cpp](export.cpp) and [import.cpp](import.cpp) handle producing the proper protobuf definitions and (de)serializing tensor data. They use [pickler.h](pickler.cpp) which implements a subset of the `pickle` stack machine.
-TODO: python_print, serialization format
-Python Bindings
----------------
+# Python Bindings
TODO: Script Module, torch.jit.trace, __constant__ handling, weak script modules