From: Elias Ellison Date: Sat, 6 Apr 2019 00:52:12 +0000 (-0700) Subject: Fix interpolate trace (#18875) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~365 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e6bbbb017e1ad829316e7d7ea0f9ca9e78421da2;p=platform%2Fupstream%2Fpytorch.git Fix interpolate trace (#18875) Summary: Fixes https://github.com/pytorch/pytorch/issues/10654 The issue is that in tracing `.size` returns an int tensor, and when an int tensor is multiplied by a scalar the int dominates and the scalar gets casted 0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18875 Differential Revision: D14814441 Pulled By: eellison fbshipit-source-id: a4e96a2698f2fcbf3ec4b2bb4c43a30250f30ad9 --- diff --git a/test/test_jit.py b/test/test_jit.py index c2a347a..d3caa8c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -7518,6 +7518,24 @@ a") mte, (torch.zeros(1, 2, 3),), None, verbose=False, example_outputs=outputs)) + def test_interpolate_trace(self): + class test(nn.Module): + def __init__(self): + super(test, self).__init__() + self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1) + + def forward(self, x): + y = self.conv(x) + w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=0.5) + return w + + f = test() + # no failure + g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),)) + x = torch.zeros(1, 1, 14, 14) + # constants not baked in + self.assertEqual(g(x), f(x)) + def test_trace_nested_datatypes(self): @torch.jit.script def foo(x): diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 393fa85..44db0d3 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2492,7 +2492,12 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne return size scale_factors = _ntuple(dim)(scale_factor) # math.floor might return float in py2.7 - return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)] + + # make scale_factor a tensor in tracing so constant doesn't get baked in + if torch._C._get_tracing_state(): + return [(torch.floor(input.size(i + 2) * torch.tensor(scale_factors[i]))) for i in range(dim)] + else: + return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] if mode in ('nearest', 'area'): if align_corners is not None: