x = torch.tensor([[1, 2], [3, 4]])
self.run_test(RepeatsDimsModel2(), (x,))
- @skipIfUnsupportedMinOpsetVersion(11)
+ @skipIfUnsupportedMinOpsetVersion(13)
def test_dynamic_repeat_interleave(self):
class SingleDynamicModel(torch.nn.Module):
def forward(self, x):
self.run_test(NegDynamicModel(), x, test_with_inputs=[another_x],
input_names=["input_1"], dynamic_axes={"input_1" : {1 : "w"}})
- class SingleDynamicModel2(torch.nn.Module):
+ class SingleDynamicModelFloat(torch.nn.Module):
def forward(self, x):
repeats = torch.tensor([4])
return torch.repeat_interleave(x, repeats, dim=0)
- x = torch.tensor([[1, 2], [3, 4]])
- another_x = torch.tensor([[7, 8], [5, 6]])
- self.run_test(SingleDynamicModel2(), x, test_with_inputs=[another_x],
+ x = torch.tensor([[1.1, 2.1], [3.1, 4.1]])
+ another_x = torch.tensor([[7.1, 8.1], [5.1, 6.1]])
+ self.run_test(SingleDynamicModelFloat(), x, test_with_inputs=[another_x],
input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h"}})
- class AllDynamicModel(torch.nn.Module):
- def forward(self, x):
- repeats = torch.tensor([4])
- return torch.repeat_interleave(x, repeats, dim=0)
+ class DynamicRepeatsModel(torch.nn.Module):
+ def forward(self, x, repeats):
+ return torch.repeat_interleave(x, repeats, dim=1)
- x = torch.tensor([[1, 2, 4, 16], [3, 9, 27, 81], [2, 3, 5, 7]])
+ x = torch.tensor([[1, 2, 4], [3, 4, 7]])
another_x = torch.tensor([[7, 8], [5, 6]])
- self.run_test(AllDynamicModel(), x, test_with_inputs=[another_x],
- input_names=["input_1"], dynamic_axes={"input_1" : {0 : "h", 1 : "w"}})
+ repeats = torch.tensor([2])
+ another_repeats = torch.tensor([4])
+ self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(another_x, another_repeats)],
+ input_names=["input_1", "repeats_1"],
+ dynamic_axes={"input_1" : {1 : "w"}, "repeats_1" : {0 : "r"}})
+
+ class DynamicRepeatsModel2(torch.nn.Module):
+ def forward(self, x, repeats):
+ return torch.repeat_interleave(x, repeats, dim=1)
+
+ x = torch.tensor([[1, 2, 4], [3, 4, 7]])
+ repeats = torch.tensor([2])
+ another_repeats = torch.tensor([4])
+ self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
+ input_names=["input_1", "repeats_1"],
+ dynamic_axes={"repeats_1" : {0 : "r"}})
+
+ @skipIfUnsupportedMinOpsetVersion(13)
+ def test_multiple_dynamic_repeat_interleave(self):
+ class DynamicRepeatsModel(torch.nn.Module):
+ def forward(self, x, repeats):
+ return torch.repeat_interleave(x, repeats, dim=1)
+
+ x = torch.tensor([[1, 2, 4], [3, 4, 7]])
+ repeats = torch.tensor([2, 3, 4])
+ another_repeats = torch.tensor([4, 3, 2])
+ self.run_test(DynamicRepeatsModel(), (x, repeats), test_with_inputs=[(x, another_repeats)],
+ input_names=["input_1", "repeats_1"],
+ dynamic_axes={"repeats_1" : {0 : "r"}})
+
+ class DynamicRepeatsModel2(torch.nn.Module):
+ def forward(self, x, repeats):
+ return torch.repeat_interleave(x, repeats, dim=0)
+
+ x = torch.tensor([[1, 2, 4], [3, 4, 7]])
+ repeats = torch.tensor([2, 3])
+ another_repeats = torch.tensor([4, 3])
+ self.run_test(DynamicRepeatsModel2(), (x, repeats), test_with_inputs=[(x, another_repeats)],
+ input_names=["input_1", "repeats_1"],
+ dynamic_axes={"repeats_1" : {0 : "r"}})
def test_view(self):
class ViewModel(torch.nn.Module):
chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
return split(g, self, chunk_vec, dim)
-def repeat_interleave(g, self, repeats, dim=None, output_size=None):
- input = self
- final_dim = dim
- # if dim is None flatten
- # By default, use the flattened input array, and return a flat output array
- if sym_help._is_none(dim):
- input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
- dim = 0
- else:
- dim = sym_help._maybe_get_scalar(dim)
-
- repeats_dim = sym_help._get_tensor_rank(repeats)
- repeats_sizes = sym_help._get_tensor_sizes(repeats)
- input_sizes = sym_help._get_tensor_sizes(input)
- if repeats_dim is None:
- raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
- "repeats rank.")
- if repeats_sizes is None:
- raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
- "repeats size.")
- if input_sizes is None:
- raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
- "input size.")
- # Handle cases where dim is negative
- if dim < 0:
- dim += len(input_sizes)
-
- output_sizes = input_sizes.copy()
- perm_i = [0]
- for idx, input_size in enumerate(input_sizes):
- perm_i.append(idx + 1)
- if input_size is None:
- output_sizes[idx], input_sizes[idx] = 0, -1
- perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]
-
- # Cases when repeats is a single value tensor and dim has unknown input size
- if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
- if not sym_help._is_tensor(repeats):
- repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
- reps = sym_help._size_helper(g, input, dim)
- reps = unsqueeze(g, reps, 0)
- repeats = g.op("Expand", repeats, reps)
- # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
- # provided along one of the dynamic axes provided. A simple example would be
- # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
- # Now, repeat interleaving can be performed in pytorch when the value of * matches
- # with the number of elements in repeat, for example if * -> 2, number of repeats
- # should be 2 as well.
- else:
- return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
-
- reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
- value_t=torch.tensor([1], dtype=torch.long))
- r_splits = split(g, repeats, reps_like, 0)
- i_splits = split(g, input, reps_like, dim)
-
- output_sizes[dim], input_sizes[dim] = -1, 1
-
- # Create a loop to iterate over each value along the dimension
- # and perform individual interleaving using the repeats tensor
- # Loop is of the following pattern
- # input (trip_count, cond)
- # int trip_count = ...;
- # bool cond = ...;
- # for (int i=0; i < trip_count && cond; ++i) {
- # cond = ...;
- # }
-
- # Loop conditions
- loop_condition = g.op("Constant", value_t=torch.tensor(1))
- loop_condition = g.op("Cast", loop_condition, to_i=9)
- loop_len = reps
- loop = g.op("Loop", loop_len, loop_condition)
-
- # Loop inputs
- loop_block = _add_block(loop.node())
- block_input_iter = _add_input_to_block(loop_block)
- cond = _add_input_to_block(loop_block)
-
- r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
- i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
-
- i_split = unsqueeze(loop_block, i_split, dim + 1)
- r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
- r_split,
- loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
- r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
- i_split = expand(loop_block, i_split, r_concat, None)
- i_split = sym_help._reshape_helper(loop_block, i_split,
- g.op("Constant", value_t=torch.LongTensor(output_sizes)))
-
- # Loop outputs
- cond_out = loop_block.op("Cast", loop_condition, to_i=9)
- _add_output_to_block(loop_block, cond_out)
- _add_output_to_block(loop_block, i_split)
- loop_out = loop.node().output()
-
- # In this loop, the outputs are scan outputs and are concatenated along
- # the zero'th dimension (by default). In order to avoid this and concatenate
- # along the dimension provided, some post-processing is required
- loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
- return sym_help._reshape_helper(g, loop_out,
- g.op("Constant", value_t=torch.LongTensor(output_sizes)))
-
def normal(g, loc, scale, seed):
# If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
-from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero
+from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero, expand
+from torch.onnx.symbolic_opset11 import unsqueeze
+from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
# EDITING THIS FILE? READ THIS FIRST!
# user's modules.
splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
+
+def repeat_interleave(g, self, repeats, dim=None, output_size=None):
+ input = self
+ final_dim = dim
+ # if dim is None flatten
+ # By default, use the flattened input array, and return a flat output array
+ if sym_help._is_none(dim):
+ input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1])))
+ dim = 0
+ else:
+ dim = sym_help._maybe_get_scalar(dim)
+
+ repeats_dim = sym_help._get_tensor_rank(repeats)
+ repeats_sizes = sym_help._get_tensor_sizes(repeats)
+ input_sizes = sym_help._get_tensor_sizes(input)
+ if repeats_dim is None:
+ raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
+ "repeats rank.")
+ if repeats_sizes is None:
+ raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
+ "repeats size.")
+ if input_sizes is None:
+ raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
+ "input size.")
+ # Handle cases where dim is negative
+ if dim < 0:
+ dim += len(input_sizes)
+
+ output_sizes = input_sizes.copy()
+ for idx, input_size in enumerate(input_sizes):
+ if input_size is None:
+ output_sizes[idx], input_sizes[idx] = 0, -1
+ print(output_sizes, input_sizes)
+
+ cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
+ # If input size is dynamic or repeats vector is dynamic
+ if output_sizes[dim] == 0 or cond_dynamic_repeats:
+ reps = sym_help._size_helper(g, input, dim)
+ reps = unsqueeze(g, reps, 0)
+ # Check if repeats vector is a single integer value
+ # or a single dimension tensor with non-dynamic values
+ if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
+ if not sym_help._is_tensor(repeats):
+ repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
+ repeats = g.op("Expand", repeats, reps)
+ # Check if repeats is dynamic
+ # As repeats is dynamic, we use a where node as a substitute for the if statement
+ # If repests_dim = 1, expand repeats otherwise use original tensor
+ elif cond_dynamic_repeats:
+ repeat_dim = sym_help._size_helper(g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
+ repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])))
+ repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
+ # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
+ # provided along one of the dynamic axes provided. A simple example would be
+ # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
+ # Now, repeat interleaving can be performed in pytorch when the value of * matches
+ # with the number of elements in repeat, for example if * -> 2, number of repeats
+ # should be 2 as well.
+ else:
+ return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim)
+
+ reps_like = g.op("ConstantOfShape", g.op("Shape", repeats),
+ value_t=torch.tensor([1], dtype=torch.long))
+ r_splits = split(g, repeats, reps_like, 0)
+ i_splits = split(g, input, reps_like, dim)
+
+ output_sizes[dim], input_sizes[dim] = -1, 1
+
+ # Create a loop to iterate over each value along the dimension
+ # and perform individual interleaving using the repeats tensor
+ # Loop is of the following pattern
+ # input (trip_count, cond)
+ # int trip_count = ...;
+ # bool cond = ...;
+ # for (int i=0; i < trip_count && cond; ++i) {
+ # cond = ...;
+ # }
+
+ # Loop conditions
+ loop_condition = g.op("Constant", value_t=torch.tensor(1))
+ loop_condition = g.op("Cast", loop_condition, to_i=9)
+ loop_len = reps
+
+ # Create an empty sequence to store final expansions
+ final_splits = g.op("SequenceEmpty")
+ loop = g.op("Loop", loop_len, loop_condition, final_splits)
+
+ # Loop inputs
+ loop_block = _add_block(loop.node())
+ block_input_iter = _add_input_to_block(loop_block)
+ cond = _add_input_to_block(loop_block)
+ final_splits = _add_input_to_block(loop_block)
+
+ r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
+ i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
+
+ i_split = unsqueeze(loop_block, i_split, dim + 1)
+ r_concat = [loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])),
+ r_split,
+ loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:]))]
+ r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
+ i_split = expand(loop_block, i_split, r_concat, None)
+ i_split = sym_help._reshape_helper(loop_block, i_split,
+ g.op("Constant", value_t=torch.LongTensor(output_sizes)))
+ final_splits = loop_block.op("SequenceInsert", final_splits, i_split)
+
+ # Loop outputs
+ cond_out = loop_block.op("Cast", loop_condition, to_i=9)
+ _add_output_to_block(loop_block, cond_out)
+ _add_output_to_block(loop_block, final_splits)
+
+ loop_out = loop.node().output()
+ loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
+ return loop_out
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if input_sizes[dim] == 0:
- return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
+ return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
"Unsupported along dimension with unknown input size")
else:
reps = input_sizes[dim]
# Cases where repeats is a 1 dim Tensor
elif repeats_dim == 1:
if input_sizes[dim] == 0:
- return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
+ return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
"Unsupported along dimension with unknown input size")
+ if repeats_sizes[0] is None:
+ return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 13,
+ "Unsupported for cases with dynamic repeats")
assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
reps = repeats_sizes[0]
else: