--- /dev/null
+ir_version: 4
+producer_name: "pytorch"
+producer_version: "1.1"
+graph {
+ node {
+ input: "0"
+ output: "1"
+ op_type: "Slice"
+ attribute {
+ name: "axes"
+ ints: 0
+ type: INTS
+ }
+ attribute {
+ name: "ends"
+ ints: 2
+ type: INTS
+ }
+ attribute {
+ name: "starts"
+ ints: 0
+ type: INTS
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 10
+}
x = torch.rand(3, 4, requires_grad=True)
self.assertONNX(lambda x: x[:, 1:2], x)
+ def test_narrow(self):
+ x = torch.randn(3, 3, requires_grad=True)
+ self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
+
def test_atan(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: x.atan(), x)
x = torch.randn(2, 3, requires_grad=True)
self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE)
+ def test_narrow(self):
+ class NarrowModel(torch.nn.Module):
+ def forward(self, input):
+ return torch.narrow(input, 0, 0, 2)
+
+ x = torch.randn(3, 3, requires_grad=True)
+ self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
# a bit of metaprogramming to set up all the rnn tests
return g.op('NonZero', input)
+@parse_args('v', 'i', 'i', 'i')
+def narrow(g, input, dim, start, length):
+ return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
+
+
@parse_args('v', 'i', 'i')
def _argmax(g, input, dim, keepdim):
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)