[WIP][FX] BC guarantees for 1.10 (#63888)
authorJames Reed <jamesreed@fb.com>
Tue, 31 Aug 2021 02:54:50 +0000 (19:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 02:56:46 +0000 (19:56 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63888

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D30523133

Pulled By: jamesr66a

fbshipit-source-id: b04cc0d842a74862f42ecba98b757310cd2ec7b0

21 files changed:
test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect [new file with mode: 0644]
test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect [new file with mode: 0644]
test/test_fx.py
test/test_fx_experimental.py
torch/fx/__init__.py
torch/fx/_compatibility.py [new file with mode: 0644]
torch/fx/_symbolic_trace.py
torch/fx/annotate.py
torch/fx/experimental/fx_acc/acc_ops.py
torch/fx/graph.py
torch/fx/graph_module.py
torch/fx/immutable_collections.py
torch/fx/interpreter.py
torch/fx/node.py
torch/fx/operator_schemas.py
torch/fx/passes/shape_prop.py
torch/fx/passes/split_module.py
torch/fx/proxy.py
torch/fx/subgraph_rewriter.py
torch/fx/tensor_type.py
torch/quantization/ns/graph_passes.py

diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect
new file mode 100644 (file)
index 0000000..88e4654
--- /dev/null
@@ -0,0 +1,19 @@
+torch.fx._symbolic_trace.ProxyableClassMeta []
+torch.fx._symbolic_trace.Tracer ['call_module', 'create_arg', 'create_args_for_root', 'is_leaf_module', 'path_of_module', 'trace']
+torch.fx.graph.Graph ['call_function', 'call_method', 'call_module', 'create_node', 'eliminate_dead_code', 'erase_node', 'flatten_inps', 'get_attr', 'graph_copy', 'inserting_after', 'inserting_before', 'lint', 'node_copy', 'nodes', 'output', 'owning_module', 'placeholder', 'print_tabular', 'python_code', 'unflatten_outs']
+torch.fx.graph.PythonCode []
+torch.fx.graph_module.GraphModule ['add_submodule', 'code', 'delete_all_unused_submodules', 'delete_submodule', 'graph', 'recompile', 'to_folder']
+torch.fx.immutable_collections.immutable_dict ['clear', 'pop', 'popitem', 'update']
+torch.fx.immutable_collections.immutable_list ['append', 'clear', 'extend', 'insert', 'pop', 'remove']
+torch.fx.interpreter.Interpreter ['call_function', 'call_method', 'call_module', 'fetch_args_kwargs_from_env', 'fetch_attr', 'get_attr', 'map_nodes_to_values', 'output', 'placeholder', 'run', 'run_node']
+torch.fx.interpreter.Transformer ['call_function', 'call_module', 'get_attr', 'placeholder', 'transform']
+torch.fx.node.Node ['all_input_nodes', 'append', 'args', 'format_node', 'is_impure', 'kwargs', 'next', 'normalized_arguments', 'prepend', 'prev', 'replace_all_uses_with', 'replace_input_with', 'stack_trace', 'update_arg', 'update_kwarg']
+torch.fx.passes.shape_prop.ShapeProp ['propagate', 'run_node']
+torch.fx.passes.shape_prop.TensorMetadata ['dtype', 'is_quantized', 'memory_format', 'q_scale', 'q_zero_point', 'qscheme', 'requires_grad', 'shape', 'stride']
+torch.fx.passes.split_module.Partition []
+torch.fx.proxy.Attribute ['node']
+torch.fx.proxy.GraphAppendingTracer []
+torch.fx.proxy.Proxy ['keys']
+torch.fx.proxy.TraceError []
+torch.fx.proxy.TracerBase ['create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'record_stack_traces', 'to_bool']
+torch.fx.subgraph_rewriter.Match ['anchor', 'nodes_map']
\ No newline at end of file
diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect
new file mode 100644 (file)
index 0000000..a73fde7
--- /dev/null
@@ -0,0 +1,70 @@
+torch.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (<module math>,), autowrap_functions: Tuple[Callable, ...] = (,), enable_cpatching: bool = False, param_shapes_constant: bool = False) -> None
+torch.fx._symbolic_trace.Tracer.call_module(self, m: torch.nn.modules.module.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx._symbolic_trace.Tracer.create_arg(self, a: Any) -> 'Argument'
+torch.fx._symbolic_trace.Tracer.is_leaf_module(self, m: torch.nn.modules.module.Module, module_qualified_name: str) -> bool
+torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.module.Module) -> str
+torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph
+torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, enable_cpatching: bool = False) -> torch.fx.graph_module.GraphModule
+torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable])
+torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None)
+torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.eliminate_dead_code(self)
+torch.fx.graph.Graph.erase_node(self, to_erase: torch.fx.node.Node) -> None
+torch.fx.graph.Graph.get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: Dict[torch.fx.node.Node, torch.fx.node.Node], return_output_node = False) -> 'Optional[Argument]'
+torch.fx.graph.Graph.inserting_after(self, n: Optional[torch.fx.node.Node] = None)
+torch.fx.graph.Graph.inserting_before(self, n: Optional[torch.fx.node.Node] = None)
+torch.fx.graph.Graph.lint(self)
+torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Callable[[torch.fx.node.Node], Argument] = <function <lambda>>) -> torch.fx.node.Node
+torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
+torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.graph.Graph.print_tabular(self)
+torch.fx.graph.Graph.python_code(self, root_module: str) -> torch.fx.graph.PythonCode
+torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
+torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
+torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
+torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool
+torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode
+torch.fx.interpreter.Interpreter.__init__(self, module: torch.fx.graph_module.GraphModule, garbage_collect_values: bool = True)
+torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: torch.fx.node.Node) -> Tuple[Tuple, Dict]
+torch.fx.interpreter.Interpreter.fetch_attr(self, target: str)
+torch.fx.interpreter.Interpreter.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.map_nodes_to_values(self, args: torch.fx.node.Argument, n: torch.fx.node.Node) -> torch.fx.node.Argument
+torch.fx.interpreter.Interpreter.output(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Interpreter.run(self, *args, initial_env: Optional[Dict[torch.fx.node.Node, Any]] = None) -> Any
+torch.fx.interpreter.Interpreter.run_node(self, n: torch.fx.node.Node) -> Any
+torch.fx.interpreter.Transformer.__init__(self, module)
+torch.fx.interpreter.Transformer.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Transformer.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any
+torch.fx.interpreter.Transformer.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy
+torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy
+torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule
+torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None
+torch.fx.node.Node.append(self, x: 'Node') -> None
+torch.fx.node.Node.format_node(self, placeholder_names: List[str] = None, maybe_return_typename: List[str] = None) -> Optional[str]
+torch.fx.node.Node.prepend(self, x: 'Node') -> None
+torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node') -> List[Node]
+torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node')
+torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None
+torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None
+torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Argument], torch.fx.node.Argument]) -> torch.fx.node.Argument
+torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
+torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int])
+torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
+torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
+torch.fx.proxy.Proxy.keys(self)
+torch.fx.proxy.TracerBase.create_arg(self, a: Any) -> torch.fx.node.Argument
+torch.fx.proxy.TracerBase.create_node(self, kind: str, target: torch.fx.node.Target, args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, torch.fx.node.Argument], name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
+torch.fx.proxy.TracerBase.create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[torch.fx.node.Node], Proxy] = None)
+torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator
+torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any
+torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy'
+torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool
+torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match]
\ No newline at end of file
index 47873d7..eadcf6c 100644 (file)
@@ -11,6 +11,8 @@ import pickle
 import sys
 import torch
 import traceback
+import typing
+import types
 import warnings
 import unittest
 from math import sqrt
@@ -31,6 +33,7 @@ from copy import deepcopy
 from collections import namedtuple
 
 from torch.fx.proxy import TraceError
+from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMATIBLITY
 
 from fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
 from fx.test_dce_pass import TestDCE  # noqa: F401
@@ -3060,6 +3063,245 @@ class TestOperatorSignatures(JitTestCase):
             assert op.name in known_no_schema or "nn.functional" in op.name
 
 
+class TestFXAPIBackwardCompatibility(JitTestCase):
+    def setUp(self):
+        self.maxDiff = None
+
+    def _fn_to_stable_annotation_str(self, obj):
+        """
+        Unfortunately we have to serialize function signatures manually since
+        serialization for `inspect.Signature` objects is not stable across
+        python versions
+        """
+        fn_name = torch.typename(obj)
+
+        signature = inspect.signature(obj)
+
+        sig_str = f'{fn_name}{signature}'
+
+        arg_strs = []
+        for k, v in signature.parameters.items():
+            maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
+                if v.annotation is not inspect.Signature.empty else ''
+
+            def default_val_str(val):
+                if isinstance(val, (tuple, list)):
+                    str_pieces = ['(' if isinstance(val, tuple) else '[']
+                    str_pieces.append(', '.join(default_val_str(v) for v in val))
+                    if isinstance(val, tuple) and len(str_pieces) == 2:
+                        str_pieces.append(',')
+                    str_pieces.append(')' if isinstance(val, tuple) else ']')
+                    return ''.join(str_pieces)
+
+                # Need to fix up some default value strings.
+                # First case: modules. Default module `repr` contains the FS path of the module.
+                # Don't leak that
+                if isinstance(val, types.ModuleType):
+                    return f'<module {val.__name__}>'
+
+                # Second case: callables. Callables (such as lambdas) encode their address in
+                # their string repr. Don't do that
+                if callable(val):
+                    return f'<function {val.__name__}>'
+
+                return str(val)
+
+            if v.default is not inspect.Signature.empty:
+                default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
+                maybe_default = f' = {default_val_str}'
+            else:
+                maybe_default = ''
+            maybe_stars = ''
+            if v.kind == inspect.Parameter.VAR_POSITIONAL:
+                maybe_stars = '*'
+            elif v.kind == inspect.Parameter.VAR_KEYWORD:
+                maybe_stars = '**'
+            arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
+
+        return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
+            if signature.return_annotation is not inspect.Signature.empty else ''
+
+        return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
+
+    def _annotation_type_to_stable_str(self, t, sig_str):
+        if t is inspect.Signature.empty:
+            return ''
+
+        # Forward ref
+        if isinstance(t, str):
+            return f"'{t}'"
+        if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
+            return t.__forward_arg__
+        if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
+            return t.__forward_arg__
+
+        trivial_mappings = {
+            str : 'str',
+            int : 'int',
+            float: 'float',
+            bool: 'bool',
+            torch.dtype: 'torch.dtype',
+            torch.Tensor: 'torch.Tensor',
+            torch.device: 'torch.device',
+            torch.memory_format: 'torch.memory_format',
+            slice: 'slice',
+            torch.nn.Module: 'torch.nn.modules.module.Module',
+            torch.fx.Graph : 'torch.fx.graph.Graph',
+            torch.fx.Node : 'torch.fx.node.Node',
+            torch.fx.Proxy : 'torch.fx.proxy.Proxy',
+            torch.fx.node.Target : 'torch.fx.node.Target',
+            torch.fx.node.Argument : 'torch.fx.node.Argument',
+            torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
+            torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
+            torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
+            Ellipsis : '...',
+            typing.Any: 'Any',
+            type(None): 'NoneType',
+            None: 'None',
+            typing.Iterator: 'Iterator',
+        }
+
+        mapping = trivial_mappings.get(t, None)
+        if mapping:
+            return mapping
+
+        # Handle types with contained types
+        contained = getattr(t, '__args__', None) or []
+
+        # Callables contain a bare List for arguments
+        contained = t if isinstance(t, list) else contained
+
+        # Python 3.8 puts type vars into __args__ for unbound types such as Dict
+        if all(isinstance(ct, typing.TypeVar) for ct in contained):
+            contained = []
+
+        contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
+        contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
+
+
+        origin = getattr(t, '__origin__', None)
+        if origin is None:
+            # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
+            origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
+
+        if origin in {tuple, typing.Tuple}:
+            return f'Tuple{contained_type_str}'
+        if origin in {typing.Union}:
+            # Annoying hack to detect Optional
+            if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
+                not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
+                return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
+            return f'Union{contained_type_str}'
+        if origin in {dict, typing.Dict}:
+            return f'Dict{contained_type_str}'
+        if origin in {list, typing.List}:
+            return f'List{contained_type_str}'
+        if origin in {type, typing.Type}:
+            return f'Type{contained_type_str}'
+        if isinstance(t, typing.Callable):
+            if len(contained) > 0 and contained[0] is not Ellipsis:
+                return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
+            else:
+                return f'Callable{contained_type_str}'
+
+        raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
+                           f'Please add support for this type and confirm with the '
+                           f'FX team that your signature change is valid.')
+
+
+    def test_function_back_compat(self):
+        """
+        Test backward compatibility for function signatures with
+        @compatibility(is_backward_compatible=True). Currently this checks for
+        exact signature matches, which may lead to false positives. If this
+        becomes too annoying, we can refine this check to actually parse out
+        the saved schema strings and check if the change is truly backward-
+        incompatible.
+        """
+        signature_strs = []
+
+        for obj in _BACK_COMPAT_OBJECTS:
+            if not isinstance(obj, type):
+                signature_strs.append(self._fn_to_stable_annotation_str(obj))
+
+        signature_strs.sort()
+
+        try:
+            self.assertExpected('\n'.join(signature_strs), 'fx_backcompat_function_signatures')
+        except AssertionError as e:
+            msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
+                  f"as backwards-compatible has experienced a signature change. See the " \
+                  f"above exception context for more information. If this change was " \
+                  f"unintended, please revert it. If it was intended, check with the FX " \
+                  f"team to ensure that the proper deprecation protocols have been followed " \
+                  f"and subsequently --accept the change."
+            raise AssertionError(msg)
+
+    def test_class_member_back_compat(self):
+        """
+        Test backward compatibility for members of classes with
+        @compatibility(is_backward_compatible=True). Currently this checks for
+        exact matches on the publicly visible members of the class.
+        """
+        class_method_strs = []
+
+        for obj in _BACK_COMPAT_OBJECTS:
+            if isinstance(obj, type):
+                public_members = [name for name in obj.__dict__ if not name.startswith('_')]
+                class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
+
+        class_method_strs.sort()
+
+        try:
+            self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
+        except AssertionError as e:
+            msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
+                  f"as backwards-compatible has experienced change in its public members. See the " \
+                  f"above exception context for more information. If this change was " \
+                  f"unintended, please revert it. If it was intended, check with the FX " \
+                  f"team to ensure that the proper deprecation protocols have been followed " \
+                  f"and subsequently --accept the change."
+            raise AssertionError(msg)
+
+    def test_public_api_surface(self):
+        mod = torch.fx
+
+        non_back_compat_objects = {}
+
+        def check_symbols_have_bc_designation(m, prefix):
+            if not m.__name__.startswith('torch.fx'):
+                return
+            if m.__name__.startswith('torch.fx.experimental'):
+                return
+            for k, v in m.__dict__.items():
+                if v is m:
+                    continue
+                if k.startswith('_'):
+                    continue
+                if isinstance(v, types.ModuleType):
+                    check_symbols_have_bc_designation(v, prefix + [k])
+                elif isinstance(v, type) or isinstance(v, types.FunctionType):
+                    if v not in _MARKED_WITH_COMATIBLITY:
+                        non_back_compat_objects.setdefault(v)
+
+        check_symbols_have_bc_designation(mod, ['torch', 'fx'])
+
+
+        non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
+        # Only want objects in torch.fx
+        non_back_compat_strs = [
+            s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
+        # Only want objects in public namespaces
+        non_back_compat_strs = [
+            s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
+        non_back_compat_strs.sort()
+
+        if len(non_back_compat_strs) != 0:
+            raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
+                                 f"backwards-compatibility classification! Please decorate these "
+                                 f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
+                                 f"BC guarantees.")
+
 class TestFunctionalTracing(JitTestCase):
     IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
                     "has_torch_function_variadic", "handle_torch_function",
index f000b0a..e723ee4 100644 (file)
@@ -32,7 +32,7 @@ from torch.fx.operator_schemas import (
     type_matches,
     create_type_hint,
 )
-from torch.fx.passes.shape_prop import extract_tensor_metadata, ShapeProp
+from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
 from torch.fx.passes.split_module import split_module
 from torch.testing._internal.common_device_type import (
     ops,
@@ -96,13 +96,13 @@ class TestFXExperimental(JitTestCase):
         # Fix for now to add type/shape to output
         for node in traced.graph.nodes:
             if node.op == "output":
-                node.meta["tensor_meta"] = extract_tensor_metadata(a)
+                node.meta["tensor_meta"] = _extract_tensor_metadata(a)
         for mod in module_with_submodules.modules():
             if isinstance(mod, GraphModule):
                 for node in mod.graph.nodes:
-                    node.meta["tensor_meta"] = extract_tensor_metadata(a)
+                    node.meta["tensor_meta"] = _extract_tensor_metadata(a)
         for node in module_with_submodules.graph.nodes:
-            node.meta["tensor_meta"] = extract_tensor_metadata(a)
+            node.meta["tensor_meta"] = _extract_tensor_metadata(a)
 
         weights1 = {}
         weights2 = {}
index 4ff795e..6524c2d 100644 (file)
@@ -1,6 +1,4 @@
 r'''
-**This feature is under a Beta release and its API may change.**
-
 FX is a toolkit for developers to use to transform ``nn.Module``
 instances. FX consists of three main components: a **symbolic tracer,**
 an **intermediate representation**, and **Python code generation**. A
@@ -28,12 +26,13 @@ demonstration of these components in action:
     # High-level intermediate representation (IR) - Graph representation
     print(symbolic_traced.graph)
     """
-    graph(x):
-        %param : [#users=1] = self.param
-        %add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {})
-        %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
-        %clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0})
-        return clamp_1
+    graph():
+        %x : [#users=1] = placeholder[target=x]
+        %param : [#users=1] = get_attr[target=param]
+        %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
+        %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
+        %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
+        return clamp
     """
 
     # Code generation - valid Python code
@@ -41,10 +40,10 @@ demonstration of these components in action:
     """
     def forward(self, x):
         param = self.param
-        add_1 = x + param;  x = param = None
-        linear_1 = self.linear(add_1);  add_1 = None
-        clamp_1 = linear_1.clamp(min = 0.0, max = 1.0);  linear_1 = None
-        return clamp_1
+        add = x + param;  x = param = None
+        linear = self.linear(add);  add = None
+        clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
+        return clamp
     """
 
 The **symbolic tracer** performs "symbolic execution" of the Python
diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py
new file mode 100644 (file)
index 0000000..2d33813
--- /dev/null
@@ -0,0 +1,34 @@
+from typing import Any, Dict
+import textwrap
+
+_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
+_MARKED_WITH_COMATIBLITY : Dict[Any, None] = {}
+
+def compatibility(is_backward_compatible : bool):
+    if is_backward_compatible:
+
+        def mark_back_compat(fn):
+            docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
+            docstring += """
+.. note::
+    Backwards-compatibility for this API is guaranteed.
+"""
+            fn.__doc__ = docstring
+            _BACK_COMPAT_OBJECTS.setdefault(fn)
+            _MARKED_WITH_COMATIBLITY.setdefault(fn)
+            return fn
+
+        return mark_back_compat
+    else:
+
+        def mark_not_back_compat(fn):
+            docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
+            docstring += """
+.. warning::
+    This API is experimental and is *NOT* backward-compatible.
+"""
+            fn.__doc__ = docstring
+            _MARKED_WITH_COMATIBLITY.setdefault(fn)
+            return fn
+
+        return mark_not_back_compat
index 25f739e..d381973 100644 (file)
@@ -12,6 +12,7 @@ from torch._C import ScriptObject  # type: ignore[attr-defined]
 import torch.utils._pytree as pytree
 
 import sys
+from ._compatibility import compatibility
 from .node import Argument, map_aggregate, base_types
 from .graph import Graph, _PyTreeInfo
 from .graph_module import GraphModule
@@ -25,6 +26,7 @@ _orig_module_getattr : Callable = torch.nn.Module.__getattr__
 
 _proxyable_classes : Dict[Type, None] = {}
 
+@compatibility(is_backward_compatible=True)
 class ProxyableClassMeta(type):
     """
     ProxyableClassMeta allows you to make construction of a given Python class
@@ -157,6 +159,7 @@ class _CPatchManager(object):
     def __exit__(self, type, value, tb):
         sys.setprofile(None)
 
+@compatibility(is_backward_compatible=False)
 class PHBase(object):
     """
     Object representing an input placeholder to `concrete_args`
@@ -166,6 +169,7 @@ class PHBase(object):
 
 PH = PHBase()
 
+@compatibility(is_backward_compatible=True)
 class Tracer(TracerBase):
     # Reference: https://github.com/pytorch/pytorch/issues/54354
     # The first line of this docstring overrides the one Sphinx generates for the
@@ -182,6 +186,11 @@ class Tracer(TracerBase):
     process. The different behaviors that can be overridden are described
     in the docstrings of the methods on this class.
     """
+
+    # Not checking BC on this API because the default value for `autowrap_modules`
+    # includes the local filepath to the `math` module, which would jitter
+    # across machines.
+    @compatibility(is_backward_compatible=True)
     def __init__(self, autowrap_modules: Tuple[ModuleType] = (math, ),
                  autowrap_functions: Tuple[Callable, ...] = (),
                  enable_cpatching: bool = False,
@@ -197,11 +206,19 @@ class Tracer(TracerBase):
 
             autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
                 Python modules whose functions should be wrapped automatically
-                without needing to use fx.wrap().
+                without needing to use fx.wrap(). Backward-compatibility for
+                this parameter is guaranteed.
 
             autowrap_function (Tuple[Callable, ...]): defaults to `()`,
                 Python functions that should be wrapped automatically without
-                needing to use fx.wrap().
+                needing to use fx.wrap(). Backward compabilibility for this
+                parameter is guaranteed.
+
+            param_shapes_constant (bool): When this flag is set,  calls to shape,
+                size and a few other shape like attributes of a module's parameter
+                will be evaluted directly, rather than returning a new Proxy value
+                for an attribute access. Backward compatibility for this parameter
+                is guaranteed.
 
             enable_cpatching (bool): defaults to `False`,
                 Allows you to enable/disable monkeypatching of torch functions at the
@@ -210,12 +227,9 @@ class Tracer(TracerBase):
                 C-level monkeypatching works by directly modifying the PyCFunctionObject*
                 so that calling it returns a different function.
 
-                Turning this on is likely to slow down tracing by 1.5-3x.
-
-            param_shapes_constant (bool): see https://github.com/pytorch/pytorch/issues/61733. When
-            this flag is set,  calls to shape, size and a few other shape like attributes of a module's parameter
-            will be evaluted directly, rather than returning a new Proxy value for an attribute access.
-
+                Turning this on is likely to slow down tracing by 1.5-3x. This
+                parameter is experimental and its backward-compatibility is NOT
+                guaranteed.
         """
 
         super().__init__()
@@ -235,6 +249,7 @@ class Tracer(TracerBase):
 
         self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
 
+    @compatibility(is_backward_compatible=True)
     def create_arg(self, a: Any) -> 'Argument':
         """
         A method to specify the behavior of tracing when preparing values to
@@ -325,6 +340,7 @@ class Tracer(TracerBase):
 
         return super().create_arg(a)
 
+    @compatibility(is_backward_compatible=True)
     def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
         """
         A method to specify whether a given ``nn.Module`` is a "leaf" module.
@@ -346,6 +362,7 @@ class Tracer(TracerBase):
         """
         return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)
 
+    @compatibility(is_backward_compatible=True)
     def path_of_module(self, mod : torch.nn.Module) -> str:
         """
         Helper method to find the qualified name of ``mod`` in the Module hierarchy
@@ -372,6 +389,7 @@ class Tracer(TracerBase):
                     return n
             raise NameError('module is not installed as a submodule')
 
+    @compatibility(is_backward_compatible=True)
     def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Method that specifies the behavior of this ``Tracer`` when it encounters
@@ -404,6 +422,8 @@ class Tracer(TracerBase):
             return forward(*args, **kwargs)
         return self.create_proxy('call_module', module_qualified_name, args, kwargs)
 
+    # This method will be refactored
+    @compatibility(is_backward_compatible=False)
     def create_args_for_root(self, root_fn, is_module, concrete_args=None):
         """
         Create ``placeholder`` nodes corresponding to the signature of the ``root``
@@ -509,8 +529,8 @@ class Tracer(TracerBase):
 
         return attr_val
 
-
-    def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
+    @compatibility(is_backward_compatible=True)
+    def trace(self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
         """
         Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
         can either be an ``nn.Module`` instance or a Python callable.
@@ -524,8 +544,11 @@ class Tracer(TracerBase):
         Args:
 
             root (Union[Module, Callable]): Either a ``Module`` or a function to be
-                traced through.
-            concrete_args (Optional[Dict[str, any]]): Concrete arguments that should not be treated as Proxies.
+                traced through. Backwards-compatibility for this parameter is
+                guaranteed.
+            concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
+                not be treated as Proxies. This parameter is experimental and
+                its backwards-compatibility is *NOT* guaranteed.
 
         Returns:
 
@@ -772,6 +795,7 @@ def _autowrap_check(patcher : _Patcher, frame_dict : Dict[str, Any], function_id
                 patcher.patch(frame_dict, name, _create_wrapped_func(value))
 
 
+@compatibility(is_backward_compatible=True)
 def wrap(fn_or_name : Union[str, Callable]):
     """
     This function can be called at module-level scope to register fn_or_name as a "leaf function".
@@ -828,9 +852,11 @@ def wrap(fn_or_name : Union[str, Callable]):
     _wrapped_fns_to_patch.append((f.f_globals, fn_name))
     return fn_or_name
 
-def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None,
+@compatibility(is_backward_compatible=True)
+def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None,
                    enable_cpatching: bool = False) -> GraphModule:
-    """Symbolic tracing API
+    """
+    Symbolic tracing API
 
     Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
     constructed by recording operations seen while tracing through ``root``.
@@ -876,7 +902,6 @@ def symbolic_trace(root : Union[torch.nn.Module, Callable], concrete_args: Optio
 
     Returns:
         GraphModule: a Module created from the recorded operations from ``root``.
-
     """
     tracer = Tracer(enable_cpatching=enable_cpatching)
     graph = tracer.trace(root, concrete_args)
index 6e0646a..032ce14 100644 (file)
@@ -1,6 +1,7 @@
 from torch.fx.proxy import Proxy
+from ._compatibility import compatibility
 
-
+@compatibility(is_backward_compatible=False)
 def annotate(val, type):
     # val could be either a regular value (not tracing)
     # or fx.Proxy (tracing)
index 692ca63..1b4b469 100644 (file)
@@ -10,7 +10,7 @@ from torch.fx.experimental.fx_acc.acc_normalizer import (
     register_acc_op_mapping,
     register_custom_acc_mapper_fn,
 )
-from torch.fx.passes.shape_prop import extract_tensor_metadata
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
 
 this_arg_is_optional = True
 
@@ -1134,12 +1134,12 @@ def packed_quantized_linear_mapper(
     with node.graph.inserting_before(node):
         # Insert get_attr nodes for weight and bias
         get_weight = node.graph.get_attr(weight_name)
-        get_weight.meta["tensor_meta"] = extract_tensor_metadata(linear_module.weight())
+        get_weight.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.weight())
 
         get_bias = None
         if linear_module.bias() is not None:
             get_bias = node.graph.get_attr(bias_name)
-            get_bias.meta["tensor_meta"] = extract_tensor_metadata(linear_module.bias())
+            get_bias.meta["tensor_meta"] = _extract_tensor_metadata(linear_module.bias())
 
         # Create kwargs for acc_op.quantized_linear
         kwargs = {
@@ -1182,12 +1182,12 @@ def packed_quantized_conv2d_mapper(
     with node.graph.inserting_before(node):
         # Insert get_attr nodes for weight and bias
         get_weight = node.graph.get_attr(weight_name)
-        get_weight.meta["tensor_meta"] = extract_tensor_metadata(conv_module.weight())
+        get_weight.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.weight())
 
         get_bias = None
         if conv_module.bias() is not None:
             get_bias = node.graph.get_attr(bias_name)
-            get_bias.meta["tensor_meta"] = extract_tensor_metadata(conv_module.bias())
+            get_bias.meta["tensor_meta"] = _extract_tensor_metadata(conv_module.bias())
 
         # Create kwargs for acc_op.conv
         kwargs = {
index 1ee6f05..29ffc41 100644 (file)
@@ -1,6 +1,7 @@
 from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
 import torch.utils._pytree as pytree
 from . import _pytree as fx_pytree
+from ._compatibility import compatibility
 
 from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
 from dataclasses import dataclass
@@ -175,9 +176,12 @@ class _Namespace:
         return False
 
 
+@compatibility(is_backward_compatible=True)
 @dataclass
 class PythonCode:
-    """Represents all the information necessary to exec or save a graph as Python code."""
+    """
+    Represents all the information necessary to exec or save a graph as Python code.
+    """
     # Python source code for the forward function definition.
     src: str
     # Values in global scope during exection of `src_def`.
@@ -240,6 +244,7 @@ class _PyTreeInfo(NamedTuple):
     in_spec: pytree.TreeSpec
     out_spec: Optional[pytree.TreeSpec]
 
+@compatibility(is_backward_compatible=True)
 class Graph:
     """
     ``Graph`` is the main data structure used in the FX Intermediate Representation.
@@ -283,6 +288,8 @@ class Graph:
 
     For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
     """
+
+    @compatibility(is_backward_compatible=True)
     def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None):
         """
         Construct an empty Graph.
@@ -299,6 +306,11 @@ class Graph:
 
     @property
     def owning_module(self):
+        """
+        Return the module that owns this ``GraphModule``, if there is one,
+        ``None`` if there is no owning module or if there are multiple owning
+        modules.
+        """
         return self._owning_module
 
     @owning_module.setter
@@ -322,6 +334,7 @@ class Graph:
         """
         return _node_list(self)
 
+    @compatibility(is_backward_compatible=True)
     def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
         """
         Copy all nodes from a given graph into ``self``.
@@ -354,7 +367,7 @@ class Graph:
         from the default implementation. This uses graph_copy to copy the nodes
         in an iterative way, rather than recursive. It also populates the
         memoization table to prevent unnecessary copies (e.g. references to
-        nodes or other parts of the Graph from a custom GraphModule implementation
+        nodes or other parts of the Graph from a custom GraphModule implementation.
         """
         memo = memo if memo else {}
         g = Graph(tracer_cls=self._tracer_cls)
@@ -364,6 +377,7 @@ class Graph:
         g.output(output_val, type_expr=getattr(old_output_val, 'type', None))
         return g
 
+    @compatibility(is_backward_compatible=True)
     def create_node(self, op: str, target: 'Target',
                     args: Optional[Tuple['Argument', ...]] = None,
                     kwargs: Optional[Dict[str, 'Argument']] = None,
@@ -410,10 +424,12 @@ class Graph:
         self._len += 1
         return n
 
+    @compatibility(is_backward_compatible=False)
     def flatten_inps(self, *args):
         flat_args, args_spec = pytree.tree_flatten(args)
         return flat_args
 
+    @compatibility(is_backward_compatible=False)
     def unflatten_outs(self, out):
         if self._pytree_info is None:
             return out
@@ -422,6 +438,7 @@ class Graph:
         assert(self._pytree_info.out_spec is not None)
         return pytree.tree_unflatten(out, self._pytree_info.out_spec)
 
+    @compatibility(is_backward_compatible=True)
     def erase_node(self, to_erase : Node) -> None:
         """
         Erases a ``Node`` from the ``Graph``. Throws an exception if
@@ -448,6 +465,7 @@ class Graph:
         assert isinstance(new_kwargs, dict)
         to_erase.kwargs = new_kwargs
 
+    @compatibility(is_backward_compatible=True)
     def inserting_before(self, n: Optional[Node] = None):
         """Set the point at which create_node and companion methods will insert into the graph.
         When used within a 'with' statement, this will temporary set the insert point and
@@ -470,6 +488,7 @@ class Graph:
         assert n.graph == self, "Node to insert before is not in graph."
         return _InsertPoint(self, n.prepend)
 
+    @compatibility(is_backward_compatible=True)
     def inserting_after(self, n: Optional[Node] = None):
         """Set the point at which create_node and companion methods will insert into the graph.
         When used within a 'with' statement, this will temporary set the insert point and
@@ -492,7 +511,7 @@ class Graph:
         assert n.graph == self, "Node to insert after is not in graph."
         return _InsertPoint(self, n.append)
 
-    # sugar for create_node when you know the op
+    @compatibility(is_backward_compatible=True)
     def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
         """
         Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
@@ -514,6 +533,7 @@ class Graph:
         """
         return self.create_node('placeholder', name, type_expr=type_expr)
 
+    @compatibility(is_backward_compatible=True)
     def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
         """
         Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
@@ -571,6 +591,7 @@ class Graph:
                           "necessary buffer")
         return self.create_node('get_attr', qualified_name, type_expr=type_expr)
 
+    @compatibility(is_backward_compatible=True)
     def call_module(self,
                     module_name: str,
                     args: Optional[Tuple['Argument', ...]] = None,
@@ -615,6 +636,7 @@ class Graph:
                           "necessary submodule")
         return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
 
+    @compatibility(is_backward_compatible=True)
     def call_method(self,
                     method_name: str,
                     args: Optional[Tuple['Argument', ...]] = None,
@@ -649,6 +671,7 @@ class Graph:
         """
         return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
 
+    @compatibility(is_backward_compatible=True)
     def call_function(self,
                       the_function: Callable[..., Any],
                       args: Optional[Tuple['Argument', ...]] = None,
@@ -684,6 +707,7 @@ class Graph:
         """
         return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
 
+    @compatibility(is_backward_compatible=True)
     def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
         """
         Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
@@ -714,6 +738,7 @@ class Graph:
         result_node.meta = copy.copy(node.meta)
         return result_node
 
+    @compatibility(is_backward_compatible=True)
     def output(self, result: 'Argument', type_expr: Optional[Any] = None):
         """
         Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
@@ -745,6 +770,7 @@ class Graph:
         op = _snake_case(op)
         return op
 
+    @compatibility(is_backward_compatible=True)
     def python_code(self, root_module: str) -> PythonCode:
         """
         Turn this ``Graph`` into valid Python code.
@@ -995,7 +1021,7 @@ def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
 
     def __str__(self) -> str:
         """
-        Print a human-readable (not machine-readable) string representation
+        Return a human-readable (not machine-readable) string representation
         of this Graph
         """
         placeholder_names : List[str] = []
@@ -1011,10 +1037,12 @@ def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
                 s += '\n    ' + node_str
         return s
 
+    @compatibility(is_backward_compatible=True)
     def print_tabular(self):
         """
         Prints the intermediate representation of the graph in tabular
-        format.
+        format. Note that this API requires the ``tabulate`` module to be
+        installed.
         """
         try:
             from tabulate import tabulate
@@ -1027,6 +1055,7 @@ def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
         print(tabulate(node_specs,
               headers=['opcode', 'name', 'target', 'args', 'kwargs']))
 
+    @compatibility(is_backward_compatible=True)
     def lint(self):
         """
         Runs various checks on this Graph to make sure it is well-formed. In
@@ -1097,6 +1126,7 @@ def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
                         else:
                             m_itr = new_m_itr
 
+    @compatibility(is_backward_compatible=True)
     def eliminate_dead_code(self):
         """
         Remove all dead code from the graph, based on each node's number of
@@ -1124,7 +1154,6 @@ def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
 
             def forward(self, x):
                 return x + self.attr_1
-
         """
         # Lint the graph first to make sure its topologically sorted, otherwise
         # DCE below will not behave as expected.
index c918573..e7750db 100644 (file)
@@ -6,6 +6,7 @@ from torch.package import PackageImporter, PackageExporter
 import linecache
 from typing import Type, Dict, List, Any, Union, Optional, Set
 from .graph import Graph, _is_from_torch, _custom_builtins, PythonCode
+from ._compatibility import compatibility
 from torch.package import Importer, sys_importer
 import copy
 import itertools
@@ -17,9 +18,9 @@ import warnings
 
 # Normal exec loses the source code, however we can work with
 # the linecache module to recover it.
-# Using exec_with_source will add it to our local cache
+# Using _exec_with_source will add it to our local cache
 # and then tools like TorchScript will be able to get source info.
-class EvalCacheLoader(object):
+class _EvalCacheLoader(object):
     def __init__(self):
         self.eval_cache = {}
         self.next_id = 0
@@ -62,10 +63,10 @@ class EvalCacheLoader(object):
         self.next_id += 1
         return key
 
-_loader = EvalCacheLoader()
+_loader = _EvalCacheLoader()
 
 
-def exec_with_source(src: str, globals: Dict[str, Any]):
+def _exec_with_source(src: str, globals: Dict[str, Any]):
     key = _loader.cache(src, globals)
     exec(compile(src, key, 'exec'), globals)
 
@@ -73,7 +74,7 @@ def exec_with_source(src: str, globals: Dict[str, Any]):
 def _forward_from_src(src: str, globals: Dict[str, Any]):
     # avoid mutating the passed in dict
     globals_copy = globals.copy()
-    exec_with_source(src, globals_copy)
+    _exec_with_source(src, globals_copy)
     forward_fn = globals_copy['forward']
     del globals_copy['forward']
     return forward_fn
@@ -95,7 +96,7 @@ def _format_import_block(globals: Dict[str, Any], importer: Importer):
     return '\n'.join(import_strs)
 
 
-def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
+def _reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
     # BC: attribute name was changed from `code` to `_code` to facilitate
     # making `code` into a property and adding a docstring to it
     fn_src = body.get('_code') or body['code']
@@ -103,14 +104,14 @@ def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Mod
     return _deserialize_graph_module(forward, body)
 
 
-def reduce_package_graph_module(
+def _reduce_package_graph_module(
     importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
 ) -> torch.nn.Module:
     forward = importer.import_module(generated_module_name).forward
     return _deserialize_graph_module(forward, body)
 
 
-def reduce_deploy_graph_module(
+def _reduce_deploy_graph_module(
     importer: PackageImporter, body: Dict[Any, Any], import_block: str
 ) -> torch.nn.Module:
     ns = dict()
@@ -219,6 +220,7 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
     else:
         setattr(to_module, field, from_obj)
 
+@compatibility(is_backward_compatible=True)
 class GraphModule(torch.nn.Module):
     """
     GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
@@ -231,7 +233,6 @@ class GraphModule(torch.nn.Module):
         regenerated. However, if you edit the contents of the ``graph`` without reassigning
         the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
         code.
-
     """
     def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
         # each instance of a graph module needs its own forward method
@@ -243,6 +244,7 @@ class GraphModule(torch.nn.Module):
             pass
         return super().__new__(GraphModuleImpl)
 
+    @compatibility(is_backward_compatible=True)
     def __init__(self,
                  root: Union[torch.nn.Module, Dict[str, Any]],
                  graph: Graph,
@@ -266,7 +268,6 @@ class GraphModule(torch.nn.Module):
             class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
                 error messages will report as originating from ``GraphModule``. It may be helpful to set this
                 to ``root``'s original name or a name that makes sense within the context of your transform.
-
         """
         super().__init__()
         self.__class__.__name__ = class_name
@@ -334,6 +335,7 @@ class GraphModule(torch.nn.Module):
         g.owning_module = self
         self.recompile()
 
+    @compatibility(is_backward_compatible=False)
     def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"):
         """Dumps out module to ``folder`` with ``module_name`` so that it can be
         imported with ``from <folder> import <module_name>``
@@ -398,6 +400,7 @@ class {module_name}(torch.nn.Module):
             warnings.warn("Was not able to save the following children modules as reprs -"
                           f"saved as pickled files instead: {blobified_modules}")
 
+    @compatibility(is_backward_compatible=True)
     def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
         """
         Adds the given submodule to ``self``.
@@ -418,7 +421,6 @@ class {module_name}(torch.nn.Module):
                 denoted by ``target`` must either a) not exist yet,
                 or b) reference an ``nn.Module`` (not a parameter or
                 other attribute)
-
         """
         *prefix, field = target.split('.')
         mod: torch.nn.Module = self
@@ -439,6 +441,7 @@ class {module_name}(torch.nn.Module):
         mod.add_module(field, m)
         return True
 
+    @compatibility(is_backward_compatible=True)
     def delete_submodule(self, target: str) -> bool:
         """
         Deletes the given submodule from ``self``.
@@ -481,6 +484,7 @@ class {module_name}(torch.nn.Module):
         delattr(mod, target_submod)
         return True
 
+    @compatibility(is_backward_compatible=True)
     def delete_all_unused_submodules(self) -> None:
         """
         Deletes all unused submodules from ``self``.
@@ -535,6 +539,7 @@ class {module_name}(torch.nn.Module):
             raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
         return self._code
 
+    @compatibility(is_backward_compatible=True)
     def recompile(self) -> PythonCode:
         """
         Recompile this GraphModule from its ``graph`` attribute. This should be
@@ -613,7 +618,7 @@ class {module_name}(torch.nn.Module):
 
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, importer)
-        return (reduce_deploy_graph_module, (dict_without_graph, import_block))
+        return (_reduce_deploy_graph_module, (dict_without_graph, import_block))
 
     def __reduce_package__(self, exporter: PackageExporter):
         dict_without_graph = self.__dict__.copy()
@@ -625,7 +630,7 @@ class {module_name}(torch.nn.Module):
         import_block = _format_import_block(python_code.globals, exporter.importer)
         module_code = import_block + self.code
         exporter.save_source_string(generated_module_name, module_code)
-        return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
+        return (_reduce_package_graph_module, (dict_without_graph, generated_module_name))
 
     def __reduce__(self):
         """
@@ -639,7 +644,7 @@ class {module_name}(torch.nn.Module):
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, sys_importer)
         del dict_without_graph['_graph']
-        return (reduce_graph_module, (dict_without_graph, import_block))
+        return (_reduce_graph_module, (dict_without_graph, import_block))
 
     # because __reduce__ is defined for serialization,
     # we need to define deepcopy otherwise it will call __reduce__
index 459c30e..1093a07 100644 (file)
@@ -1,3 +1,4 @@
+from ._compatibility import compatibility
 
 _help_mutation = """\
 If you are attempting to modify the kwargs or args of a torch.fx.Node object,
@@ -20,5 +21,8 @@ immutable_list = _create_immutable_container(list,
                                               'clear', 'extend', 'insert', 'pop', 'remove'])
 immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),))
 
+compatibility(is_backward_compatible=True)(immutable_list)
+
 immutable_dict = _create_immutable_container(dict, ['__delitem__', '__setitem__', 'clear', 'pop', 'popitem', 'update'])
 immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),))
+compatibility(is_backward_compatible=True)(immutable_dict)
index 20dcf62..64233b4 100644 (file)
@@ -3,8 +3,10 @@ from .graph import Graph
 from .node import Argument, Node, Target, map_arg, map_aggregate
 from .proxy import Proxy
 from ._symbolic_trace import Tracer
+from ._compatibility import compatibility
 from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
 
+@compatibility(is_backward_compatible=True)
 class Interpreter:
     """
     An Interpreter executes an FX graph Node-by-Node. This pattern
@@ -59,6 +61,7 @@ class Interpreter:
             execution. This can be disabled to, for example, examine all of the intermediate
             values in the execution by looking at the ``Interpreter.env`` attribute.
     """
+    @compatibility(is_backward_compatible=True)
     def __init__(self, module : GraphModule, garbage_collect_values : bool = True):
         assert isinstance(module, GraphModule)
         self.module = module
@@ -84,6 +87,7 @@ class Interpreter:
                 map_arg(node.args, lambda n: register_last_uses(n, node))
                 map_arg(node.kwargs, lambda n: register_last_uses(n, node))
 
+    @compatibility(is_backward_compatible=True)
     def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None) -> Any:
         """
         Run `module` via interpretation and return the result.
@@ -123,6 +127,7 @@ class Interpreter:
                 output_val = self.env[node]
                 return output_val
 
+    @compatibility(is_backward_compatible=True)
     def run_node(self, n : Node) -> Any:
         """
         Run a specific node ``n`` and return the result.
@@ -142,7 +147,7 @@ class Interpreter:
         return getattr(self, n.op)(n.target, args, kwargs)
 
     # Main Node running APIs
-
+    @compatibility(is_backward_compatible=True)
     def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Execute a ``placeholder`` node. Note that this is stateful:
@@ -168,6 +173,7 @@ class Interpreter:
         else:
             return next(self.args_iter)
 
+    @compatibility(is_backward_compatible=True)
     def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Execute a ``get_attr`` node. Will retrieve an attribute
@@ -186,6 +192,7 @@ class Interpreter:
         assert isinstance(target, str)
         return self.fetch_attr(target)
 
+    @compatibility(is_backward_compatible=True)
     def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Execute a ``call_function`` node and return the result.
@@ -205,6 +212,7 @@ class Interpreter:
         # Execute the function and return the result
         return target(*args, **kwargs)
 
+    @compatibility(is_backward_compatible=True)
     def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Execute a ``call_method`` node and return the result.
@@ -226,6 +234,7 @@ class Interpreter:
         assert isinstance(target, str)
         return getattr(self_obj, target)(*args_tail, **kwargs)
 
+    @compatibility(is_backward_compatible=True)
     def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Execute a ``call_module`` node and return the result.
@@ -248,6 +257,7 @@ class Interpreter:
 
         return submod(*args, **kwargs)
 
+    @compatibility(is_backward_compatible=True)
     def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         """
         Execute an ``output`` node. This really just retrieves
@@ -266,7 +276,7 @@ class Interpreter:
         return args[0]
 
     # Helper methods
-
+    @compatibility(is_backward_compatible=True)
     def fetch_attr(self, target : str):
         """
         Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
@@ -285,6 +295,7 @@ class Interpreter:
             attr_itr = getattr(attr_itr, atom)
         return attr_itr
 
+    @compatibility(is_backward_compatible=True)
     def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
         """
         Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
@@ -302,6 +313,7 @@ class Interpreter:
         assert isinstance(kwargs, dict)
         return args, kwargs
 
+    @compatibility(is_backward_compatible=True)
     def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
         """
         Recursively descend through ``args`` and look up the concrete value
@@ -319,6 +331,7 @@ class Interpreter:
             return self.env[n_arg]
         return map_arg(args, load_arg)
 
+@compatibility(is_backward_compatible=True)
 class Transformer(Interpreter):
     """
     ``Transformer`` is a special type of interpreter that produces a
@@ -357,6 +370,8 @@ class Transformer(Interpreter):
     Args:
         module (GraphModule): The ``Module`` to be transformed.
     """
+
+    @compatibility(is_backward_compatible=True)
     def __init__(self, module):
         super().__init__(module)
         self.new_graph = Graph()
@@ -371,6 +386,7 @@ class Transformer(Interpreter):
         self.tracer = TransformerTracer(self.new_graph)
         self.tracer.root = module
 
+    @compatibility(is_backward_compatible=True)
     def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
         """
         Execute a ``placeholder`` node. In ``Transformer``, this is
@@ -387,6 +403,7 @@ class Transformer(Interpreter):
         assert isinstance(target, str)
         return Proxy(self.new_graph.placeholder(target), self.tracer)
 
+    @compatibility(is_backward_compatible=True)
     def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
         """
         Execute a ``get_attr`` node. In ``Transformer``, this is
@@ -403,16 +420,19 @@ class Transformer(Interpreter):
         assert isinstance(target, str)
         return Proxy(self.new_graph.get_attr(target), self.tracer)
 
+    @compatibility(is_backward_compatible=True)
     def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         # Override so that the leaf module policy from `self.tracer` is respected.
         assert isinstance(target, str)
         submod = self.fetch_attr(target)
         return self.tracer.call_module(submod, submod.forward, args, kwargs)
 
+    @compatibility(is_backward_compatible=True)
     def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
         # Override so that functions that were wrapped are still wrapped.
         return self.tracer.create_proxy('call_function', target, args, kwargs)
 
+    @compatibility(is_backward_compatible=True)
     def transform(self) -> GraphModule:
         """
         Transform ``self.module`` and return the transformed
index 8c4faf7..61dfba7 100644 (file)
@@ -1,5 +1,6 @@
 # Nodes represent a definition of a value in our graph of operators.
 from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
+from ._compatibility import compatibility
 from .immutable_collections import immutable_dict, immutable_list
 import torch
 import builtins
@@ -85,6 +86,7 @@ def _format_arg(arg) -> str:
     else:
         return str(arg)
 
+@compatibility(is_backward_compatible=True)
 class Node:
     """
     ``Node`` is the data structure that represents individual operations within
@@ -112,9 +114,37 @@ class Node:
     - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
       in the Graph printout.
     """
+
+    @compatibility(is_backward_compatible=True)
     def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
                  args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
                  return_type : Optional[Any] = None) -> None:
+        """
+        Instantiate an instance of ``Node``. Note: most often, you want to use the
+        Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
+        than instantiating a ``Node`` directly.
+
+        Args:
+            graph (Graph): The ``Graph`` to which this ``Node`` should belong.
+
+            name (str): The name to which the output of this ``Node`` should be assigned
+
+            op (str): The opcode for this ``Node``. Can be one of 'placeholder',
+                'call_method', 'call_module', 'call_function', 'get_attr',
+                'output'
+
+            target ('Target'): The target this op should call. See the broader
+                ``Node`` docstring for more details.
+
+            args (Tuple['Argument']): The args to be passed to ``target``
+
+            kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target``
+
+            return_type (Optional[Any]): The python type expression representing the
+                type of the output of this node. This field can be used for
+                annotation of values in the generated code or for other types
+                of analyses.
+        """
         self.graph = graph
         self.name = name  # unique name of value being created
         assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
@@ -187,6 +217,7 @@ class Node:
         """
         return self._prev
 
+    @compatibility(is_backward_compatible=True)
     def prepend(self, x: 'Node') -> None:
         """
         Insert x before this node in the list of nodes in the graph. Example::
@@ -205,6 +236,7 @@ class Node:
         p._next, x._prev = x, p
         x._next, self._prev = self, x
 
+    @compatibility(is_backward_compatible=True)
     def append(self, x: 'Node') -> None:
         """
         Insert x after this node in the list of nodes in the graph.
@@ -279,6 +311,7 @@ class Node:
         """
         return list(self._input_nodes.keys())
 
+    @compatibility(is_backward_compatible=True)
     def update_arg(self, idx : int, arg : Argument) -> None:
         """
         Update an existing positional argument to contain the new value
@@ -293,6 +326,7 @@ class Node:
         args[idx] = arg
         self.args = tuple(args)
 
+    @compatibility(is_backward_compatible=True)
     def update_kwarg(self, key : str, arg : Argument) -> None:
         """
         Update an existing keyword argument to contain the new value
@@ -365,6 +399,7 @@ class Node:
                 return f'operator.{target.__name__}'
         return _get_qualified_name(target)
 
+    @compatibility(is_backward_compatible=True)
     def format_node(self,
                     placeholder_names: List[str] = None,
                     maybe_return_typename: List[str] = None) -> Optional[str]:
@@ -420,6 +455,7 @@ class Node:
                    f'{self.op}[target={self._pretty_print_target(self.target)}](' \
                    f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
 
+    @compatibility(is_backward_compatible=True)
     def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']:
         """
         Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
@@ -449,6 +485,7 @@ class Node:
         assert len(self.users) == 0
         return to_process
 
+    @compatibility(is_backward_compatible=False)
     def is_impure(self):
         """
         Returns whether this op is impure, i.e. if its op is a placeholder or
@@ -478,6 +515,7 @@ class Node:
 
         return False
 
+    @compatibility(is_backward_compatible=False)
     def normalized_arguments(
             self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
             kwarg_types : Optional[Dict[str, Any]] = None,
@@ -513,7 +551,7 @@ class Node:
 
         return None
 
-
+    @compatibility(is_backward_compatible=True)
     def replace_input_with(self, old_input: 'Node', new_input: 'Node'):
         """
         Loop through input nodes of ``self``, and replace all instances of
@@ -523,7 +561,6 @@ class Node:
 
             old_input (Node): The old input node to be replaced.
             new_input (Node): The new input node to replace ``old_input``.
-
         """
         def maybe_replace_node(n : Node) -> Node:
             return new_input if n == old_input else n
@@ -535,13 +572,19 @@ class Node:
         self.__update_args_kwargs(new_args, new_kwargs)
 
 
+@compatibility(is_backward_compatible=True)
 def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
-    """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
+    """
+    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
+    """
     assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
     return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
 
+@compatibility(is_backward_compatible=True)
 def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
-    """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
+    """
+    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
+    """
     if isinstance(a, tuple):
         return tuple(map_aggregate(elem, fn) for elem in a)
     elif isinstance(a, list):
index 5f61ebe..ac559b1 100644 (file)
@@ -6,7 +6,9 @@ import enum
 import warnings
 from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast
 from torch._jit_internal import boolean_dispatched
+from ._compatibility import compatibility
 
+@compatibility(is_backward_compatible=False)
 class ArgsKwargsPair(NamedTuple):
     """
     Simple named tuple for wrapping args/kwargs pairs.
@@ -76,6 +78,7 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins
 
     return inspect.Signature(parameters, return_annotation=return_type)
 
+@compatibility(is_backward_compatible=False)
 def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature]]:
     """
     Given an operator on the `torch` namespace, return a list of `inspect.Signature`
@@ -103,6 +106,7 @@ def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature
 
     return signatures
 
+@compatibility(is_backward_compatible=False)
 def create_type_hint(x):
     try:
         if isinstance(x, list) or isinstance(x, tuple):
@@ -130,6 +134,7 @@ def create_type_hint(x):
         pass
     return x
 
+@compatibility(is_backward_compatible=False)
 def type_matches(signature_type : Any, argument_type : Any):
     sig_origin_type = getattr(signature_type, '__origin__', signature_type)
 
@@ -177,6 +182,7 @@ def type_matches(signature_type : Any, argument_type : Any):
 
     return False
 
+@compatibility(is_backward_compatible=False)
 def normalize_function(
         target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
         kwarg_types : Optional[Dict[str, Any]] = None,
@@ -272,6 +278,7 @@ def normalize_function(
 
     return new_args_and_kwargs
 
+@compatibility(is_backward_compatible=False)
 def normalize_module(
         root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
         normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
index 6f0f72d..816fbe7 100644 (file)
@@ -2,7 +2,9 @@ import torch
 import torch.fx
 from torch.fx.node import Node, map_aggregate
 from typing import Any, Tuple, NamedTuple, Optional
+from torch.fx._compatibility import compatibility
 
+@compatibility(is_backward_compatible=True)
 class TensorMetadata(NamedTuple):
     # TensorMetadata is a structure containing pertinent information
     # about a tensor within a PyTorch program.
@@ -20,7 +22,7 @@ class TensorMetadata(NamedTuple):
     q_scale : Optional[float]
     q_zero_point : Optional[int]
 
-def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
+def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
     """
     Extract a TensorMetadata NamedTuple describing `result`.
     """
@@ -58,7 +60,7 @@ def extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
     return TensorMetadata(
         shape, dtype, requires_grad, stride, memory_format, is_quantized, qscheme, q_scale, q_zero_point)
 
-
+@compatibility(is_backward_compatible=True)
 class ShapeProp(torch.fx.Interpreter):
     """
     Execute an FX graph Node-by-Node and
@@ -113,7 +115,7 @@ class ShapeProp(torch.fx.Interpreter):
             if isinstance(obj, torch.Tensor):
                 nonlocal found_tensor
                 found_tensor = True
-                return extract_tensor_metadata(obj)
+                return _extract_tensor_metadata(obj)
             else:
                 return obj
 
index 989ec92..c42af7e 100644 (file)
@@ -1,7 +1,9 @@
 import torch
 from torch.fx.graph_module import GraphModule
 from typing import Callable, List, Dict, Any, Optional
+from torch.fx._compatibility import compatibility
 
+@compatibility(is_backward_compatible=True)
 class Partition:
     def __init__(self, name: str):
         self.name: str = name
@@ -23,6 +25,7 @@ class Partition:
             f" parition dependents: {self.partition_dependents}"
 
 # Creates subgraphs out of main graph
+@compatibility(is_backward_compatible=True)
 def split_module(
     m: GraphModule,
     root_m: torch.nn.Module,
index c0b83bc..61b039f 100644 (file)
@@ -7,11 +7,14 @@ import traceback
 from .graph import magic_methods, reflectable_magic_methods, Graph
 from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable
 from .node import Target, Node, Argument, base_types, map_aggregate
+from ._compatibility import compatibility
 
+@compatibility(is_backward_compatible=True)
 class TracerBase:
     graph: Graph
     record_stack_traces : bool = False
 
+    @compatibility(is_backward_compatible=True)
     def create_node(self, kind : str, target : Target,
                     args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
                     type_expr : Optional[Any] = None) -> Node:
@@ -24,11 +27,11 @@ class TracerBase:
         """
         return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
 
+    @compatibility(is_backward_compatible=True)
     def proxy(self, node: Node) -> 'Proxy':
         return Proxy(node, self)
 
-
-
+    @compatibility(is_backward_compatible=True)
     def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
                      name: Optional[str] = None, type_expr : Optional[Any] = None,
                      proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
@@ -86,6 +89,7 @@ class TracerBase:
 
         return frame
 
+    @compatibility(is_backward_compatible=True)
     def create_arg(self, a: Any) -> Argument:
         """
         A method that lowers the objects seen as arguments during symbolic evaluation
@@ -131,6 +135,7 @@ class TracerBase:
 
         raise NotImplementedError(f"argument of type: {type(a)}")
 
+    @compatibility(is_backward_compatible=True)
     def to_bool(self, obj: 'Proxy') -> bool:
         """Called when a proxy object is being converted to a boolean, such as
         when used in control flow.  Normally we don't know what to do because
@@ -139,6 +144,7 @@ class TracerBase:
         """
         raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
 
+    @compatibility(is_backward_compatible=True)
     def iter(self, obj: 'Proxy') -> Iterator:
         """Called when a proxy object is being iterated over, such as
         when used in control flow.  Normally we don't know what to do because
@@ -154,6 +160,7 @@ class TracerBase:
                          ' Proxy docstring for help troubleshooting '
                          'Proxy iteration errors')
 
+    @compatibility(is_backward_compatible=True)
     def keys(self, obj: 'Proxy') -> Any:
         """Called when a proxy object is has the keys() method called.
         This is what happens when ** is called on a proxy. This should return an
@@ -163,15 +170,17 @@ class TracerBase:
 
 
 # used in Proxy object when just appending to the graph while not tracing.
+@compatibility(is_backward_compatible=True)
 class GraphAppendingTracer(TracerBase):
     def __init__(self, graph: Graph):
         super().__init__()
         self.graph = graph
 
+@compatibility(is_backward_compatible=True)
 class TraceError(ValueError):
     pass
 
-
+@compatibility(is_backward_compatible=True)
 class Proxy:
     """
     ``Proxy`` objects are ``Node`` wrappers that flow through the
@@ -200,6 +209,8 @@ class Proxy:
     For a more detailed description into the Proxy internals, check out
     the "Proxy" section in `torch/fx/OVERVIEW.md`
     """
+
+    @compatibility(is_backward_compatible=True)
     def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
         if tracer is None:
             # This allows you to create a Proxy object around a raw Node
@@ -232,6 +243,7 @@ class Proxy:
     def __bool__(self) -> bool:
         return self.tracer.to_bool(self)
 
+    @compatibility(is_backward_compatible=True)
     def keys(self):
         return self.tracer.keys(self)
 
@@ -253,7 +265,9 @@ class Proxy:
             return self.tracer.create_proxy('call_function', orig_method, args, kwargs,
                                             name=self.tracer.graph._target_to_str(orig_method.__name__))
 
+@compatibility(is_backward_compatible=True)
 class Attribute(Proxy):
+    @compatibility(is_backward_compatible=True)
     def __init__(self, root: Proxy, attr: str):
         self.root = root
         self.attr = attr
@@ -272,9 +286,10 @@ class Attribute(Proxy):
         return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
 
 
+@compatibility(is_backward_compatible=False)
 class ParameterProxy(Proxy):
     """
-    a special proxy which lets "shape", "size", "dim", and a few other
+    A special proxy which lets "shape", "size", "dim", and a few other
     attribute accesses pass through to the underlying  module parameter object,
     so that conditional tests on these attributes will not throw exception during tracing
     """
@@ -309,7 +324,7 @@ class ParameterProxy(Proxy):
 
 
 for method in magic_methods:
-    def scope(method):
+    def _scope(method):
         def impl(*args, **kwargs):
             tracer = args[0].tracer
             target = getattr(operator, method)
@@ -317,7 +332,7 @@ for method in magic_methods:
         impl.__name__ = method
         as_magic = f'__{method}__'
         setattr(Proxy, as_magic, impl)
-    scope(method)
+    _scope(method)
 
 def _define_reflectable(orig_method_name):
     method_name = f'__r{orig_method_name}__'
index e779f6c..72ea56a 100644 (file)
@@ -2,22 +2,24 @@ from .graph_module import GraphModule
 from .graph import Graph
 from .node import Node
 from ._symbolic_trace import symbolic_trace
+from ._compatibility import compatibility
 
 import copy
 from typing import Callable, Dict, List, NamedTuple, Optional, Set
 import torch
 
+@compatibility(is_backward_compatible=True)
 class Match(NamedTuple):
     # Node from which the match was found
     anchor: Node
     # Maps nodes in the pattern subgraph to nodes in the larger graph
     nodes_map: Dict[Node, Node]
 
-class SubgraphMatcher:
+class _SubgraphMatcher:
     def __init__(self, pattern: Graph) -> None:
         self.pattern = pattern
         if len(pattern.nodes) == 0:
-            raise ValueError("SubgraphMatcher cannot be initialized with an "
+            raise ValueError("_SubgraphMatcher cannot be initialized with an "
                              "empty pattern")
         # `self.pattern_anchor` is the output Node in `pattern`
         self.pattern_anchor = next(iter(reversed(pattern.nodes)))
@@ -129,6 +131,7 @@ def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
 
     gm.graph.lint()
 
+@compatibility(is_backward_compatible=True)
 def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
     """
     Matches all possible non-overlapping sets of operators and their
@@ -242,7 +245,6 @@ def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -
             max_2 = torch.max(sum_2)
             add_2 = add_1 + max_2
             return add_2
-
     """
     # Get the graphs for `gm`, `pattern`, `replacement`
     original_graph = gm.graph
@@ -251,7 +253,7 @@ def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -
 
     # Find all possible pattern matches in original_graph. Note that
     # pattern matches may overlap with each other.
-    matcher = SubgraphMatcher(pattern_graph)
+    matcher = _SubgraphMatcher(pattern_graph)
     matches: List[Match] = []
 
     # Consider each node as an "anchor" (deepest matching graph node)
index 18387ee..0840122 100644 (file)
@@ -1,6 +1,9 @@
 from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
 
+from ._compatibility import compatibility
 
+
+@compatibility(is_backward_compatible=False)
 class TensorType:
     """
     TensorType defines a type for tensors, which consists of a list of dimensions.
@@ -48,7 +51,7 @@ class _DynType:
 
 Dyn = _DynType()
 
-
+@compatibility(is_backward_compatible=False)
 def is_consistent(t1, t2):
     """
     A binary relation denoted by ~ that determines if t1 is consistent with t2.
@@ -74,6 +77,7 @@ def is_consistent(t1, t2):
         return False
 
 
+@compatibility(is_backward_compatible=False)
 def is_more_precise(t1, t2):
     """
     A binary relation denoted by <= that determines if t1 is more precise than t2.
index 36e737e..51eb6c2 100644 (file)
@@ -361,6 +361,7 @@ def _insert_copy_of_subgraph_a_after_input_node_c(
     if isinstance(input_node_c, Node):
         graph_c = input_node_c.graph
     else:
+        assert isinstance(input_node_c, list)
         graph_c = input_node_c[0].graph
 
     # create a sequential list of the subgraphs' nodes from start to end,
@@ -450,6 +451,7 @@ def _insert_copy_of_node_a_after_input_node_c(
     if isinstance(input_node_c, Node):
         graph_c = input_node_c.graph
     else:
+        assert isinstance(input_node_c, list)
         graph_c = input_node_c[0].graph
 
     # generically handle all args and kwargs except for the input