mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs))
+ def test_interpolate_trace(self):
+ class test(nn.Module):
+ def __init__(self):
+ super(test, self).__init__()
+ self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ y = self.conv(x)
+ w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=0.5)
+ return w
+
+ f = test()
+ # no failure
+ g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
+ x = torch.zeros(1, 1, 14, 14)
+ # constants not baked in
+ self.assertEqual(g(x), f(x))
+
def test_trace_nested_datatypes(self):
@torch.jit.script
def foo(x):
return size
scale_factors = _ntuple(dim)(scale_factor)
# math.floor might return float in py2.7
- return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)]
+
+ # make scale_factor a tensor in tracing so constant doesn't get baked in
+ if torch._C._get_tracing_state():
+ return [(torch.floor(input.size(i + 2) * torch.tensor(scale_factors[i]))) for i in range(dim)]
+ else:
+ return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
if mode in ('nearest', 'area'):
if align_corners is not None: