From ea808df25deee98a3021f5cff1f818803d56c97b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 10 Aug 2021 09:40:41 -0700 Subject: [PATCH] Test shape analysis with opinfos (#59814) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59814 Using opinfos to test shape analysis. By default, we just check that we don't give incorrect answers, and then if `assert_jit_shape_analysis` is true, tests that we correctly propagates the full shape. and it found a couple bugs {emoji:1f603} Test Plan: Imported from OSS Reviewed By: Krovatkin Differential Revision: D30200058 Pulled By: eellison fbshipit-source-id: 6226be87f5390277cfa5a1fffaa1b072d4bc8803 --- test/test_ops.py | 21 ++++++++++++---- torch/_C/__init__.pyi.in | 5 ++++ torch/csrc/jit/passes/remove_mutation.cpp | 2 +- torch/csrc/jit/python/init.cpp | 18 ++++++++++++++ torch/csrc/jit/python/python_ir.cpp | 13 ++++++---- torch/csrc/jit/runtime/symbolic_shape_registry.cpp | 13 ++++------ torch/testing/_internal/common_jit.py | 28 +++++++++++++++++++++- .../_internal/common_methods_invocations.py | 4 ++++ .../testing/_internal/jit_metaprogramming_utils.py | 1 + 9 files changed, 86 insertions(+), 19 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index cc3b737..76a7b6a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -751,7 +751,8 @@ class TestJit(JitCommonTestCase): # Check traced forward, grad, and grad grad # TODO: fix tracing here - if not has_fake_function: + supports_tracing = not has_fake_function + if supports_tracing: traced_fn = create_traced_fn(self, variant) check_against_reference(self, traced_fn, @@ -763,12 +764,24 @@ class TestJit(JitCommonTestCase): # Check alias annotation schema for correctness (make # sure inputs that aren't supposed to be modified aren't) - # Note: only runs in float32 and int64 because schema isn't affected by dtype, + # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive - if dtype in [torch.float32, torch.int32]: + if dtype == torch.float32: check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs, func_type=func_type, aten_name=op.aten_name) + # TODO: use script graph as well + checked_shape_analysis = False + if supports_tracing: + out = variant(get_sample(), *sample.args, **sample.kwargs) + + # TODO: handle multiple outputs + if isinstance(out, torch.Tensor): + self.checkShapeAnalysis(out.size(), traced_fn.graph, op.assert_jit_shape_analysis) + checked_shape_analysis = True + if op.assert_jit_shape_analysis: + self.assertTrue(checked_shape_analysis) + # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample if dtype is torch.float32: # Sandcastle doesn't fuse nodes @@ -780,7 +793,7 @@ class TestJit(JitCommonTestCase): nonfusible_nodes = op.autodiff_nonfusible_nodes fusible_nodes = op.autodiff_fusible_nodes - if not has_fake_function: + if supports_tracing: self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) assert tested, "JIT Test does not execute any logic" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1bd3949..ffdafb1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -202,6 +202,8 @@ def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_inline(Graph) -> None: ... def _jit_pass_constant_propagation(Graph) -> None: ... +def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ... +def _jit_erase_non_input_shape_information(Graph) -> None: ... def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ... def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ... def _jit_can_fuse_on_cpu() -> _bool: ... @@ -920,6 +922,7 @@ Stack = List[IValue] class JitType: annotation_str : str + def isSubtypeOf(self, other: JitType) -> _bool: ... class InferredType: def __init__(self, arg: Union[JitType, str]): ... @@ -1025,6 +1028,8 @@ class TensorType(JitType): def get(cls) -> TensorType: ... @classmethod def getInferred(cls) -> TensorType: ... + def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ... + def sizes(self) -> Optional[List[_int]]: ... # Defined in torch/csrc/jit/python/python_tree_views.cpp class SourceRange: diff --git a/torch/csrc/jit/passes/remove_mutation.cpp b/torch/csrc/jit/passes/remove_mutation.cpp index b4e335a..ede5b10 100644 --- a/torch/csrc/jit/passes/remove_mutation.cpp +++ b/torch/csrc/jit/passes/remove_mutation.cpp @@ -80,7 +80,7 @@ bool removableSetItem(Node* n) { if (index < 0) { index += n->inputs().size(); } - return index < n->inputs().size(); + return index < static_cast(n->input(0)->node()->inputs().size()); } bool MutationRemover::listMutationFollowingListConstruct(Node* n) { diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index dfbea96..5fca575 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -525,6 +525,24 @@ void initJITBindings(PyObject* module) { py::arg("graph")) .def("_jit_pass_erase_shape_information", EraseShapeInformation) .def( + "_jit_erase_non_input_shape_information", + [](std::shared_ptr& g) { + std::vector input_types; + for (Value* v : g->inputs()) { + if (auto tt = v->type()->cast()) { + input_types.push_back(tt); + } else { + input_types.push_back(nullptr); + } + } + EraseShapeInformation(g); + for (size_t i = 0; i < input_types.size(); ++i) { + if (input_types[i]) { + g->inputs().at(i)->setType(input_types[i]); + } + } + }) + .def( "_jit_pass_create_autodiff_subgraphs", [](const std::shared_ptr& graph) { CreateAutodiffSubgraphs(graph); diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 06c2562..e0951c3 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -759,11 +759,16 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "with_sizes", - [](Type& t, std::vector> sizes) -> py::object { - if (auto ptt = t.expect()) { - return py::cast(ptt->withSymbolicShapes(sizes)); + [](Type& t, c10::optional>> sizes) + -> py::object { + auto ptt = t.expect(); + if (!ptt) { + return py::none(); } - return py::none(); + if (!sizes) { + return py::cast(ptt->withSymbolicShapes(c10::SymbolicShape())); + } + return py::cast(ptt->withSymbolicShapes(*sizes)); }) .def( "varyingSizes", diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 08f1dfb..ffc2f44 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -73,17 +73,16 @@ const std::string shape_compute_functions = def mean_dim(self: List[int], dims: List[int], keep_dim: bool, dt : Any): out: List[int] = [] - idx : int = 0 - for elem in self: + for idx in range(len(self)): is_mean_dim : bool = False for reduce_dim in dims: - if idx == reduce_dim: + if idx == maybe_wrap_dim(reduce_dim, len(self)): is_mean_dim = True if is_mean_dim: if keep_dim: out.append(1) else: - out.append(elem) + out.append(self[idx]) return out def broadcast_one_unused_input(self: List[int], other: List[int], unused: Any): @@ -99,7 +98,6 @@ const std::string shape_compute_functions = def dot(self: List[int], tensor: List[int]): assert len(self) == 1 and len(tensor) == 1 assert self[0] == tensor[0] - # TODO: return self out: List[int] = [] return out @@ -189,10 +187,7 @@ const std::string shape_compute_functions = return out def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any): - out = matmul(mat1, t(mat2)) - if self is not None: - assert broadcast(self, out) == out - return out + return broadcast(self, mm(mat1, mat2)) def check_non_negative(array: List[int]) -> bool: # TODO: look into rewriting with early return and getting loop unrolling to fire diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py index f615ad4..80cb4d0 100644 --- a/torch/testing/_internal/common_jit.py +++ b/torch/testing/_internal/common_jit.py @@ -15,6 +15,7 @@ from torch.testing._internal.common_utils import enable_profiling_mode # noqa: # Standard library from itertools import chain from typing import List, Union +from torch._C import TensorType import io @@ -137,7 +138,6 @@ def check_against_reference(self, func, reference_func, output_func, args, kwarg continue self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4)) - class JitCommonTestCase(TestCase): def createFunctionFromGraph(self, trace): graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() @@ -280,3 +280,29 @@ class JitCommonTestCase(TestCase): nodes_in_diff_graph) self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) + + def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation): + # repropagte input shapes provided by tracing, + prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() + for enable_test_mode in [True, False]: + # here we are testing allowing/disallowing substituting in complete shapes as constants, + # disallowing constants helps stress test partial eval and substitution pipeline + torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode) + torch._C._jit_erase_non_input_shape_information(traced_graph) + torch._C._jit_pass_constant_propagation(traced_graph) + torch._C._jit_pass_propagate_shapes_on_graph(traced_graph) + # Add sizes to default tensor type to avoid checking something out of scope + # and difficulties with tracer leaving in other parts of tensor type + sizes = next(traced_graph.outputs()).type().symbolic_sizes() + out_type = TensorType.get().with_sizes(sizes) + actual_type = TensorType.get().with_sizes(out_size) + + # always check actual shape is a subtype of the output + self.assertTrue(actual_type.isSubtypeOf(out_type)) + + # and then if assertion flag is provided, check shape analysis + # is successful + if assert_propagation: + self.assertEqual(out_type.sizes(), out_size) + + torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 969472e..10c2a2c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -519,6 +519,7 @@ class OpInfo(object): # the following metadata relates to complex support and is checked in test_ops.py test_conjugated_samples=True, test_neg_view=True, + assert_jit_shape_analysis=False, # assert that jit shape analysis fully propagates shape ): dtypes_args = (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM) @@ -615,6 +616,8 @@ class OpInfo(object): if aliases is not None: self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment] + self.assert_jit_shape_analysis = assert_jit_shape_analysis + self.test_conjugated_samples = test_conjugated_samples self.test_neg_view = test_neg_view @@ -6469,6 +6472,7 @@ op_db: List[OpInfo] = [ backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []), assert_autodiffed=True, + assert_jit_shape_analysis=True, sample_inputs_func=sample_inputs_matmul, skips=( # matmul does not correctly warn when resizing out= inputs diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 6e4b480..28b9eb8 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -336,6 +336,7 @@ def create_traced_fn(self, fn): output = traced(*inputs_tensors) # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 traced_fn.last_graph = traced.graph_for(*inputs_tensors) # type: ignore[attr-defined] + traced_fn.graph = traced.graph # type: ignore[attr-defined] return output return traced_fn -- 2.7.4