From 34b43baeec3f5d2c0e45ecbdf866f121b6829d15 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Wed, 30 Jan 2019 10:44:59 -0800 Subject: [PATCH] Allow list and tuples to be passed as output_size to max_unpool1d (#16489) Summary: Changelog: - Modify concantenation of [1] to a tuple by using cases for list and non-list types. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16489 Differential Revision: D13875838 Pulled By: soumith fbshipit-source-id: fade65cc47385986b773b9bde9b4601ab93fe1cf --- test/test_nn.py | 6 ++++++ torch/nn/functional.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index ea7fd8e..a263698 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7129,6 +7129,12 @@ class TestNNInit(TestCase): output, indices = F.max_pool1d(torch.randn([1, 1, 4]), 2, stride=2, return_indices=True) self.assertEqual(F.max_unpool1d(output, indices, 2), F.max_unpool1d(output, indices, 2, stride=2)) + # Test list / tuple passed as argument to max_unpool1d + input = torch.randn([1, 1, 5]) + output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True) + self.assertEqual(F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape), + F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size())) + # Test 2D output, indices = F.max_pool2d(torch.randn([1, 1, 4, 4]), 2, stride=2, return_indices=True) self.assertEqual(F.max_unpool2d(output, indices, 2), F.max_unpool2d(output, indices, 2, stride=2)) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 04e6382..0eecd32 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -565,7 +565,12 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, padding = _single(padding) output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) - return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size + [1]).squeeze(3) + if isinstance(output_size, list): + output_size = output_size + [1] + else: + output_size = output_size + (1,) + return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), + output_size).squeeze(3) @weak_script -- 2.7.4