~~~~~
The largest difference between TorchScript and the full Python language is that
-TorchScript only support a small set of types that are needed to express neural
-net models. In particular TorchScript supports:
+TorchScript only supports a small set of types that are needed to express neural
+net models. In particular, TorchScript supports:
``Tensor``
A PyTorch tensor of any dtype, dimension, or backend.
``Optional[T]``
A value which is either None or type ``T``
-```Dict[K, V]``
+``Dict[K, V]``
A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and
``float`` are allowed as key types.
# and type int in the false branch
-There are 2 scenarios in which you can annotate:
+Default Types
+^^^^^^^^^^^^^
-1. Function Argument Type Annotation
-
-By default, all parameters to a TorchScript function are assumed to be Tensor
-because this is the most common type used in modules. To specify that an
-argument to a TorchScript function is another type, it is possible to use
+By default, all parameters to a TorchScript function are assumed to be Tensor.
+To specify that an argument to a TorchScript function is another type, it is possible to use
MyPy-style type annotations using the types listed above:
Example::
In our examples, we use comment-based annotations to ensure Python 2
compatibility as well.
-
-2. Variable Type Annotation
-
-A list by default is assumed to be ``List[Tensor]`` and empty dicts
+An empty list is assumed to be ``List[Tensor]`` and empty dicts
``Dict[str, Tensor]``. To instantiate an empty list or dict of other types,
use ``torch.jit.annotate``.
for i in range(10):
list_of_tuple.append((x, x))
- # This annotates the list to be a `Dict[int, Tensor]`
+ # This annotates the list to be a `Dict[int, Tensor]`
int_tensor_dict = torch.jit.annotate(Dict[int, Tensor], {})
return list_of_tuple, int_tensor_dict
-Optional Type Refinement:
+Optional Type Refinement
+^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript will refine the type of a variable of type Optional[T] when
a comparison to None is made inside the conditional of an if statement.
The following Python Expressions are supported
Literals
+^^^^^^^^
``True``, ``False``, ``None``, ``'string literals'``, ``"string literals"``,
- number literals ``3`` (interpreted as int) ``3.4`` (interpreter as a float)
-
-Variables
- ``a``
-
- .. note::
- See `Variable Resolution`_ for how variables are resolved.
-
-Tuple Construction
- ``(3, 4)``, ``(3,)``
+ number literals ``3`` (interpreted as int) ``3.4`` (interpreted as a float)
List Construction
+"""""""""""""""""
``[3, 4]``, ``[]``, ``[torch.rand(3), torch.rand(4)]``
.. note::
an empty list is assumed have type ``List[Tensor]``.
The types of other list literals are derived from the type of the members.
+Tuple Construction
+""""""""""""""""""
+ ``(3, 4)``, ``(3,)``
+
+
Dict Construction
+"""""""""""""""""
``{'hello': 3}``, ``{}``, ``{'a': torch.rand(3), 'b': torch.rand(4)}``
.. note::
an empty dict is assumed have type ``Dict[str, Tensor]``.
The types of other dict literals are derived from the type of the members.
+Variables
+^^^^^^^^^
+ ``my_variable_name``
+
+ .. note::
+ See `Variable Resolution`_ for how variables are resolved.
+
+
Arithmetic Operators
+^^^^^^^^^^^^^^^^^^^^
``a + b``
+
``a - b``
+
``a * b``
+
``a / b``
+
``a ^ b``
+
``a @ b``
Comparison Operators
+^^^^^^^^^^^^^^^^^^^^
``a == b``
+
``a != b``
+
``a < b``
+
``a > b``
+
``a <= b``
+
``a >= b``
Logical Operators
+^^^^^^^^^^^^^^^^^
``a and b``
+
``a or b``
+
``not b``
Subscripts
+^^^^^^^^^^
``t[0]``
+
``t[-1]``
+
``t[0:2]``
+
``t[1:]``
+
``t[:1]``
+
``t[:]``
+
``t[0, 1]``
+
``t[0, 1:2]``
+
``t[0, :1]``
+
``t[-1, 1:, 0]``
+
``t[1:, -1, 0]``
+
``t[i:j, i]``
.. note::
TorchScript currently does not support mutating tensors in place, so any
tensor indexing can only appear on the right-hand size of an expression.
-Function calls
+Function Calls
+^^^^^^^^^^^^^^
Calls to built-in functions: ``torch.rand(3, dtype=torch.int)``
Calls to other script functions:
def bar(x):
return foo(x)
-Method calls
+Method Calls
+^^^^^^^^^^^^
Calls to methods of builtin types like tensor: ``x.mm(y)``
def forward(self, input):
return self.helper(input)
-If expressions
+Ternary Expressions
+^^^^^^^^^^^^^^^^^^^
``x if x > y else y``
Casts
- ``float(ten)``, ``int(3.5)``, ``bool(ten)``
+^^^^^
+ ``float(ten)``
+
+ ``int(3.5)``
+
+ ``bool(ten)``
Accessing Module Parameters
- ``self.my_parameter`` ``self.my_submodule.my_parameter``
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ ``self.my_parameter``
+
+ ``self.my_submodule.my_parameter``
Statements
TorchScript supports the following types of statements:
Simple Assignments
-
::
a = b
a -= b
Pattern Matching Assignments
-
::
a, b = tuple_or_list
on the dynamic type of the python valued referenced.
Functions
- TorchScript can call python functions. This functionality is very useful when
+^^^^^^^^^
+
+ TorchScript can call Python functions. This functionality is very useful when
incrementally converting a model into script. The model can be moved function-by-function
to script, leaving calls to Python functions in place. This way you can incrementally
check the correctness of the model as you go.
Attribute Lookup On Python Modules
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript can lookup attributes on modules. Builtin functions like ``torch.add``
are accessed this way. This allows TorchScript to call functions defined in
other modules.
Python-defined Constants
+^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript also provides a way to use constants that are defined in Python.
These can be used to hard-code hyper-parameters into the function, or to
define universal constants. There are two ways of specifying that a Python
~~~~~~~~~
Disable JIT for Debugging
+^^^^^^^^^^^^^^^^^^^^^^^^^
If you want to disable all JIT modes (tracing and scripting) so you can
debug your program in raw Python, you can use the ``PYTORCH_JIT`` environment
variable. ``PYTORCH_JIT`` can be used to globally disable the
Interpreting Graphs
+^^^^^^^^^^^^^^^^^^^
TorchScript uses a static single assignment (SSA) intermediate representation
(IR) to represent computation. The instructions in this format consist of
ATen (the C++ backend of PyTorch) operators and other primitive operators,
Tracing Edge Cases
+^^^^^^^^^^^^^^^^^^
There are some edge cases that exist where the trace of a given Python
function/module will not be representative of the underlying code. These
cases can include:
Automatic Trace Checking
+^^^^^^^^^^^^^^^^^^^^^^^^
One way to automatically catch many errors in traces is by using ``check_inputs``
on the ``torch.jit.trace()`` API. ``check_inputs`` takes a list of tuples
of inputs that will be used to re-trace the computation and verify the
}
Tracer Warnings
+^^^^^^^^^^^^^^^
The tracer produces warnings for several problematic patterns in traced
computation. As an example, take a trace of a function that contains an
in-place assignment on a slice (a view) of a Tensor::
print(traced.graph)
-Builtin Functions
-~~~~~~~~~~~~~~~~~
-
-TorchScript supports a subset of the builtin tensor and neural network
-functions that PyTorch provides. Most methods on Tensor as well as functions in
-the ``torch`` namespace, all functions in ``torch.nn.functional`` and all
-modules from ``torch.nn`` are supported in TorchScript, excluding those in the
-table below. For unsupported modules, we suggest using :meth:`torch.jit.trace`.
-
-Unsupported ``torch.nn`` Modules ::
-
- torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
- torch.nn.modules.normalization.CrossMapLRN2d
- torch.nn.modules.fold.Fold
- torch.nn.modules.fold.Unfold
- torch.nn.modules.rnn.GRU
- torch.nn.modules.rnn.LSTM
- torch.nn.modules.rnn.RNN
- torch.nn.modules.rnn.GRUCell
- torch.nn.modules.rnn.LSTMCell
- torch.nn.modules.rnn.RNNCell
-
-
-.. automodule:: torch.jit.supported_ops
-
Frequently Asked Questions
--------------------------
specific device, so casting an already-loaded model may have unexpected
effects. Casting the model *before* saving it ensures that the tracer has
the correct device information.
+
+
+Builtin Functions
+~~~~~~~~~~~~~~~~~
+
+TorchScript supports a subset of the builtin tensor and neural network
+functions that PyTorch provides. Most methods on Tensor as well as functions in
+the ``torch`` namespace, all functions in ``torch.nn.functional`` and all
+modules from ``torch.nn`` are supported in TorchScript, excluding those in the
+table below. For unsupported modules, we suggest using :meth:`torch.jit.trace`.
+
+Unsupported ``torch.nn`` Modules ::
+
+ torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss
+ torch.nn.modules.normalization.CrossMapLRN2d
+ torch.nn.modules.fold.Fold
+ torch.nn.modules.fold.Unfold
+ torch.nn.modules.rnn.GRU
+ torch.nn.modules.rnn.RNN
+
+
+.. automodule:: torch.jit.supported_ops