From: BowenBao Date: Fri, 20 Aug 2021 19:44:29 +0000 (-0700) Subject: [ONNX] Update repeat_interleave for dynamic repeats (#59979) (#62764) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~835 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=db0771b05d81e9ca5e46740b09589a8ff0bc3ec0;p=platform%2Fupstream%2Fpytorch.git [ONNX] Update repeat_interleave for dynamic repeats (#59979) (#62764) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62764 Fixes #58733 - Support dynamic interleave for cases with dynamic repeat values - Moved repeat_interleave symbolic from opset 11 to opset 13, as sequence as output types for loop outputs is needed for this change Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D30375179 Pulled By: msaroufim fbshipit-source-id: 787f96bf91d124fd0483761088c5f4ae930d96a9 Co-authored-by: Shubham Bhokare --- diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 67903fb..fd10629 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -4323,7 +4323,7 @@ class TestONNXRuntime(unittest.TestCase): x = torch.tensor([[1, 2], [3, 4]]) self.run_test(RepeatsDimsModel2(), (x,)) - @skipIfUnsupportedMinOpsetVersion(11) + @skipIfUnsupportedMinOpsetVersion(13) def test_dynamic_repeat_interleave(self): class SingleDynamicModel(torch.nn.Module): def forward(self, x): @@ -4345,25 +4345,62 @@ class TestONNXRuntime(unittest.TestCase): self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x], input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}}) - class SingleDynamicModel2(torch.nn.Module): + class SingleDynamicModelFloat(torch.nn.Module): def forward(self, x): repeats = torch.tensor([4]) return torch.repeat_interleave(x, repeats, dim=0) - x = torch.tensor([[1, 2], [3, 4]]) - another_x = torch.tensor([[7, 8], [5, 6]]) - self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x], + x = torch.tensor([[1.1, 2.1], [3.1, 4.1]]) + another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]]) + self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x], input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}}) - class AllDynamicModel(torch.nn.Module): - def forward(self, x): - repeats = torch.tensor([4]) - return torch.repeat_interleave(x, repeats, dim=0) + class DynamicRepeatsModel(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=1) - x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]]) + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) another_x = torch.tensor([[7, 8], [5, 6]]) - self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x], - input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h", 1 : "w"}}) + repeats = torch.tensor([2]) + another_repeats = torch.tensor([4]) + self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}}) + + class DynamicRepeatsModel2(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=1) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + repeats = torch.tensor([2]) + another_repeats = torch.tensor([4]) + self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"repeats_1" : {0 : "r"}}) + + @skipIfUnsupportedMinOpsetVersion(13) + def test_multiple_dynamic_repeat_interleave(self): + class DynamicRepeatsModel(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=1) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + repeats = torch.tensor([2, 3, 4]) + another_repeats = torch.tensor([4, 3, 2]) + self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"repeats_1" : {0 : "r"}}) + + class DynamicRepeatsModel2(torch.nn.Module): + def forward(self, x, repeats): + return torch.repeat_interleave(x, repeats, dim=0) + + x = torch.tensor([[1, 2, 4], [3, 4, 7]]) + repeats = torch.tensor([2, 3]) + another_repeats = torch.tensor([4, 3]) + self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)], + input_names=["input_1", "repeats_1"], + dynamic_axes={"repeats_1" : {0 : "r"}}) def test_view(self): class ViewModel(torch.nn.Module): diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index ed7abf2..53440f1 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -894,110 +894,6 @@ def chunk(g, self, chunks, dim): chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) return split(g, self, chunk_vec, dim) -def repeat_interleave(g, self, repeats, dim=None, output_size=None): - input = self - final_dim = dim - # if dim is None flatten - # By default, use the flattened input array, and return a flat output array - if sym_help._is_none(dim): - input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1]))) - dim = 0 - else: - dim = sym_help._maybe_get_scalar(dim) - - repeats_dim = sym_help._get_tensor_rank(repeats) - repeats_sizes = sym_help._get_tensor_sizes(repeats) - input_sizes = sym_help._get_tensor_sizes(input) - if repeats_dim is None: - raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " - "repeats rank.") - if repeats_sizes is None: - raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " - "repeats size.") - if input_sizes is None: - raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " - "input size.") - # Handle cases where dim is negative - if dim < 0: - dim += len(input_sizes) - - output_sizes = input_sizes.copy() - perm_i = [0] - for idx, input_size in enumerate(input_sizes): - perm_i.append(idx + 1) - if input_size is None: - output_sizes[idx], input_sizes[idx] = 0, -1 - perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0] - - # Cases when repeats is a single value tensor and dim has unknown input size - if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0: - if not sym_help._is_tensor(repeats): - repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) - reps = sym_help._size_helper(g, input, dim) - reps = unsqueeze(g, reps, 0) - repeats = g.op("Expand", repeats, reps) - # There are cases when the repeats are 1-d tensor with multiple repeats, but dim - # provided along one of the dynamic axes provided. A simple example would be - # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 - # Now, repeat interleaving can be performed in pytorch when the value of * matches - # with the number of elements in repeat, for example if * -> 2, number of repeats - # should be 2 as well. - else: - return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) - - reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), - value_t=torch.tensor([1], dtype=torch.long)) - r_splits = split(g, repeats, reps_like, 0) - i_splits = split(g, input, reps_like, dim) - - output_sizes[dim], input_sizes[dim] = -1, 1 - - # Create a loop to iterate over each value along the dimension - # and perform individual interleaving using the repeats tensor - # Loop is of the following pattern - # input (trip_count, cond) - # int trip_count = ...; - # bool cond = ...; - # for (int i=0; i < trip_count && cond; ++i) { - # cond = ...; - # } - - # Loop conditions - loop_condition = g.op("Constant", value_t=torch.tensor(1)) - loop_condition = g.op("Cast", loop_condition, to_i=9) - loop_len = reps - loop = g.op("Loop", loop_len, loop_condition) - - # Loop inputs - loop_block = _add_block(loop.node()) - block_input_iter = _add_input_to_block(loop_block) - cond = _add_input_to_block(loop_block) - - r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) - i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) - - i_split = unsqueeze(loop_block, i_split, dim + 1) - r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), - r_split, - loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))] - r_concat = loop_block.op("Concat", *r_concat, axis_i=0) - i_split = expand(loop_block, i_split, r_concat, None) - i_split = sym_help._reshape_helper(loop_block, i_split, - g.op("Constant", value_t=torch.LongTensor(output_sizes))) - - # Loop outputs - cond_out = loop_block.op("Cast", loop_condition, to_i=9) - _add_output_to_block(loop_block, cond_out) - _add_output_to_block(loop_block, i_split) - loop_out = loop.node().output() - - # In this loop, the outputs are scan outputs and are concatenated along - # the zero'th dimension (by default). In order to avoid this and concatenate - # along the dimension provided, some post-processing is required - loop_out = g.op("Transpose", loop_out, perm_i=perm_i) - return sym_help._reshape_helper(g, loop_out, - g.op("Constant", value_t=torch.LongTensor(output_sizes))) - def normal(g, loc, scale, seed): # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 7f20833..0baf785 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -5,7 +5,9 @@ import torch import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _unimplemented -from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero +from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand +from torch.onnx.symbolic_opset11 import unsqueeze +from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block # EDITING THIS FILE? READ THIS FIRST! @@ -196,3 +198,117 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None): # user's modules. splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + +def repeat_interleave(g, self, repeats, dim=None, output_size=None): + input = self + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if sym_help._is_none(dim): + input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1]))) + dim = 0 + else: + dim = sym_help._maybe_get_scalar(dim) + + repeats_dim = sym_help._get_tensor_rank(repeats) + repeats_sizes = sym_help._get_tensor_sizes(repeats) + input_sizes = sym_help._get_tensor_sizes(input) + if repeats_dim is None: + raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " + "repeats rank.") + if repeats_sizes is None: + raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " + "repeats size.") + if input_sizes is None: + raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown " + "input size.") + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + print(output_sizes, input_sizes) + + cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None) + # If input size is dynamic or repeats vector is dynamic + if output_sizes[dim] == 0 or cond_dynamic_repeats: + reps = sym_help._size_helper(g, input, dim) + reps = unsqueeze(g, reps, 0) + # Check if repeats vector is a single integer value + # or a single dimension tensor with non-dynamic values + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if not sym_help._is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + repeats = g.op("Expand", repeats, reps) + # Check if repeats is dynamic + # As repeats is dynamic, we use a where node as a substitute for the if statement + # If repests_dim = 1, expand repeats otherwise use original tensor + elif cond_dynamic_repeats: + repeat_dim = sym_help._size_helper(g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))) + repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))) + repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long)) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, input, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=9) + loop_len = reps + + # Create an empty sequence to store final expansions + final_splits = g.op("SequenceEmpty") + loop = g.op("Loop", loop_len, loop_condition, final_splits) + + # Loop inputs + loop_block = _add_block(loop.node()) + block_input_iter = _add_input_to_block(loop_block) + cond = _add_input_to_block(loop_block) + final_splits = _add_input_to_block(loop_block) + + r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) + + i_split = unsqueeze(loop_block, i_split, dim + 1) + r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), + r_split, + loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))] + r_concat = loop_block.op("Concat", *r_concat, axis_i=0) + i_split = expand(loop_block, i_split, r_concat, None) + i_split = sym_help._reshape_helper(loop_block, i_split, + g.op("Constant", value_t=torch.LongTensor(output_sizes))) + final_splits = loop_block.op("SequenceInsert", final_splits, i_split) + + # Loop outputs + cond_out = loop_block.op("Cast", loop_condition, to_i=9) + _add_output_to_block(loop_block, cond_out) + _add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) + return loop_out diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 993284a..ce59e15 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2058,7 +2058,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): if not sym_help._is_tensor(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) if input_sizes[dim] == 0: - return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11, + return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size") else: reps = input_sizes[dim] @@ -2067,8 +2067,11 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None): # Cases where repeats is a 1 dim Tensor elif repeats_dim == 1: if input_sizes[dim] == 0: - return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11, + return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13, "Unsupported along dimension with unknown input size") + if repeats_sizes[0] is None: + return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13, + "Unsupported for cases with dynamic repeats") assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim" reps = repeats_sizes[0] else: