[ONNX] Update repeat_interleave for dynamic repeats (#59979) (#62764)
authorBowenBao <bowbao@microsoft.com>
Fri, 20 Aug 2021 19:44:29 +0000 (12:44 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 19:46:54 +0000 (12:46 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62764

Fixes #58733

- Support dynamic interleave for cases with dynamic repeat values
- Moved repeat_interleave symbolic from opset 11 to opset 13, as sequence as output types for loop outputs is needed for this change

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D30375179

Pulled By: msaroufim

fbshipit-source-id: 787f96bf91d124fd0483761088c5f4ae930d96a9

Co-authored-by: Shubham Bhokare <shubhambhokare@gmail.com>
test/onnx/test_pytorch_onnx_onnxruntime.py
torch/onnx/symbolic_opset11.py
torch/onnx/symbolic_opset13.py
torch/onnx/symbolic_opset9.py

index 67903fb..fd10629 100644 (file)
@@ -4323,7 +4323,7 @@ class TestONNXRuntime(unittest.TestCase):
         x = torch.tensor([[1, 2], [3, 4]])
         self.run_test(RepeatsDimsModel2(), (x,))
 
-    @skipIfUnsupportedMinOpsetVersion(11)
+    @skipIfUnsupportedMinOpsetVersion(13)
     def test_dynamic_repeat_interleave(self):
         class SingleDynamicModel(torch.nn.Module):
             def forward(self, x):
@@ -4345,25 +4345,62 @@ class TestONNXRuntime(unittest.TestCase):
         self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
                       input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}})
 
-        class SingleDynamicModel2(torch.nn.Module):
+        class SingleDynamicModelFloat(torch.nn.Module):
             def forward(self, x):
                 repeats = torch.tensor([4])
                 return torch.repeat_interleave(x, repeats, dim=0)
 
-        x = torch.tensor([[1, 2], [3, 4]])
-        another_x = torch.tensor([[7, 8], [5, 6]])
-        self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x],
+        x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
+        another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
+        self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x],
                       input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}})
 
-        class AllDynamicModel(torch.nn.Module):
-            def forward(self, x):
-                repeats = torch.tensor([4])
-                return torch.repeat_interleave(x, repeats, dim=0)
+        class DynamicRepeatsModel(torch.nn.Module):
+            def forward(self, x, repeats):
+                return torch.repeat_interleave(x, repeats, dim=1)
 
-        x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]])
+        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
         another_x = torch.tensor([[7, 8], [5, 6]])
-        self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x],
-                      input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h", 1 : "w"}})
+        repeats = torch.tensor([2])
+        another_repeats = torch.tensor([4])
+        self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)],
+                      input_names=["input_1", "repeats_1"],
+                      dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}})
+
+        class DynamicRepeatsModel2(torch.nn.Module):
+            def forward(self, x, repeats):
+                return torch.repeat_interleave(x, repeats, dim=1)
+
+        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
+        repeats = torch.tensor([2])
+        another_repeats = torch.tensor([4])
+        self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
+                      input_names=["input_1", "repeats_1"],
+                      dynamic_axes={"repeats_1" : {0 : "r"}})
+
+    @skipIfUnsupportedMinOpsetVersion(13)
+    def test_multiple_dynamic_repeat_interleave(self):
+        class DynamicRepeatsModel(torch.nn.Module):
+            def forward(self, x, repeats):
+                return torch.repeat_interleave(x, repeats, dim=1)
+
+        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
+        repeats = torch.tensor([2, 3, 4])
+        another_repeats = torch.tensor([4, 3, 2])
+        self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)],
+                      input_names=["input_1", "repeats_1"],
+                      dynamic_axes={"repeats_1" : {0 : "r"}})
+
+        class DynamicRepeatsModel2(torch.nn.Module):
+            def forward(self, x, repeats):
+                return torch.repeat_interleave(x, repeats, dim=0)
+
+        x = torch.tensor([[1, 2, 4], [3, 4, 7]])
+        repeats = torch.tensor([2, 3])
+        another_repeats = torch.tensor([4, 3])
+        self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
+                      input_names=["input_1", "repeats_1"],
+                      dynamic_axes={"repeats_1" : {0 : "r"}})
 
     def test_view(self):
         class ViewModel(torch.nn.Module):
index ed7abf2..53440f1 100644 (file)
@@ -894,110 +894,6 @@ def chunk(g, self, chunks, dim):
     chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
     return split(g, self, chunk_vec, dim)
 
-def repeat_interleave(g, self, repeats, dim=None, output_size=None):
-    input = self
-    final_dim = dim
-    # if dim is None flatten
-    # By default, use the flattened input array, and return a flat output array
-    if sym_help._is_none(dim):
-        input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
-        dim = 0
-    else:
-        dim = sym_help._maybe_get_scalar(dim)
-
-    repeats_dim = sym_help._get_tensor_rank(repeats)
-    repeats_sizes = sym_help._get_tensor_sizes(repeats)
-    input_sizes = sym_help._get_tensor_sizes(input)
-    if repeats_dim is None:
-        raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
-                           "repeats rank.")
-    if repeats_sizes is None:
-        raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
-                           "repeats size.")
-    if input_sizes is None:
-        raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
-                           "input size.")
-    # Handle cases where dim is negative
-    if dim < 0:
-        dim += len(input_sizes)
-
-    output_sizes = input_sizes.copy()
-    perm_i = [0]
-    for idx, input_size in enumerate(input_sizes):
-        perm_i.append(idx + 1)
-        if input_size is None:
-            output_sizes[idx], input_sizes[idx] = 0, -1
-    perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]
-
-    # Cases when repeats is a single value tensor and dim has unknown input size
-    if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
-        if not sym_help._is_tensor(repeats):
-            repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
-        reps = sym_help._size_helper(g, input, dim)
-        reps = unsqueeze(g, reps, 0)
-        repeats = g.op("Expand", repeats, reps)
-    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
-    # provided along one of the dynamic axes provided. A simple example would be
-    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
-    # Now, repeat interleaving can be performed in pytorch when the value of * matches
-    # with the number of elements in repeat, for example if * -> 2, number of repeats
-    # should be 2 as well.
-    else:
-        return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
-
-    reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
-                     value_t=torch.tensor([1], dtype=torch.long))
-    r_splits = split(g, repeats, reps_like, 0)
-    i_splits = split(g, input, reps_like, dim)
-
-    output_sizes[dim], input_sizes[dim] = -1, 1
-
-    # Create a loop to iterate over each value along the dimension
-    # and perform individual interleaving using the repeats tensor
-    # Loop is of the following pattern
-    # input (trip_count, cond)
-    #   int trip_count = ...;
-    #   bool cond = ...;
-    #   for (int i=0; i < trip_count && cond; ++i) {
-    #     cond = ...;
-    #   }
-
-    # Loop conditions
-    loop_condition = g.op("Constant", value_t=torch.tensor(1))
-    loop_condition = g.op("Cast", loop_condition, to_i=9)
-    loop_len = reps
-    loop = g.op("Loop", loop_len, loop_condition)
-
-    # Loop inputs
-    loop_block = _add_block(loop.node())
-    block_input_iter = _add_input_to_block(loop_block)
-    cond = _add_input_to_block(loop_block)
-
-    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
-    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
-
-    i_split = unsqueeze(loop_block, i_split, dim + 1)
-    r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
-                r_split,
-                loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
-    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
-    i_split = expand(loop_block, i_split, r_concat, None)
-    i_split = sym_help._reshape_helper(loop_block, i_split,
-                                       g.op("Constant", value_t=torch.LongTensor(output_sizes)))
-
-    # Loop outputs
-    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
-    _add_output_to_block(loop_block, cond_out)
-    _add_output_to_block(loop_block, i_split)
-    loop_out = loop.node().output()
-
-    # In this loop, the outputs are scan outputs and are concatenated along
-    # the zero'th dimension (by default). In order to avoid this and concatenate
-    # along the dimension provided, some post-processing is required
-    loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
-    return sym_help._reshape_helper(g, loop_out,
-                                    g.op("Constant", value_t=torch.LongTensor(output_sizes)))
-
 
 def normal(g, loc, scale, seed):
     # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
index 7f20833..0baf785 100644 (file)
@@ -5,7 +5,9 @@
 import torch
 import torch.onnx.symbolic_helper as sym_help
 from torch.onnx.symbolic_helper import parse_args, _unimplemented
-from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero
+from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand
+from torch.onnx.symbolic_opset11 import unsqueeze
+from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
 
 
 # EDITING THIS FILE? READ THIS FIRST!
@@ -196,3 +198,117 @@ def unsafe_chunk(g, self, chunks, dim, _outputs=None):
     # user's modules.
     splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
     return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
+
+def repeat_interleave(g, self, repeats, dim=None, output_size=None):
+    input = self
+    final_dim = dim
+    # if dim is None flatten
+    # By default, use the flattened input array, and return a flat output array
+    if sym_help._is_none(dim):
+        input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
+        dim = 0
+    else:
+        dim = sym_help._maybe_get_scalar(dim)
+
+    repeats_dim = sym_help._get_tensor_rank(repeats)
+    repeats_sizes = sym_help._get_tensor_sizes(repeats)
+    input_sizes = sym_help._get_tensor_sizes(input)
+    if repeats_dim is None:
+        raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
+                           "repeats rank.")
+    if repeats_sizes is None:
+        raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
+                           "repeats size.")
+    if input_sizes is None:
+        raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
+                           "input size.")
+    # Handle cases where dim is negative
+    if dim < 0:
+        dim += len(input_sizes)
+
+    output_sizes = input_sizes.copy()
+    for idx, input_size in enumerate(input_sizes):
+        if input_size is None:
+            output_sizes[idx], input_sizes[idx] = 0, -1
+    print(output_sizes, input_sizes)
+
+    cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
+    # If input size is dynamic or repeats vector is dynamic
+    if output_sizes[dim] == 0 or cond_dynamic_repeats:
+        reps = sym_help._size_helper(g, input, dim)
+        reps = unsqueeze(g, reps, 0)
+        # Check if repeats vector is a single integer value
+        # or a single dimension tensor with non-dynamic values
+        if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
+            if not sym_help._is_tensor(repeats):
+                repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
+            repeats = g.op("Expand", repeats, reps)
+        # Check if repeats is dynamic
+        # As repeats is dynamic, we use a where node as a substitute for the if statement
+        # If repests_dim = 1, expand repeats otherwise use original tensor
+        elif cond_dynamic_repeats:
+            repeat_dim = sym_help._size_helper(g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
+            repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])))
+            repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
+    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
+    # provided along one of the dynamic axes provided. A simple example would be
+    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
+    # Now, repeat interleaving can be performed in pytorch when the value of * matches
+    # with the number of elements in repeat, for example if * -> 2, number of repeats
+    # should be 2 as well.
+    else:
+        return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
+
+    reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
+                     value_t=torch.tensor([1], dtype=torch.long))
+    r_splits = split(g, repeats, reps_like, 0)
+    i_splits = split(g, input, reps_like, dim)
+
+    output_sizes[dim], input_sizes[dim] = -1, 1
+
+    # Create a loop to iterate over each value along the dimension
+    # and perform individual interleaving using the repeats tensor
+    # Loop is of the following pattern
+    # input (trip_count, cond)
+    #   int trip_count = ...;
+    #   bool cond = ...;
+    #   for (int i=0; i < trip_count && cond; ++i) {
+    #     cond = ...;
+    #   }
+
+    # Loop conditions
+    loop_condition = g.op("Constant", value_t=torch.tensor(1))
+    loop_condition = g.op("Cast", loop_condition, to_i=9)
+    loop_len = reps
+
+    # Create an empty sequence to store final expansions
+    final_splits = g.op("SequenceEmpty")
+    loop = g.op("Loop", loop_len, loop_condition, final_splits)
+
+    # Loop inputs
+    loop_block = _add_block(loop.node())
+    block_input_iter = _add_input_to_block(loop_block)
+    cond = _add_input_to_block(loop_block)
+    final_splits = _add_input_to_block(loop_block)
+
+    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
+    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
+
+    i_split = unsqueeze(loop_block, i_split, dim + 1)
+    r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
+                r_split,
+                loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
+    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
+    i_split = expand(loop_block, i_split, r_concat, None)
+    i_split = sym_help._reshape_helper(loop_block, i_split,
+                                       g.op("Constant", value_t=torch.LongTensor(output_sizes)))
+    final_splits = loop_block.op("SequenceInsert", final_splits, i_split)
+
+    # Loop outputs
+    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
+    _add_output_to_block(loop_block, cond_out)
+    _add_output_to_block(loop_block, final_splits)
+
+    loop_out = loop.node().output()
+    loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
+    return loop_out
index 993284a..ce59e15 100644 (file)
@@ -2058,7 +2058,7 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
         if not sym_help._is_tensor(repeats):
             repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
         if input_sizes[dim] == 0:
-            return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
+            return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
                                                              "Unsupported along dimension with unknown input size")
         else:
             reps = input_sizes[dim]
@@ -2067,8 +2067,11 @@ def repeat_interleave(g, self, repeats, dim=None, output_size=None):
     # Cases where repeats is a 1 dim Tensor
     elif repeats_dim == 1:
         if input_sizes[dim] == 0:
-            return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
+            return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
                                                              "Unsupported along dimension with unknown input size")
+        if repeats_sizes[0] is None:
+            return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
+                                                             "Unsupported for cases with dynamic repeats")
         assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
         reps = repeats_sizes[0]
     else: