--- /dev/null
+torch.fx._symbolic_trace.ProxyableClassMeta []
+torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace']
+torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'flatten_inps', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'output', 'owning_module', 'placeholder', 'print_tabular', 'python_code', 'unflatten_outs']
+torch.fx.graph.PythonCode []
+torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder']
+torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update']
+torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove']
+torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node']
+torch.fx.interpreter.Transformer ['call_function', 'call_module', 'get_attr', 'placeholder', 'transform']
+torch.fx.node.Node ['all_input_nodes', 'append', 'args', 'format_node', 'is_impure', 'kwargs', 'next', 'normalized_arguments', 'prepend', 'prev', 'replace_all_uses_with', 'replace_input_with', 'stack_trace', 'update_arg', 'update_kwarg']
+torch.fx.passes.shape_prop.ShapeProp ['propagate', 'run_node']
+torch.fx.passes.shape_prop.TensorMetadata ['dtype', 'is_quantized', 'memory_format', 'q_scale', 'q_zero_point', 'qscheme', 'requires_grad', 'shape', 'stride']
+torch.fx.passes.split_module.Partition []
+torch.fx.proxy.Attribute ['node']
+torch.fx.proxy.GraphAppendingTracer []
+torch.fx.proxy.Proxy ['keys']
+torch.fx.proxy.TraceError []
+torch.fx.proxy.TracerBase ['create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'record_stack_traces', 'to_bool']
+torch.fx.subgraph_rewriter.Match ['anchor', 'nodes_map']
\ No newline at end of file
--- /dev/null
+torch.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (<module math>,), autowrap_functions: Tuple[Callable, ...] = (,), enable_cpatching: bool = False, param_shapes_constant: bool = False) -> None
+torch.fx._symbolic_trace.Tracer.call_module(self, m: torch.nn.modules.module.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx._symbolic_trace.Tracer.create_arg(self, a: Any) -> 'Argument'
+torch.fx._symbolic_trace.Tracer.is_leaf_module(self, m: torch.nn.modules.module.Module, module_qualified_name: str) -> bool
+torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.module.Module) -> str
+torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph
+torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, enable_cpatching: bool = False) -> torch.fx.graph_module.GraphModule
+torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable])
+torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None)
+torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.eliminate_dead_code(self)
+torch.fx.graph.Graph.erase_node(self, to_erase: torch.fx.node.Node) -> None
+torch.fx.graph.Graph.get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: Dict[torch.fx.node.Node, torch.fx.node.Node], return_output_node = False) -> 'Optional[Argument]'
+torch.fx.graph.Graph.inserting_after(self, n: Optional[torch.fx.node.Node] = None)
+torch.fx.graph.Graph.inserting_before(self, n: Optional[torch.fx.node.Node] = None)
+torch.fx.graph.Graph.lint(self)
+torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Callable[[torch.fx.node.Node], Argument] = <function <lambda>>) -> torch.fx.node.Node
+torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
+torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.print_tabular(self)
+torch.fx.graph.Graph.python_code(self, root_module: str) -> torch.fx.graph.PythonCode
+torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
+torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
+torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
+torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool
+torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode
+torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
+torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: torch.fx.node.Node) -> Tuple[Tuple, Dict]
+torch.fx.interpreter.Interpreter.fetch_attr(self, target: str)
+torch.fx.interpreter.Interpreter.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.map_nodes_to_values(self, args: torch.fx.node.Argument, n: torch.fx.node.Node) -> torch.fx.node.Argument
+torch.fx.interpreter.Interpreter.output(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.run(self, *args, initial_env: Optional[Dict[torch.fx.node.Node, Any]] = None) -> Any
+torch.fx.interpreter.Interpreter.run_node(self, n: torch.fx.node.Node) -> Any
+torch.fx.interpreter.Transformer.__init__(self, module)
+torch.fx.interpreter.Transformer.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Transformer.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Transformer.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy
+torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy
+torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule
+torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None
+torch.fx.node.Node.append(self, x: 'Node') -> None
+torch.fx.node.Node.format_node(self, placeholder_names: List[str] = None, maybe_return_typename: List[str] = None) -> Optional[str]
+torch.fx.node.Node.prepend(self, x: 'Node') -> None
+torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node') -> List[Node]
+torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node')
+torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
+torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
+torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
+torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
+torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int])
+torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
+torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
+torch.fx.proxy.Proxy.keys(self)
+torch.fx.proxy.TracerBase.create_arg(self, a: Any) -> torch.fx.node.Argument
+torch.fx.proxy.TracerBase.create_node(self, kind: str, target: torch.fx.node.Target, args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, torch.fx.node.Argument], name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.proxy.TracerBase.create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[torch.fx.node.Node], Proxy] = None)
+torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator
+torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any
+torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy'
+torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool
+torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match]
\ No newline at end of file
import sys
import torch
import traceback
+import typing
+import types
import warnings
import unittest
from math import sqrt
from collections import namedtuple
from torch.fx.proxy import TraceError
+from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMATIBLITY
from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
from fx.test_dce_pass import TestDCE # noqa: F401
assert op.name in known_no_schema or "nn.functional" in op.name
+class TestFXAPIBackwardCompatibility(JitTestCase):
+ def setUp(self):
+ self.maxDiff = None
+
+ def _fn_to_stable_annotation_str(self, obj):
+ """
+ Unfortunately we have to serialize function signatures manually since
+ serialization for `inspect.Signature` objects is not stable across
+ python versions
+ """
+ fn_name = torch.typename(obj)
+
+ signature = inspect.signature(obj)
+
+ sig_str = f'{fn_name}{signature}'
+
+ arg_strs = []
+ for k, v in signature.parameters.items():
+ maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
+ if v.annotation is not inspect.Signature.empty else ''
+
+ def default_val_str(val):
+ if isinstance(val, (tuple, list)):
+ str_pieces = ['(' if isinstance(val, tuple) else '[']
+ str_pieces.append(', '.join(default_val_str(v) for v in val))
+ if isinstance(val, tuple) and len(str_pieces) == 2:
+ str_pieces.append(',')
+ str_pieces.append(')' if isinstance(val, tuple) else ']')
+ return ''.join(str_pieces)
+
+ # Need to fix up some default value strings.
+ # First case: modules. Default module `repr` contains the FS path of the module.
+ # Don't leak that
+ if isinstance(val, types.ModuleType):
+ return f'<module {val.__name__}>'
+
+ # Second case: callables. Callables (such as lambdas) encode their address in
+ # their string repr. Don't do that
+ if callable(val):
+ return f'<function {val.__name__}>'
+
+ return str(val)
+
+ if v.default is not inspect.Signature.empty:
+ default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
+ maybe_default = f' = {default_val_str}'
+ else:
+ maybe_default = ''
+ maybe_stars = ''
+ if v.kind == inspect.Parameter.VAR_POSITIONAL:
+ maybe_stars = '*'
+ elif v.kind == inspect.Parameter.VAR_KEYWORD:
+ maybe_stars = '**'
+ arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
+
+ return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
+ if signature.return_annotation is not inspect.Signature.empty else ''
+
+ return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
+
+ def _annotation_type_to_stable_str(self, t, sig_str):
+ if t is inspect.Signature.empty:
+ return ''
+
+ # Forward ref
+ if isinstance(t, str):
+ return f"'{t}'"
+ if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
+ return t.__forward_arg__
+ if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
+ return t.__forward_arg__
+
+ trivial_mappings = {
+ str : 'str',
+ int : 'int',
+ float: 'float',
+ bool: 'bool',
+ torch.dtype: 'torch.dtype',
+ torch.Tensor: 'torch.Tensor',
+ torch.device: 'torch.device',
+ torch.memory_format: 'torch.memory_format',
+ slice: 'slice',
+ torch.nn.Module: 'torch.nn.modules.module.Module',
+ torch.fx.Graph : 'torch.fx.graph.Graph',
+ torch.fx.Node : 'torch.fx.node.Node',
+ torch.fx.Proxy : 'torch.fx.proxy.Proxy',
+ torch.fx.node.Target : 'torch.fx.node.Target',
+ torch.fx.node.Argument : 'torch.fx.node.Argument',
+ torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
+ torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
+ torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
+ Ellipsis : '...',
+ typing.Any: 'Any',
+ type(None): 'NoneType',
+ None: 'None',
+ typing.Iterator: 'Iterator',
+ }
+
+ mapping = trivial_mappings.get(t, None)
+ if mapping:
+ return mapping
+
+ # Handle types with contained types
+ contained = getattr(t, '__args__', None) or []
+
+ # Callables contain a bare List for arguments
+ contained = t if isinstance(t, list) else contained
+
+ # Python 3.8 puts type vars into __args__ for unbound types such as Dict
+ if all(isinstance(ct, typing.TypeVar) for ct in contained):
+ contained = []
+
+ contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
+ contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
+
+
+ origin = getattr(t, '__origin__', None)
+ if origin is None:
+ # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
+ origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
+
+ if origin in {tuple, typing.Tuple}:
+ return f'Tuple{contained_type_str}'
+ if origin in {typing.Union}:
+ # Annoying hack to detect Optional
+ if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
+ not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
+ return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
+ return f'Union{contained_type_str}'
+ if origin in {dict, typing.Dict}:
+ return f'Dict{contained_type_str}'
+ if origin in {list, typing.List}:
+ return f'List{contained_type_str}'
+ if origin in {type, typing.Type}:
+ return f'Type{contained_type_str}'
+ if isinstance(t, typing.Callable):
+ if len(contained) > 0 and contained[0] is not Ellipsis:
+ return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
+ else:
+ return f'Callable{contained_type_str}'
+
+ raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
+ f'Please add support for this type and confirm with the '
+ f'FX team that your signature change is valid.')
+
+
+ def test_function_back_compat(self):
+ """
+ Test backward compatibility for function signatures with
+ @compatibility(is_backward_compatible=True). Currently this checks for
+ exact signature matches, which may lead to false positives. If this
+ becomes too annoying, we can refine this check to actually parse out
+ the saved schema strings and check if the change is truly backward-
+ incompatible.
+ """
+ signature_strs = []
+
+ for obj in _BACK_COMPAT_OBJECTS:
+ if not isinstance(obj, type):
+ signature_strs.append(self._fn_to_stable_annotation_str(obj))
+
+ signature_strs.sort()
+
+ try:
+ self.assertExpected('\n'.join(signature_strs), 'fx_backcompat_function_signatures')
+ except AssertionError as e:
+ msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
+ f"as backwards-compatible has experienced a signature change. See the " \
+ f"above exception context for more information. If this change was " \
+ f"unintended, please revert it. If it was intended, check with the FX " \
+ f"team to ensure that the proper deprecation protocols have been followed " \
+ f"and subsequently --accept the change."
+ raise AssertionError(msg)
+
+ def test_class_member_back_compat(self):
+ """
+ Test backward compatibility for members of classes with
+ @compatibility(is_backward_compatible=True). Currently this checks for
+ exact matches on the publicly visible members of the class.
+ """
+ class_method_strs = []
+
+ for obj in _BACK_COMPAT_OBJECTS:
+ if isinstance(obj, type):
+ public_members = [name for name in obj.__dict__ if not name.startswith('_')]
+ class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
+
+ class_method_strs.sort()
+
+ try:
+ self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
+ except AssertionError as e:
+ msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
+ f"as backwards-compatible has experienced change in its public members. See the " \
+ f"above exception context for more information. If this change was " \
+ f"unintended, please revert it. If it was intended, check with the FX " \
+ f"team to ensure that the proper deprecation protocols have been followed " \
+ f"and subsequently --accept the change."
+ raise AssertionError(msg)
+
+ def test_public_api_surface(self):
+ mod = torch.fx
+
+ non_back_compat_objects = {}
+
+ def check_symbols_have_bc_designation(m, prefix):
+ if not m.__name__.startswith('torch.fx'):
+ return
+ if m.__name__.startswith('torch.fx.experimental'):
+ return
+ for k, v in m.__dict__.items():
+ if v is m:
+ continue
+ if k.startswith('_'):
+ continue
+ if isinstance(v, types.ModuleType):
+ check_symbols_have_bc_designation(v, prefix + [k])
+ elif isinstance(v, type) or isinstance(v, types.FunctionType):
+ if v not in _MARKED_WITH_COMATIBLITY:
+ non_back_compat_objects.setdefault(v)
+
+ check_symbols_have_bc_designation(mod, ['torch', 'fx'])
+
+
+ non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
+ # Only want objects in torch.fx
+ non_back_compat_strs = [
+ s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
+ # Only want objects in public namespaces
+ non_back_compat_strs = [
+ s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
+ non_back_compat_strs.sort()
+
+ if len(non_back_compat_strs) != 0:
+ raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
+ f"backwards-compatibility classification! Please decorate these "
+ f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
+ f"BC guarantees.")
+
class TestFunctionalTracing(JitTestCase):
IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
"has_torch_function_variadic", "handle_torch_function",
type_matches,
create_type_hint,
)
-from torch.fx.passes.shape_prop import extract_tensor_metadata, ShapeProp
+from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
from torch.fx.passes.split_module import split_module
from torch.testing._internal.common_device_type import (
ops,
# Fix for now to add type/shape to output
for node in traced.graph.nodes:
if node.op == "output":
- node.meta["tensor_meta"] = extract_tensor_metadata(a)
+ node.meta["tensor_meta"] = _extract_tensor_metadata(a)
for mod in module_with_submodules.modules():
if isinstance(mod, GraphModule):
for node in mod.graph.nodes:
- node.meta["tensor_meta"] = extract_tensor_metadata(a)
+ node.meta["tensor_meta"] = _extract_tensor_metadata(a)
for node in module_with_submodules.graph.nodes:
- node.meta["tensor_meta"] = extract_tensor_metadata(a)
+ node.meta["tensor_meta"] = _extract_tensor_metadata(a)
weights1 = {}
weights2 = {}
r'''
-**This feature is under a Beta release and its API may change.**
-
FX is a toolkit for developers to use to transform ``nn.Module``
instances. FX consists of three main components: a **symbolic tracer,**
an **intermediate representation**, and **Python code generation**. A
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
- graph(x):
- %param : [#users=1] = self.param
- %add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {})
- %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
- %clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0})
- return clamp_1
+ graph():
+ %x : [#users=1] = placeholder[target=x]
+ %param : [#users=1] = get_attr[target=param]
+ %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
+ %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
+ %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
+ return clamp
"""
# Code generation - valid Python code
"""
def forward(self, x):
param = self.param
- add_1 = x + param; x = param = None
- linear_1 = self.linear(add_1); add_1 = None
- clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None
- return clamp_1
+ add = x + param; x = param = None
+ linear = self.linear(add); add = None
+ clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
+ return clamp
"""
The **symbolic tracer** performs "symbolic execution" of the Python
--- /dev/null
+from typing import Any, Dict
+import textwrap
+
+_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
+_MARKED_WITH_COMATIBLITY : Dict[Any, None] = {}
+
+def compatibility(is_backward_compatible : bool):
+ if is_backward_compatible:
+
+ def mark_back_compat(fn):
+ docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
+ docstring += """
+.. note::
+ Backwards-compatibility for this API is guaranteed.
+"""
+ fn.__doc__ = docstring
+ _BACK_COMPAT_OBJECTS.setdefault(fn)
+ _MARKED_WITH_COMATIBLITY.setdefault(fn)
+ return fn
+
+ return mark_back_compat
+ else:
+
+ def mark_not_back_compat(fn):
+ docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
+ docstring += """
+.. warning::
+ This API is experimental and is *NOT* backward-compatible.
+"""
+ fn.__doc__ = docstring
+ _MARKED_WITH_COMATIBLITY.setdefault(fn)
+ return fn
+
+ return mark_not_back_compat
import torch.utils._pytree as pytree
import sys
+from ._compatibility import compatibility
from .node import Argument, map_aggregate, base_types
from .graph import Graph, _PyTreeInfo
from .graph_module import GraphModule
_proxyable_classes : Dict[Type, None] = {}
+@compatibility(is_backward_compatible=True)
class ProxyableClassMeta(type):
"""
ProxyableClassMeta allows you to make construction of a given Python class
def __exit__(self, type, value, tb):
sys.setprofile(None)
+@compatibility(is_backward_compatible=False)
class PHBase(object):
"""
Object representing an input placeholder to `concrete_args`
PH = PHBase()
+@compatibility(is_backward_compatible=True)
class Tracer(TracerBase):
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for the
process. The different behaviors that can be overridden are described
in the docstrings of the methods on this class.
"""
+
+ # Not checking BC on this API because the default value for `autowrap_modules`
+ # includes the local filepath to the `math` module, which would jitter
+ # across machines.
+ @compatibility(is_backward_compatible=True)
def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ),
autowrap_functions: Tuple[Callable, ...] = (),
enable_cpatching: bool = False,
autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
Python modules whose functions should be wrapped automatically
- without needing to use fx.wrap().
+ without needing to use fx.wrap(). Backward-compatibility for
+ this parameter is guaranteed.
autowrap_function (Tuple[Callable, ...]): defaults to `()`,
Python functions that should be wrapped automatically without
- needing to use fx.wrap().
+ needing to use fx.wrap(). Backward compabilibility for this
+ parameter is guaranteed.
+
+ param_shapes_constant (bool): When this flag is set, calls to shape,
+ size and a few other shape like attributes of a module's parameter
+ will be evaluted directly, rather than returning a new Proxy value
+ for an attribute access. Backward compatibility for this parameter
+ is guaranteed.
enable_cpatching (bool): defaults to `False`,
Allows you to enable/disable monkeypatching of torch functions at the
C-level monkeypatching works by directly modifying the PyCFunctionObject*
so that calling it returns a different function.
- Turning this on is likely to slow down tracing by 1.5-3x.
-
- param_shapes_constant (bool): see https://github.com/pytorch/pytorch/issues/61733. When
- this flag is set, calls to shape, size and a few other shape like attributes of a module's parameter
- will be evaluted directly, rather than returning a new Proxy value for an attribute access.
-
+ Turning this on is likely to slow down tracing by 1.5-3x. This
+ parameter is experimental and its backward-compatibility is NOT
+ guaranteed.
"""
super().__init__()
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
+ @compatibility(is_backward_compatible=True)
def create_arg(self, a: Any) -> 'Argument':
"""
A method to specify the behavior of tracing when preparing values to
return super().create_arg(a)
+ @compatibility(is_backward_compatible=True)
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
"""
A method to specify whether a given ``nn.Module`` is a "leaf" module.
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
+ @compatibility(is_backward_compatible=True)
def path_of_module(self, mod : torch.nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy
return n
raise NameError('module is not installed as a submodule')
+ @compatibility(is_backward_compatible=True)
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
return forward(*args, **kwargs)
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
+ # This method will be refactored
+ @compatibility(is_backward_compatible=False)
def create_args_for_root(self, root_fn, is_module, concrete_args=None):
"""
Create ``placeholder`` nodes corresponding to the signature of the ``root``
return attr_val
-
- def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
+ @compatibility(is_backward_compatible=True)
+ def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
Args:
root (Union[Module, Callable]): Either a ``Module`` or a function to be
- traced through.
- concrete_args (Optional[Dict[str, any]]): Concrete arguments that should not be treated as Proxies.
+ traced through. Backwards-compatibility for this parameter is
+ guaranteed.
+ concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
+ not be treated as Proxies. This parameter is experimental and
+ its backwards-compatibility is *NOT* guaranteed.
Returns:
patcher.patch(frame_dict, name, _create_wrapped_func(value))
+@compatibility(is_backward_compatible=True)
def wrap(fn_or_name : Union[str, Callable]):
"""
This function can be called at module-level scope to register fn_or_name as a "leaf function".
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
return fn_or_name
-def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None,
+@compatibility(is_backward_compatible=True)
+def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
enable_cpatching: bool = False) -> GraphModule:
- """Symbolic tracing API
+ """
+ Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
Returns:
GraphModule: a Module created from the recorded operations from ``root``.
-
"""
tracer = Tracer(enable_cpatching=enable_cpatching)
graph = tracer.trace(root, concrete_args)
from torch.fx.proxy import Proxy
+from ._compatibility import compatibility
-
+@compatibility(is_backward_compatible=False)
def annotate(val, type):
# val could be either a regular value (not tracing)
# or fx.Proxy (tracing)
register_acc_op_mapping,
register_custom_acc_mapper_fn,
)
-from torch.fx.passes.shape_prop import extract_tensor_metadata
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
this_arg_is_optional = True
with node.graph.inserting_before(node):
# Insert get_attr nodes for weight and bias
get_weight = node.graph.get_attr(weight_name)
- get_weight.meta["tensor_meta"] = extract_tensor_metadata(linear_module.weight())
+ get_weight.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.weight())
get_bias = None
if linear_module.bias() is not None:
get_bias = node.graph.get_attr(bias_name)
- get_bias.meta["tensor_meta"] = extract_tensor_metadata(linear_module.bias())
+ get_bias.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.bias())
# Create kwargs for acc_op.quantized_linear
kwargs = {
with node.graph.inserting_before(node):
# Insert get_attr nodes for weight and bias
get_weight = node.graph.get_attr(weight_name)
- get_weight.meta["tensor_meta"] = extract_tensor_metadata(conv_module.weight())
+ get_weight.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.weight())
get_bias = None
if conv_module.bias() is not None:
get_bias = node.graph.get_attr(bias_name)
- get_bias.meta["tensor_meta"] = extract_tensor_metadata(conv_module.bias())
+ get_bias.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.bias())
# Create kwargs for acc_op.conv
kwargs = {
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
import torch.utils._pytree as pytree
from . import _pytree as fx_pytree
+from ._compatibility import compatibility
from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
from dataclasses import dataclass
return False
+@compatibility(is_backward_compatible=True)
@dataclass
class PythonCode:
- """Represents all the information necessary to exec or save a graph as Python code."""
+ """
+ Represents all the information necessary to exec or save a graph as Python code.
+ """
# Python source code for the forward function definition.
src: str
# Values in global scope during exection of `src_def`.
in_spec: pytree.TreeSpec
out_spec: Optional[pytree.TreeSpec]
+@compatibility(is_backward_compatible=True)
class Graph:
"""
``Graph`` is the main data structure used in the FX Intermediate Representation.
For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
"""
+
+ @compatibility(is_backward_compatible=True)
def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None):
"""
Construct an empty Graph.
@property
def owning_module(self):
+ """
+ Return the module that owns this ``GraphModule``, if there is one,
+ ``None`` if there is no owning module or if there are multiple owning
+ modules.
+ """
return self._owning_module
@owning_module.setter
"""
return _node_list(self)
+ @compatibility(is_backward_compatible=True)
def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
"""
Copy all nodes from a given graph into ``self``.
from the default implementation. This uses graph_copy to copy the nodes
in an iterative way, rather than recursive. It also populates the
memoization table to prevent unnecessary copies (e.g. references to
- nodes or other parts of the Graph from a custom GraphModule implementation
+ nodes or other parts of the Graph from a custom GraphModule implementation.
"""
memo = memo if memo else {}
g = Graph(tracer_cls=self._tracer_cls)
g.output(output_val, type_expr=getattr(old_output_val, 'type', None))
return g
+ @compatibility(is_backward_compatible=True)
def create_node(self, op: str, target: 'Target',
args: Optional[Tuple['Argument', ...]] = None,
kwargs: Optional[Dict[str, 'Argument']] = None,
self._len += 1
return n
+ @compatibility(is_backward_compatible=False)
def flatten_inps(self, *args):
flat_args, args_spec = pytree.tree_flatten(args)
return flat_args
+ @compatibility(is_backward_compatible=False)
def unflatten_outs(self, out):
if self._pytree_info is None:
return out
assert(self._pytree_info.out_spec is not None)
return pytree.tree_unflatten(out, self._pytree_info.out_spec)
+ @compatibility(is_backward_compatible=True)
def erase_node(self, to_erase : Node) -> None:
"""
Erases a ``Node`` from the ``Graph``. Throws an exception if
assert isinstance(new_kwargs, dict)
to_erase.kwargs = new_kwargs
+ @compatibility(is_backward_compatible=True)
def inserting_before(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and
assert n.graph == self, "Node to insert before is not in graph."
return _InsertPoint(self, n.prepend)
+ @compatibility(is_backward_compatible=True)
def inserting_after(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and
assert n.graph == self, "Node to insert after is not in graph."
return _InsertPoint(self, n.append)
- # sugar for create_node when you know the op
+ @compatibility(is_backward_compatible=True)
def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
"""
Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
"""
return self.create_node('placeholder', name, type_expr=type_expr)
+ @compatibility(is_backward_compatible=True)
def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
"""
Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
"necessary buffer")
return self.create_node('get_attr', qualified_name, type_expr=type_expr)
+ @compatibility(is_backward_compatible=True)
def call_module(self,
module_name: str,
args: Optional[Tuple['Argument', ...]] = None,
"necessary submodule")
return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
+ @compatibility(is_backward_compatible=True)
def call_method(self,
method_name: str,
args: Optional[Tuple['Argument', ...]] = None,
"""
return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
+ @compatibility(is_backward_compatible=True)
def call_function(self,
the_function: Callable[..., Any],
args: Optional[Tuple['Argument', ...]] = None,
"""
return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
+ @compatibility(is_backward_compatible=True)
def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
"""
Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
result_node.meta = copy.copy(node.meta)
return result_node
+ @compatibility(is_backward_compatible=True)
def output(self, result: 'Argument', type_expr: Optional[Any] = None):
"""
Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
op = _snake_case(op)
return op
+ @compatibility(is_backward_compatible=True)
def python_code(self, root_module: str) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
def __str__(self) -> str:
"""
- Print a human-readable (not machine-readable) string representation
+ Return a human-readable (not machine-readable) string representation
of this Graph
"""
placeholder_names : List[str] = []
s += '\n ' + node_str
return s
+ @compatibility(is_backward_compatible=True)
def print_tabular(self):
"""
Prints the intermediate representation of the graph in tabular
- format.
+ format. Note that this API requires the ``tabulate`` module to be
+ installed.
"""
try:
from tabulate import tabulate
print(tabulate(node_specs,
headers=['opcode', 'name', 'target', 'args', 'kwargs']))
+ @compatibility(is_backward_compatible=True)
def lint(self):
"""
Runs various checks on this Graph to make sure it is well-formed. In
else:
m_itr = new_m_itr
+ @compatibility(is_backward_compatible=True)
def eliminate_dead_code(self):
"""
Remove all dead code from the graph, based on each node's number of
def forward(self, x):
return x + self.attr_1
-
"""
# Lint the graph first to make sure its topologically sorted, otherwise
# DCE below will not behave as expected.
import linecache
from typing import Type, Dict, List, Any, Union, Optional, Set
from .graph import Graph, _is_from_torch, _custom_builtins, PythonCode
+from ._compatibility import compatibility
from torch.package import Importer, sys_importer
import copy
import itertools
# Normal exec loses the source code, however we can work with
# the linecache module to recover it.
-# Using exec_with_source will add it to our local cache
+# Using _exec_with_source will add it to our local cache
# and then tools like TorchScript will be able to get source info.
-class EvalCacheLoader(object):
+class _EvalCacheLoader(object):
def __init__(self):
self.eval_cache = {}
self.next_id = 0
self.next_id += 1
return key
-_loader = EvalCacheLoader()
+_loader = _EvalCacheLoader()
-def exec_with_source(src: str, globals: Dict[str, Any]):
+def _exec_with_source(src: str, globals: Dict[str, Any]):
key = _loader.cache(src, globals)
exec(compile(src, key, 'exec'), globals)
def _forward_from_src(src: str, globals: Dict[str, Any]):
# avoid mutating the passed in dict
globals_copy = globals.copy()
- exec_with_source(src, globals_copy)
+ _exec_with_source(src, globals_copy)
forward_fn = globals_copy['forward']
del globals_copy['forward']
return forward_fn
return '\n'.join(import_strs)
-def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
+def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
# BC: attribute name was changed from `code` to `_code` to facilitate
# making `code` into a property and adding a docstring to it
fn_src = body.get('_code') or body['code']
return _deserialize_graph_module(forward, body)
-def reduce_package_graph_module(
+def _reduce_package_graph_module(
importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
) -> torch.nn.Module:
forward = importer.import_module(generated_module_name).forward
return _deserialize_graph_module(forward, body)
-def reduce_deploy_graph_module(
+def _reduce_deploy_graph_module(
importer: PackageImporter, body: Dict[Any, Any], import_block: str
) -> torch.nn.Module:
ns = dict()
else:
setattr(to_module, field, from_obj)
+@compatibility(is_backward_compatible=True)
class GraphModule(torch.nn.Module):
"""
GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
regenerated. However, if you edit the contents of the ``graph`` without reassigning
the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
code.
-
"""
def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
# each instance of a graph module needs its own forward method
pass
return super().__new__(GraphModuleImpl)
+ @compatibility(is_backward_compatible=True)
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
error messages will report as originating from ``GraphModule``. It may be helpful to set this
to ``root``'s original name or a name that makes sense within the context of your transform.
-
"""
super().__init__()
self.__class__.__name__ = class_name
g.owning_module = self
self.recompile()
+ @compatibility(is_backward_compatible=False)
def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"):
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
imported with ``from <folder> import <module_name>``
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
+ @compatibility(is_backward_compatible=True)
def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
"""
Adds the given submodule to ``self``.
denoted by ``target`` must either a) not exist yet,
or b) reference an ``nn.Module`` (not a parameter or
other attribute)
-
"""
*prefix, field = target.split('.')
mod: torch.nn.Module = self
mod.add_module(field, m)
return True
+ @compatibility(is_backward_compatible=True)
def delete_submodule(self, target: str) -> bool:
"""
Deletes the given submodule from ``self``.
delattr(mod, target_submod)
return True
+ @compatibility(is_backward_compatible=True)
def delete_all_unused_submodules(self) -> None:
"""
Deletes all unused submodules from ``self``.
raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
return self._code
+ @compatibility(is_backward_compatible=True)
def recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, importer)
- return (reduce_deploy_graph_module, (dict_without_graph, import_block))
+ return (_reduce_deploy_graph_module, (dict_without_graph, import_block))
def __reduce_package__(self, exporter: PackageExporter):
dict_without_graph = self.__dict__.copy()
import_block = _format_import_block(python_code.globals, exporter.importer)
module_code = import_block + self.code
exporter.save_source_string(generated_module_name, module_code)
- return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
+ return (_reduce_package_graph_module, (dict_without_graph, generated_module_name))
def __reduce__(self):
"""
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, sys_importer)
del dict_without_graph['_graph']
- return (reduce_graph_module, (dict_without_graph, import_block))
+ return (_reduce_graph_module, (dict_without_graph, import_block))
# because __reduce__ is defined for serialization,
# we need to define deepcopy otherwise it will call __reduce__
+from ._compatibility import compatibility
_help_mutation = """\
If you are attempting to modify the kwargs or args of a torch.fx.Node object,
'clear', 'extend', 'insert', 'pop', 'remove'])
immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),))
+compatibility(is_backward_compatible=True)(immutable_list)
+
immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update'])
immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),))
+compatibility(is_backward_compatible=True)(immutable_dict)
from .node import Argument, Node, Target, map_arg, map_aggregate
from .proxy import Proxy
from ._symbolic_trace import Tracer
+from ._compatibility import compatibility
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+@compatibility(is_backward_compatible=True)
class Interpreter:
"""
An Interpreter executes an FX graph Node-by-Node. This pattern
execution. This can be disabled to, for example, examine all of the intermediate
values in the execution by looking at the ``Interpreter.env`` attribute.
"""
+ @compatibility(is_backward_compatible=True)
def __init__(self, module : GraphModule, garbage_collect_values : bool = True):
assert isinstance(module, GraphModule)
self.module = module
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
+ @compatibility(is_backward_compatible=True)
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any:
"""
Run `module` via interpretation and return the result.
output_val = self.env[node]
return output_val
+ @compatibility(is_backward_compatible=True)
def run_node(self, n : Node) -> Any:
"""
Run a specific node ``n`` and return the result.
return getattr(self, n.op)(n.target, args, kwargs)
# Main Node running APIs
-
+ @compatibility(is_backward_compatible=True)
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
else:
return next(self.args_iter)
+ @compatibility(is_backward_compatible=True)
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
assert isinstance(target, str)
return self.fetch_attr(target)
+ @compatibility(is_backward_compatible=True)
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
# Execute the function and return the result
return target(*args, **kwargs)
+ @compatibility(is_backward_compatible=True)
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
assert isinstance(target, str)
return getattr(self_obj, target)(*args_tail, **kwargs)
+ @compatibility(is_backward_compatible=True)
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the result.
return submod(*args, **kwargs)
+ @compatibility(is_backward_compatible=True)
def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
return args[0]
# Helper methods
-
+ @compatibility(is_backward_compatible=True)
def fetch_attr(self, target : str):
"""
Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
attr_itr = getattr(attr_itr, atom)
return attr_itr
+ @compatibility(is_backward_compatible=True)
def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
"""
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
assert isinstance(kwargs, dict)
return args, kwargs
+ @compatibility(is_backward_compatible=True)
def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
"""
Recursively descend through ``args`` and look up the concrete value
return self.env[n_arg]
return map_arg(args, load_arg)
+@compatibility(is_backward_compatible=True)
class Transformer(Interpreter):
"""
``Transformer`` is a special type of interpreter that produces a
Args:
module (GraphModule): The ``Module`` to be transformed.
"""
+
+ @compatibility(is_backward_compatible=True)
def __init__(self, module):
super().__init__(module)
self.new_graph = Graph()
self.tracer = TransformerTracer(self.new_graph)
self.tracer.root = module
+ @compatibility(is_backward_compatible=True)
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
"""
Execute a ``placeholder`` node. In ``Transformer``, this is
assert isinstance(target, str)
return Proxy(self.new_graph.placeholder(target), self.tracer)
+ @compatibility(is_backward_compatible=True)
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
"""
Execute a ``get_attr`` node. In ``Transformer``, this is
assert isinstance(target, str)
return Proxy(self.new_graph.get_attr(target), self.tracer)
+ @compatibility(is_backward_compatible=True)
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
# Override so that the leaf module policy from `self.tracer` is respected.
assert isinstance(target, str)
submod = self.fetch_attr(target)
return self.tracer.call_module(submod, submod.forward, args, kwargs)
+ @compatibility(is_backward_compatible=True)
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
# Override so that functions that were wrapped are still wrapped.
return self.tracer.create_proxy('call_function', target, args, kwargs)
+ @compatibility(is_backward_compatible=True)
def transform(self) -> GraphModule:
"""
Transform ``self.module`` and return the transformed
# Nodes represent a definition of a value in our graph of operators.
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
+from ._compatibility import compatibility
from .immutable_collections import immutable_dict, immutable_list
import torch
import builtins
else:
return str(arg)
+@compatibility(is_backward_compatible=True)
class Node:
"""
``Node`` is the data structure that represents individual operations within
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
in the Graph printout.
"""
+
+ @compatibility(is_backward_compatible=True)
def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
return_type : Optional[Any] = None) -> None:
+ """
+ Instantiate an instance of ``Node``. Note: most often, you want to use the
+ Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
+ than instantiating a ``Node`` directly.
+
+ Args:
+ graph (Graph): The ``Graph`` to which this ``Node`` should belong.
+
+ name (str): The name to which the output of this ``Node`` should be assigned
+
+ op (str): The opcode for this ``Node``. Can be one of 'placeholder',
+ 'call_method', 'call_module', 'call_function', 'get_attr',
+ 'output'
+
+ target ('Target'): The target this op should call. See the broader
+ ``Node`` docstring for more details.
+
+ args (Tuple['Argument']): The args to be passed to ``target``
+
+ kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target``
+
+ return_type (Optional[Any]): The python type expression representing the
+ type of the output of this node. This field can be used for
+ annotation of values in the generated code or for other types
+ of analyses.
+ """
self.graph = graph
self.name = name # unique name of value being created
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
"""
return self._prev
+ @compatibility(is_backward_compatible=True)
def prepend(self, x: 'Node') -> None:
"""
Insert x before this node in the list of nodes in the graph. Example::
p._next, x._prev = x, p
x._next, self._prev = self, x
+ @compatibility(is_backward_compatible=True)
def append(self, x: 'Node') -> None:
"""
Insert x after this node in the list of nodes in the graph.
"""
return list(self._input_nodes.keys())
+ @compatibility(is_backward_compatible=True)
def update_arg(self, idx : int, arg : Argument) -> None:
"""
Update an existing positional argument to contain the new value
args[idx] = arg
self.args = tuple(args)
+ @compatibility(is_backward_compatible=True)
def update_kwarg(self, key : str, arg : Argument) -> None:
"""
Update an existing keyword argument to contain the new value
return f'operator.{target.__name__}'
return _get_qualified_name(target)
+ @compatibility(is_backward_compatible=True)
def format_node(self,
placeholder_names: List[str] = None,
maybe_return_typename: List[str] = None) -> Optional[str]:
f'{self.op}[target={self._pretty_print_target(self.target)}](' \
f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
+ @compatibility(is_backward_compatible=True)
def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']:
"""
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
assert len(self.users) == 0
return to_process
+ @compatibility(is_backward_compatible=False)
def is_impure(self):
"""
Returns whether this op is impure, i.e. if its op is a placeholder or
return False
+ @compatibility(is_backward_compatible=False)
def normalized_arguments(
self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
kwarg_types : Optional[Dict[str, Any]] = None,
return None
-
+ @compatibility(is_backward_compatible=True)
def replace_input_with(self, old_input: 'Node', new_input: 'Node'):
"""
Loop through input nodes of ``self``, and replace all instances of
old_input (Node): The old input node to be replaced.
new_input (Node): The new input node to replace ``old_input``.
-
"""
def maybe_replace_node(n : Node) -> Node:
return new_input if n == old_input else n
self.__update_args_kwargs(new_args, new_kwargs)
+@compatibility(is_backward_compatible=True)
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
- """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
+ """
+ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
+ """
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
+@compatibility(is_backward_compatible=True)
def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
- """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
+ """
+ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
+ """
if isinstance(a, tuple):
return tuple(map_aggregate(elem, fn) for elem in a)
elif isinstance(a, list):
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast
from torch._jit_internal import boolean_dispatched
+from ._compatibility import compatibility
+@compatibility(is_backward_compatible=False)
class ArgsKwargsPair(NamedTuple):
"""
Simple named tuple for wrapping args/kwargs pairs.
return inspect.Signature(parameters, return_annotation=return_type)
+@compatibility(is_backward_compatible=False)
def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature]]:
"""
Given an operator on the `torch` namespace, return a list of `inspect.Signature`
return signatures
+@compatibility(is_backward_compatible=False)
def create_type_hint(x):
try:
if isinstance(x, list) or isinstance(x, tuple):
pass
return x
+@compatibility(is_backward_compatible=False)
def type_matches(signature_type : Any, argument_type : Any):
sig_origin_type = getattr(signature_type, '__origin__', signature_type)
return False
+@compatibility(is_backward_compatible=False)
def normalize_function(
target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
kwarg_types : Optional[Dict[str, Any]] = None,
return new_args_and_kwargs
+@compatibility(is_backward_compatible=False)
def normalize_module(
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
import torch.fx
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional
+from torch.fx._compatibility import compatibility
+@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.
q_scale : Optional[float]
q_zero_point : Optional[int]
-def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
+def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
return TensorMetadata(
shape, dtype, requires_grad, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point)
-
+@compatibility(is_backward_compatible=True)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
- return extract_tensor_metadata(obj)
+ return _extract_tensor_metadata(obj)
else:
return obj
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
+from torch.fx._compatibility import compatibility
+@compatibility(is_backward_compatible=True)
class Partition:
def __init__(self, name: str):
self.name: str = name
f" parition dependents: {self.partition_dependents}"
# Creates subgraphs out of main graph
+@compatibility(is_backward_compatible=True)
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
from .graph import magic_methods, reflectable_magic_methods, Graph
from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable
from .node import Target, Node, Argument, base_types, map_aggregate
+from ._compatibility import compatibility
+@compatibility(is_backward_compatible=True)
class TracerBase:
graph: Graph
record_stack_traces : bool = False
+ @compatibility(is_backward_compatible=True)
def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
"""
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
+ @compatibility(is_backward_compatible=True)
def proxy(self, node: Node) -> 'Proxy':
return Proxy(node, self)
-
-
+ @compatibility(is_backward_compatible=True)
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
return frame
+ @compatibility(is_backward_compatible=True)
def create_arg(self, a: Any) -> Argument:
"""
A method that lowers the objects seen as arguments during symbolic evaluation
raise NotImplementedError(f"argument of type: {type(a)}")
+ @compatibility(is_backward_compatible=True)
def to_bool(self, obj: 'Proxy') -> bool:
"""Called when a proxy object is being converted to a boolean, such as
when used in control flow. Normally we don't know what to do because
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
+ @compatibility(is_backward_compatible=True)
def iter(self, obj: 'Proxy') -> Iterator:
"""Called when a proxy object is being iterated over, such as
when used in control flow. Normally we don't know what to do because
' Proxy docstring for help troubleshooting '
'Proxy iteration errors')
+ @compatibility(is_backward_compatible=True)
def keys(self, obj: 'Proxy') -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an
# used in Proxy object when just appending to the graph while not tracing.
+@compatibility(is_backward_compatible=True)
class GraphAppendingTracer(TracerBase):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
+@compatibility(is_backward_compatible=True)
class TraceError(ValueError):
pass
-
+@compatibility(is_backward_compatible=True)
class Proxy:
"""
``Proxy`` objects are ``Node`` wrappers that flow through the
For a more detailed description into the Proxy internals, check out
the "Proxy" section in `torch/fx/OVERVIEW.md`
"""
+
+ @compatibility(is_backward_compatible=True)
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
if tracer is None:
# This allows you to create a Proxy object around a raw Node
def __bool__(self) -> bool:
return self.tracer.to_bool(self)
+ @compatibility(is_backward_compatible=True)
def keys(self):
return self.tracer.keys(self)
return self.tracer.create_proxy('call_function', orig_method, args, kwargs,
name=self.tracer.graph._target_to_str(orig_method.__name__))
+@compatibility(is_backward_compatible=True)
class Attribute(Proxy):
+ @compatibility(is_backward_compatible=True)
def __init__(self, root: Proxy, attr: str):
self.root = root
self.attr = attr
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+@compatibility(is_backward_compatible=False)
class ParameterProxy(Proxy):
"""
- a special proxy which lets "shape", "size", "dim", and a few other
+ A special proxy which lets "shape", "size", "dim", and a few other
attribute accesses pass through to the underlying module parameter object,
so that conditional tests on these attributes will not throw exception during tracing
"""
for method in magic_methods:
- def scope(method):
+ def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
impl.__name__ = method
as_magic = f'__{method}__'
setattr(Proxy, as_magic, impl)
- scope(method)
+ _scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name}__'
from .graph import Graph
from .node import Node
from ._symbolic_trace import symbolic_trace
+from ._compatibility import compatibility
import copy
from typing import Callable, Dict, List, NamedTuple, Optional, Set
import torch
+@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
-class SubgraphMatcher:
+class _SubgraphMatcher:
def __init__(self, pattern: Graph) -> None:
self.pattern = pattern
if len(pattern.nodes) == 0:
- raise ValueError("SubgraphMatcher cannot be initialized with an "
+ raise ValueError("_SubgraphMatcher cannot be initialized with an "
"empty pattern")
# `self.pattern_anchor` is the output Node in `pattern`
self.pattern_anchor = next(iter(reversed(pattern.nodes)))
gm.graph.lint()
+@compatibility(is_backward_compatible=True)
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
max_2 = torch.max(sum_2)
add_2 = add_1 + max_2
return add_2
-
"""
# Get the graphs for `gm`, `pattern`, `replacement`
original_graph = gm.graph
# Find all possible pattern matches in original_graph. Note that
# pattern matches may overlap with each other.
- matcher = SubgraphMatcher(pattern_graph)
+ matcher = _SubgraphMatcher(pattern_graph)
matches: List[Match] = []
# Consider each node as an "anchor" (deepest matching graph node)
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
+from ._compatibility import compatibility
+
+@compatibility(is_backward_compatible=False)
class TensorType:
"""
TensorType defines a type for tensors, which consists of a list of dimensions.
Dyn = _DynType()
-
+@compatibility(is_backward_compatible=False)
def is_consistent(t1, t2):
"""
A binary relation denoted by ~ that determines if t1 is consistent with t2.
return False
+@compatibility(is_backward_compatible=False)
def is_more_precise(t1, t2):
"""
A binary relation denoted by <= that determines if t1 is more precise than t2.
if isinstance(input_node_c, Node):
graph_c = input_node_c.graph
else:
+ assert isinstance(input_node_c, list)
graph_c = input_node_c[0].graph
# create a sequential list of the subgraphs' nodes from start to end,
if isinstance(input_node_c, Node):
graph_c = input_node_c.graph
else:
+ assert isinstance(input_node_c, list)
graph_c = input_node_c[0].graph
# generically handle all args and kwargs except for the input