From 32a93c2424c7c165a3f52a6dc8cee83aab4d7b63 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Thu, 2 Sep 2021 16:06:17 -0700 Subject: [PATCH] Revert D30675780: [FX] Prototype for guarding against mutable operations in tracing Test Plan: revert-hammer Differential Revision: D30675780 (https://github.com/pytorch/pytorch/commit/795387477fe90e03cb598f3077a32222896e65dd) Original commit changeset: b2116b51dcc8 fbshipit-source-id: d4f1173f4989556ea54974f4c2739ef85a705fae --- ..._back_compat-fx_backcompat_class_members.expect | 2 +- test/test_fx.py | 67 ++-------------------- torch/csrc/jit/python/init.cpp | 14 ++--- torch/fx/operator_schemas.py | 44 ++------------ torch/fx/proxy.py | 7 --- 5 files changed, 14 insertions(+), 120 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect index 5c3630a..88e4654 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_class_member_back_compat-fx_backcompat_class_members.expect @@ -15,5 +15,5 @@ torch.fx.proxy.Attribute ['node'] torch.fx.proxy.GraphAppendingTracer [] torch.fx.proxy.Proxy ['keys'] torch.fx.proxy.TraceError [] -torch.fx.proxy.TracerBase ['check_mutable_operations', 'create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'record_stack_traces', 'to_bool'] +torch.fx.proxy.TracerBase ['create_arg', 'create_node', 'create_proxy', 'iter', 'keys', 'proxy', 'record_stack_traces', 'to_bool'] torch.fx.subgraph_rewriter.Match ['anchor', 'nodes_map'] \ No newline at end of file diff --git a/test/test_fx.py b/test/test_fx.py index 57a2960..5220f67 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -132,17 +132,10 @@ class Foo(object): # noqa: B209 class TestFX(JitTestCase): def setUp(self): - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations - torch.fx.proxy.TracerBase.check_mutable_operations = True - - if not (TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS): - lib_file_path = find_library_location('libtorchbind_test.so') - torch.ops.load_library(str(lib_file_path)) - - def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag + if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS: + return + lib_file_path = find_library_location('libtorchbind_test.so') + torch.ops.load_library(str(lib_file_path)) def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version @@ -2374,19 +2367,6 @@ class TestFX(JitTestCase): traced.graph.lint() - def test_throw_out_variant(self): - def foo(x): - y = torch.rand_like(x) - torch.sigmoid(x, out=y) - return y - - class MyTracer(torch.fx.Tracer): - check_mutable_operations = True - - tracer = MyTracer() - with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'): - traced_graph = tracer.trace(foo) - def test_ast_rewriter_reassigns_submodules(self): class M(torch.nn.Module): def __init__(self): @@ -3041,15 +3021,6 @@ def run_getitem_target(): class TestOperatorSignatures(JitTestCase): - def setUp(self): - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations - torch.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag - @onlyCPU @ops(op_db, allowed_dtypes=(torch.float,)) def test_get_torch_func_signature_exhaustive(self, device, dtype, op): @@ -3119,15 +3090,6 @@ class TestFXAPIBackwardCompatibility(JitTestCase): def setUp(self): self.maxDiff = None - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations - torch.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag - - def _fn_to_stable_annotation_str(self, obj): """ Unfortunately we have to serialize function signatures manually since @@ -3364,15 +3326,6 @@ class TestFXAPIBackwardCompatibility(JitTestCase): f"BC guarantees.") class TestFunctionalTracing(JitTestCase): - def setUp(self): - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations - torch.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag - IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary", "has_torch_function_variadic", "handle_torch_function", "boolean_dispatch") @@ -3387,7 +3340,6 @@ class TestFunctionalTracing(JitTestCase): ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$") CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow") INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined") - MUTABLE = (RuntimeError, r"Tried to trace mutable operation") UNTRACEABLE_FUNCTIONALS = { "adaptive_avg_pool1d": BUILT_IN_FUNC, @@ -3507,8 +3459,6 @@ class TestFunctionalTracing(JitTestCase): "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT, "upsample_nearest": INTERPOLATE_ARGS_CONFLICT, - - "normalize" : MUTABLE, } # List of nn.functionals with Tensor inputs but not with type annotation @@ -3623,15 +3573,6 @@ instantiate_device_type_tests(TestOperatorSignatures, globals()) @skipIfNoTorchVision class TestVisionTracing(JitTestCase): - def setUp(self): - # Checking for mutable operations whil tracing is feature flagged - # Enable it in testing but not by default - self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations - torch.fx.proxy.TracerBase.check_mutable_operations = True - - def tearDown(self): - torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag - PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated") INCONSISTENT_TYPE = ( RuntimeError, diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 35197e4..7e43e51 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1280,15 +1280,11 @@ void initJITBindings(PyObject* module) { [](const FunctionSchema& self, const FunctionSchema& other) { return self == other; }) - .def( - "__str__", - [](FunctionSchema& self) { - std::stringstream ss; - ss << self; - return ss.str(); - }) - .def_property_readonly( - "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); + .def("__str__", [](FunctionSchema& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }); py::class_(m, "Argument") .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 5e024e8..ac559b1 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -79,43 +79,7 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins return inspect.Signature(parameters, return_annotation=return_type) @compatibility(is_backward_compatible=False) -def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): - signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) - - if signatures and schemas: - matched_schemas = [] - - # Iterate through all of the schema until we find one that matches - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs - # values. If none matches, `new_args_and_kwargs` will be None - for candidate_signature, schema in zip(signatures, schemas): - try: - candidate_signature.bind(*args, **kwargs) - matched_schemas.append((candidate_signature, schema)) - except TypeError as e: - continue - - def throw_if_mutable(schema): - if schema.is_mutable: - raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' - f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' - f'are not supported') - - if len(matched_schemas) == 0: - # Did not match any schema. Cannot check for mutation - pass - elif len(matched_schemas) == 1: - # Matched exactly one schema, unambiguous - _, schema_to_check = matched_schemas[0] - throw_if_mutable(schema_to_check) - pass - else: - # Ambiguous schema match. Since mutability checking is best effort, - # do nothing. - pass - -@compatibility(is_backward_compatible=False) -def get_signature_for_torch_op(op : Callable, return_schemas : bool = False) -> Optional[List[inspect.Signature]]: +def get_signature_for_torch_op(op : Callable) -> Optional[List[inspect.Signature]]: """ Given an operator on the `torch` namespace, return a list of `inspect.Signature` objects corresponding to the overloads of that op.. May return `None` if a signature @@ -130,17 +94,17 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False) -> """ override = _manual_overrides.get(op) if override: - return (override, None) if return_schemas else None + return override aten_fn = torch.jit._builtins._find_builtin(op) if aten_fn is None: - return (None, None) if return_schemas else None + return None schemas = torch._C._jit_get_schemas_for_operator(aten_fn) signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] - return (signatures, schemas) if return_schemas else signatures + return signatures @compatibility(is_backward_compatible=False) def create_type_hint(x): diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index b25e45d..61b039f 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -8,15 +8,11 @@ from .graph import magic_methods, reflectable_magic_methods, Graph from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable from .node import Target, Node, Argument, base_types, map_aggregate from ._compatibility import compatibility -from .operator_schemas import check_for_mutable_operation @compatibility(is_backward_compatible=True) class TracerBase: graph: Graph record_stack_traces : bool = False - # Feature flag for mutable schema checking - # Enableby default in 1.12 - check_mutable_operations : bool = False @compatibility(is_backward_compatible=True) def create_node(self, kind : str, target : Target, @@ -29,9 +25,6 @@ class TracerBase: modification of values used in node creation. For example, one might want to disallow in-place operations from being recorded. """ - if kind == 'call_function' and self.check_mutable_operations: - check_for_mutable_operation(target, args, kwargs) - return self.graph.create_node(kind, target, args, kwargs, name, type_expr) @compatibility(is_backward_compatible=True) -- 2.7.4