Fix interpolate trace (#18875)
authorElias Ellison <eellison@fb.com>
Sat, 6 Apr 2019 00:52:12 +0000 (17:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 6 Apr 2019 00:55:23 +0000 (17:55 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/10654

The issue is that in tracing `.size` returns an int tensor, and when an int tensor is multiplied by a scalar the int dominates and the scalar gets casted 0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18875

Differential Revision: D14814441

Pulled By: eellison

fbshipit-source-id: a4e96a2698f2fcbf3ec4b2bb4c43a30250f30ad9

test/test_jit.py
torch/nn/functional.py

index c2a347a..d3caa8c 100644 (file)
@@ -7518,6 +7518,24 @@ a")
             mte, (torch.zeros(1, 2, 3),), None, verbose=False,
             example_outputs=outputs))
 
+    def test_interpolate_trace(self):
+        class test(nn.Module):
+            def __init__(self):
+                super(test, self).__init__()
+                self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
+
+            def forward(self, x):
+                y = self.conv(x)
+                w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=0.5)
+                return w
+
+        f = test()
+        # no failure
+        g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
+        x = torch.zeros(1, 1, 14, 14)
+        # constants not baked in
+        self.assertEqual(g(x), f(x))
+
     def test_trace_nested_datatypes(self):
         @torch.jit.script
         def foo(x):
index 393fa85..44db0d3 100644 (file)
@@ -2492,7 +2492,12 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
             return size
         scale_factors = _ntuple(dim)(scale_factor)
         # math.floor might return float in py2.7
-        return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)]
+
+        # make scale_factor a tensor in tracing so constant doesn't get baked in
+        if torch._C._get_tracing_state():
+            return [(torch.floor(input.size(i + 2) * torch.tensor(scale_factors[i]))) for i in range(dim)]
+        else:
+            return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
 
     if mode in ('nearest', 'area'):
         if align_corners is not None: