From cf2d15bf84d70f40b15435ebc1fdc7c23273eed6 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 7 Sep 2021 18:19:14 -0700 Subject: [PATCH] Add support for slice, selec twith int, index_select (#63365) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63365 Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D30738144 Pulled By: eellison fbshipit-source-id: 7e0c572209bdc6e62ecb4fd1f06f80291de69803 --- test/test_ops.py | 29 +++++++---- torch/csrc/jit/runtime/symbolic_shape_registry.cpp | 59 ++++++++++++++++++++++ .../_internal/common_methods_invocations.py | 10 +++- 3 files changed, 85 insertions(+), 13 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3946870..90e52bb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -708,6 +708,8 @@ class TestJit(JitCommonTestCase): variants = {'method': getattr(torch.Tensor, op.name)} samples = op.sample_inputs(device, dtype, requires_grad=False) + support_script = op.supports_scripting + tested = False for sample in samples: # Test traced and scripted consistency @@ -732,7 +734,8 @@ class TestJit(JitCommonTestCase): # DifferentiableGraph nodes if they are present with disable_autodiff_subgraph_inlining(): # Check scripted forward, grad, and grad grad - script_fn = create_script_fn(self, name, func_type) + if support_script: + script_fn = create_script_fn(self, name, func_type) def out_fn(output): # Processes the output for autograd @@ -743,13 +746,14 @@ class TestJit(JitCommonTestCase): def get_sample(): return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input - check_against_reference(self, - script_fn, - func, - out_fn, - (get_sample(),) + sample.args, - sample.kwargs, - no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) + if support_script: + check_against_reference(self, + script_fn, + func, + out_fn, + (get_sample(),) + sample.args, + sample.kwargs, + no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad) # Check traced forward, grad, and grad grad # TODO: fix tracing here @@ -772,8 +776,10 @@ class TestJit(JitCommonTestCase): # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype == torch.float32: - check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs, - func_type=func_type, aten_name=op.aten_name) + # TODO: no reason why we cant run this with tracing graph + if support_script: + check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs, + func_type=func_type, aten_name=op.aten_name) # TODO: use script graph as well checked_shape_analysis = False @@ -800,7 +806,8 @@ class TestJit(JitCommonTestCase): if supports_tracing: self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) - self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) + if support_script: + self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes) assert tested, "JIT Test does not execute any logic" # alias testing is only done with torch.float for the same reason diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index ad14f3a..dd2a2e8 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -140,6 +140,62 @@ const std::string shape_compute_functions = out.append(li[i]) return out + def index_select(self: List[int], dim: int, index: List[int]): + dim = maybe_wrap_dim(dim, len(self)) + numel = multiply_integers(index) + assert len(index) <= 1 + assert dim == 0 or dim < len(self) + result_size: List[int] = [] + for i in range(len(self)): + if dim == i: + result_size.append(numel) + else: + result_size.append(self[i]) + return result_size + + def max_int(): + return 9223372036854775807 + + def slice(self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int): + ndim = len(self) + assert ndim != 0 + dim = maybe_wrap_dim(dim, ndim) + start_val = start if start is not None else 0 + end_val = end if end is not None else max_int() + assert step > 0 + if (start_val == max_int()): + start_val = 0 + if start_val < 0: + start_val += self[dim] + if end_val < 0: + end_val += self[dim] + if start_val < 0: + start_val = 0 + elif start_val >= self[dim]: + start_val = self[dim] + if end_val < start_val: + end_val = start_val + elif end_val >= self[dim]: + end_val = self[dim] + len = end_val - start_val + out = _copy(self) + out[dim] = (len + step - 1) // step + return out + + def select(self: List[int], dim: int, index: int): + ndim = len(self) + assert ndim != 0 + dim = maybe_wrap_dim(dim, ndim) + size = self[dim] + assert not (index < -size or index >= size) + if index < 0: + index += size + out: List[int] = [] + for i in range(ndim): + if i != dim: + out.append(self[i]) + return out + def matmul(tensor1: List[int] , tensor2: List[int]): dim_tensor1 = len(tensor1) dim_tensor2 = len(tensor2) @@ -369,6 +425,9 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::squeeze(Tensor(a) self) -> Tensor(a)", "squeeze_nodim"}, {"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "squeeze"}, {"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "unsqueeze"}, + {"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", "slice"}, + {"aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", "select"}, + {"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", "index_select"}, {"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, " "float eps=1e-05, bool cudnn_enable=True) -> Tensor", "unary_five_unused_inputs"}, {"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "unary_two_unused_inputs"}, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8cb0ee4..086f01c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -537,6 +537,7 @@ class OpInfo(object): # the following metadata relates to sparse support and is used in test_sparse.py supports_sparse=False, # whether the op supports sparse inputs + supports_scripting=True, # only run tracing tests # the following metadata relates to complex support and is checked in test_ops.py test_conjugated_samples=True, test_neg_view=True, @@ -636,6 +637,7 @@ class OpInfo(object): if aliases is not None: self.aliases = tuple(AliasInfo(a) for a in aliases) # type: ignore[assignment] + self.supports_scripting = supports_scripting self.assert_jit_shape_analysis = assert_jit_shape_analysis self.test_conjugated_samples = test_conjugated_samples @@ -8056,6 +8058,7 @@ op_db: List[OpInfo] = [ OpInfo('select', dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), sample_inputs_func=sample_inputs_select, + assert_jit_shape_analysis=True, supports_forward_ad=True, supports_out=False), UnaryUfuncInfo('signbit', @@ -8618,6 +8621,7 @@ op_db: List[OpInfo] = [ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_index_select, supports_forward_ad=True, + assert_jit_shape_analysis=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('index_add', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), @@ -8629,9 +8633,11 @@ op_db: List[OpInfo] = [ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, supports_inplace_autograd=False, + supports_scripting=False, op=torch.Tensor.__getitem__, - sample_inputs_func=sample_inputs_getitem, - skips=(SkipInfo('TestJit', 'test_variant_consistency_jit'),)), + skips=(SkipInfo('TestJit', 'test_variant_consistency_jit', device_type='cuda'),), + assert_jit_shape_analysis=False, # TODO: support index.Tensor() + sample_inputs_func=sample_inputs_getitem,), OpInfo('index_put', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, -- 2.7.4