Fix interpolate tracing (#19034)
authorZachary DeVito <zdevito@fb.com>
Mon, 8 Apr 2019 21:56:26 +0000 (14:56 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 8 Apr 2019 21:59:26 +0000 (14:59 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19034
ghimport-source-id: 874e0b0a8685184416152a77fc1850d9a06516ae

Differential Revision: D14837282

Pulled By: zdevito

fbshipit-source-id: b0ed82b607c288a54eecec3d6ed62c4626e5a563

test/test_jit.py
torch/nn/functional.py

index f6652fb..59e360d 100644 (file)
@@ -7729,7 +7729,7 @@ a")
 
             def forward(self, x):
                 y = self.conv(x)
-                w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=0.5)
+                w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=3)
                 return w
 
         f = test()
index 44db0d3..937e624 100644 (file)
@@ -2495,7 +2495,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
 
         # make scale_factor a tensor in tracing so constant doesn't get baked in
         if torch._C._get_tracing_state():
-            return [(torch.floor(input.size(i + 2) * torch.tensor(scale_factors[i]))) for i in range(dim)]
+            return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)]
         else:
             return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]