From 5eb8cec663e6f7b932065fe732325137d7135aa4 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 7 Sep 2021 18:19:14 -0700 Subject: [PATCH] Add permute, arange (#63407) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63407 Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D30738149 Pulled By: eellison fbshipit-source-id: 36d572488408d38b0643aa93cb08aab5c45218ad --- test/jit/test_symbolic_shape_analysis.py | 42 +++++++++++++++++++- torch/csrc/jit/passes/peephole_non_tensor.cpp | 15 ++++---- torch/csrc/jit/runtime/symbolic_shape_registry.cpp | 45 ++++++++++++++++++++++ torch/testing/_internal/common_jit.py | 5 ++- .../_internal/common_methods_invocations.py | 5 +-- 5 files changed, 99 insertions(+), 13 deletions(-) diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 6d4e33c..7c067a3 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -1,9 +1,10 @@ import torch -from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.jit_utils import JitTestCase, execWrapper import operator from torch.testing import FileCheck +from textwrap import dedent if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" @@ -177,3 +178,42 @@ class TestSymbolicShapeAnalysis(JitTestCase): torch._C._jit_pass_peephole(fn.graph) torch._C._jit_pass_constant_propagation(fn.graph) self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True) + + def test_arange_shape(self): + # no opinfo for tensor constructors + inps = [ + (10,), + (10, 10), + (0, 10), + (0, 1000), + (1, -1, -1), + (1, 0, -1), + (1, 2, 1), + (0.6, 0.89, 0.1), + (1, 10, 0.3), + (1, 10, 4), + (0.6, 0.7, 0.8), + (1, 10, 0.3), + # (True,), TODO: https://github.com/pytorch/pytorch/issues/63405 + # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405 + (0, 5), + (0, 5, 2), + (0, 5 + 1e-6), + (0, 5 - 1e-6), + (10, -1 + 1e-6, -1), + (10, -1, -1), + (10, -1 - 1e-6, -1), + ] + + for inp in inps: + funcs_template = dedent(''' + def func(): + return torch.arange({args}) + ''') + + inp_s = str(inp)[1:-1] # remove tuple parens + funcs_str = funcs_template.format(args=inp_s) + scope = {} + execWrapper(funcs_str, globals(), scope) + cu = torch.jit.CompilationUnit(funcs_str) + self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False) diff --git a/torch/csrc/jit/passes/peephole_non_tensor.cpp b/torch/csrc/jit/passes/peephole_non_tensor.cpp index f93eb9e..cffb4bb 100644 --- a/torch/csrc/jit/passes/peephole_non_tensor.cpp +++ b/torch/csrc/jit/passes/peephole_non_tensor.cpp @@ -193,13 +193,14 @@ struct PeepholeOptimizeNonTensorImpl { node->output()->replaceAllUsesWith(node->input()); changed = true; } - } else if (node->kind() == aten::Int) { - if (node->input()->type()->cast()) { - GRAPH_UPDATE( - "Removing ", getHeader(node), " as input is already an integer"); - node->output()->replaceAllUsesWith(node->input()); - changed = true; - } + } else if ( + (node->kind() == aten::Int || node->kind() == aten::ceil) && + node->inputs().size() == 1 && + node->input()->type()->cast()) { + GRAPH_UPDATE( + "Removing ", getHeader(node), " as input is already an integer"); + node->output()->replaceAllUsesWith(node->input()); + changed = true; } else if (node->kind() == aten::ne || node->kind() == aten::eq) { if (node->inputs().size() != 2 || node->inputs().at(0) != node->inputs().at(1)) { diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index dd2a2e8..ae8ddd1 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -80,6 +80,9 @@ const std::string shape_compute_functions = out.append(elem) return out + def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool=False): + return view(self, sizes) + def mean_dim(self: List[int], dims: List[int], keep_dim: bool, dt : Any): out: List[int] = [] for idx in range(len(self)): @@ -344,12 +347,47 @@ const std::string shape_compute_functions = dim += dim_post_expr return dim + def zero_dim_tensor(input: Any): + out: List[int] = [] + return out + def multiply_integers(li: List[int]): out = 1 for elem in li: out = out * elem return out + def arange_end(end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any): + assert end >= 0 + return [int(torch.ceil(end))] + + def arange_start(start: number, end: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any): + assert end >= 0 + assert end >= start + return [int(torch.ceil(end - start))] + + def arange_start_step(start: number, end: number, step: number, inp0: Any, inp1: Any, inp2: Any, inp3: Any): + assert step != 0 + if step < 0: + assert start >= end + else: + assert end >= start + return [int(torch.ceil((end - start) / step))] + + def permute(input: List[int], dims: List[int]): + assert len(input) == len(dims) + ndim = len(dims) + seen_dims: List[int] = [] + newSizes: List[int] = [] + for i in range(ndim): + dim = maybe_wrap_dim(dims[i], ndim) + seen_dims.append(dim) + newSizes.append(input[dim]) + for i in range(1, ndim): + for j in range(i): + assert seen_dims[i] != seen_dims[j] + return newSizes + def flatten(input: List[int], start_dim: int, end_dim: int): start_dim = maybe_wrap_dim(start_dim, len(input)) end_dim = maybe_wrap_dim(end_dim, len(input)) @@ -420,8 +458,13 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::gelu(Tensor self) -> Tensor", "unary"}, {"aten::tanh(Tensor self) -> Tensor", "unary"}, {"aten::erf(Tensor self) -> (Tensor)", "unary"}, + {"prim::NumToTensor.Scalar(Scalar a) -> Tensor", "zero_dim_tensor"}, + {"prim::NumToTensor.bool(bool a) -> Tensor", "zero_dim_tensor"}, {"aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "unary_four_unused_inputs"}, {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", "unary_four_unused_inputs"}, + {"aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "arange_end"}, + {"aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start"}, + {"aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start_step"}, {"aten::squeeze(Tensor(a) self) -> Tensor(a)", "squeeze_nodim"}, {"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "squeeze"}, {"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "unsqueeze"}, @@ -443,8 +486,10 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"}, {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"}, {"aten::relu(Tensor self) -> Tensor", "unary"}, + {"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"}, {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"}, {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "view"}, + {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "view_one_unused"}, {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"}, #ifdef USE_XNNPACK diff --git a/torch/testing/_internal/common_jit.py b/torch/testing/_internal/common_jit.py index 89533a6..c32300e 100644 --- a/torch/testing/_internal/common_jit.py +++ b/torch/testing/_internal/common_jit.py @@ -281,7 +281,7 @@ class JitCommonTestCase(TestCase): self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) - def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation): + def checkShapeAnalysis(self, out_size, traced_graph, assert_propagation, constant_prop=True): # repropagte input shapes provided by tracing, prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() for enable_test_mode in [True, False]: @@ -289,7 +289,8 @@ class JitCommonTestCase(TestCase): # disallowing constants helps stress test partial eval and substitution pipeline torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode) torch._C._jit_erase_non_input_shape_information(traced_graph) - torch._C._jit_pass_constant_propagation(traced_graph) + if constant_prop: + torch._C._jit_pass_constant_propagation(traced_graph) torch._C._jit_pass_propagate_shapes_on_graph(traced_graph) # Add sizes to default tensor type to avoid checking something out of scope # and difficulties with tracer leaving in other parts of tensor type diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 086f01c..375aa5f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6477,10 +6477,8 @@ op_db: List[OpInfo] = [ op=lambda self, shape: self.expand(shape), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_expand, - skips=( - # Because expand does not have a function variant. - SkipInfo('TestJit', 'test_variant_consistency_jit'),), supports_forward_ad=True, + assert_jit_shape_analysis=True, supports_out=False), OpInfo('expand_as', op=lambda self, other: self.expand_as(other), @@ -7768,6 +7766,7 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, assert_autodiffed=True, + assert_jit_shape_analysis=True, supports_forward_ad=True, sample_inputs_func=sample_inputs_permute), OpInfo('pow', -- 2.7.4