def f(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
- trace, z = torch.jit.get_trace_graph(f, (x, y))
- self.assertExpectedGraph(trace)
- self.assertExportImport(trace, (x, y))
+ self.checkTrace(f, (x, y))
def test_restore_device(self):
# main purpose is checking map_location works
out = torch.sigmoid(out)
return out
- trace, z = torch.jit.get_trace_graph(f, (x, y))
- self.assertExpectedGraph(trace)
- self.assertExportImport(trace, (x, y))
+ self.checkTrace(f, (x, y))
def test_scopes_intermediate_node(self):
def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
return x + a + b + c
- self.assertExpectedGraph(simple_fn.graph, "simple")
self.assertEqual(
simple_fn(torch.ones(1)),
torch.ones(1) + 0.5 + 10 + (20 + 30))
result = x + a
return result
- self.assertExpectedGraph(bool_fn.graph, "bool")
self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
self.assertEqual(
bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
# type: (Optional[int]) -> Optional[int]
return x
- self.assertExpectedGraph(none_fn.graph, "none")
self.assertEqual(none_fn(), None)
self.assertEqual(none_fn(1), 1)
# type: (Tensor, float, int) -> Tensor
return x + a + b
- self.assertExpectedGraph(hints.graph, "type_hints")
self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
var_int = [2, -2]
var_float = [1.4321, -1.2]
- ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=']
- # TODO: turn this on for py3 (and add PY3 division semantics)
- ops_py2_only = ['/']
- if PY2:
- ops.extend(ops_py2_only)
+ ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
float_tensor = torch.randn(5, 5, device=device)
double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
# kwargs for function, // optional
# )
nn_functional_tests = [
- # TODO: default arguments for None type not supported, add
- # manually as argument, remove when ATen default arg system ready
- ('conv1d', (S, S, S), ((S, S, S), None)),
- ('conv2d', (S, S, S, S), ((S, S, S, S), None)),
- ('conv3d', (S, S, S, S, S), ((S, S, S, S, S), None)),
- ('conv_transpose1d', (S, S, S), ((S, S, S), None)),
- ('conv_transpose2d', (S, S, S, S), ((S, S, S, S), None)),
- ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S), None)),
+ ('conv1d', (S, S, S), ((S, S, S),)),
+ ('conv2d', (S, S, S, S), ((S, S, S, S),)),
+ ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
+ ('conv_transpose1d', (S, S, S), ((S, S, S),)),
+ ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
+ ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
('avg_pool1d', (S, S, S), (3,)),
('avg_pool2d', (S, S, S, S), (3,)),
('avg_pool3d', (S, S, S, S, S), (3,)),
- ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3], None)),
+ ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
('max_pool1d', (S, S, S), (2, 1)),
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
('max_pool2d', (S, S, S, S), (2, 1)),
('tanh', (S, S, S), (),),
('sigmoid', (S, S, S), (),),
('log_softmax', (S, S, S), (0,),),
- ('linear', (S, S), ((M, S), None),),
- ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M), None),),
+ ('linear', (S, S), ((M, S),),),
+ ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),),
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
- ('layer_norm', (S, S, S, S), ([5], None, None),),
- ('group_norm', (S, S, S), (1, torch.rand(5), None),),
+ ('layer_norm', (S, S, S, S), ([5],),),
+ ('group_norm', (S, S, S), (1, torch.rand(5),),),
('local_response_norm', (S, S, S), (2, ),),
- ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]), None, None),),
+ ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),),),
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
('gumbel_softmax', (S, S), (2.,),),
('gumbel_softmax', (S, S), (2., True,), 'hard'),
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
- ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), \
+ ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
1, 1., non_differentiable(torch.randn(S))),),
- ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), \
+ ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
non_differentiable(torch.randn(3, 2))),),
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
(non_differentiable(torch.rand(3, 2)),
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
- ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(), \
- (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long), \
+ ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
+ (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
torch.randint(1, S, (S,), dtype=torch.long))),
('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
- ('upsample', torch.randn(S, S, M, M), (4, None), 'with_size'),
+ ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
- ('interpolate', torch.randn(S, S, M, M), (4, None), 'with_size'),
+ ('interpolate', torch.randn(S, S, M, M), (4,), 'with_size'),
]