# input size/constructing fn,
# args (tuple represents shape of a tensor arg),
# test variant name (will be used at test name suffix), // optional
+# (True, nonfusible_nodes, fusible_nodes) for autodiff, // optional
# indices for possible dim arg, // optional
# fn mapping output to part that should be gradcheck'ed, // optional
# )
+# Note: some functions have separate schema for (Tensor other) and (Scalar other),
+# and it's possible that we only support AD for Scalar version but not Tensor
+# version, and vice versa.
+# When writing tests, only scalar(float/int) input triggers the Scalar schema.
+# uniform_scalar produces a scalar **Tensor** which won't match Scalar input.
def method_tests():
set_rng_seed(0)
return [
- ('add', (S, S, S), ((S, S, S),)),
- ('add', (S, S, S), ((S, S),), 'broadcast_rhs'),
- ('add', (S, S), ((S, S, S),), 'broadcast_lhs'),
- ('add', (S, 1, S), ((M, S),), 'broadcast_all'),
- ('add', (), ((),), 'scalar'),
- ('add', (S, S, S), ((),), 'scalar_broadcast_rhs'),
- ('add', (), ((S, S, S),), 'scalar_broadcast_lhs'),
- ('add', (S, S, S), (3.14,), 'constant'),
- ('add', (), (3.14,), 'scalar_constant'),
- ('__radd__', (S, S, S), (3.14,), 'constant'),
- ('__radd__', (), (3.14,), 'scalar_constant'),
- ('sub', (S, S, S), ((S, S, S),)),
- ('sub', (S, S, S), ((S, S),), 'broadcast_rhs'),
- ('sub', (S, S), ((S, S, S),), 'broadcast_lhs'),
- ('sub', (S, 1, S), ((M, S),), 'broadcast_all'),
- ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs'),
- ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs'),
- ('sub', (S, S, S), (3.14,), 'constant'),
- ('sub', (), (3.14,), 'scalar_constant'),
- ('__rsub__', (S, S, S), (3.14,), 'constant'),
- ('__rsub__', (), (3.14,), 'scalar_constant'),
- ('mul', (S, S, S), ((S, S, S),)),
- ('mul', (), ((),), 'scalar'),
- ('mul', (S, S, S), ((S, S),), 'broadcast_rhs'),
- ('mul', (S, S), ((S, S, S),), 'broadcast_lhs'),
- ('mul', (S, 1, S), ((M, S),), 'broadcast_all'),
- ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs'),
- ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs'),
- ('mul', (S, S, S), (3.14,), 'constant'),
- ('mul', (), (3.14,), 'scalar_constant'),
- ('__rmul__', (S, S, S), (3.14,), 'constant'),
- ('__rmul__', (), (3.14,), 'scalar_constant'),
- ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,)),
- ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs'),
- ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
- ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all'),
- ('div', (), (uniform_scalar(0.1),), 'scalar'),
- ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
- ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs'),
- ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'),
- ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'),
- ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'),
- ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'),
- ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,)),
- ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
- ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
- ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
- ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
- ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
- ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
- ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
- ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
- ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'),
- ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'),
- ('transpose', (1, 2, 3), (1, 2), 'dim', [0, 1]),
- ('transpose', (), (0, 0), 'scalar'),
- ('transpose', (1,), (0, 0), '1d'),
- ('transpose', torch.rand(L, L), (0, 1), '2d'),
- ('transpose', torch.rand(S, S, S), (2, 0), '3d'),
- ('t', (1, 2), NO_ARGS),
- ('view', (S, S, S), (S * S, S),),
- ('view', (S, S, S), (torch.Size([S * S, S]),), 'size'),
- ('view', (S,), (S,), '1d'),
- ('view', (), (dont_convert(()),), 'scalar_to_scalar'),
- ('view', (), (1,), 'scalar_to_1d'),
+ ('add', (S, S, S), ((S, S, S),), '', (True,)),
+ ('add', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)),
+ ('add', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)),
+ ('add', (S, 1, S), ((M, S),), 'broadcast_all', (True,)),
+ ('add', (), ((),), 'scalar', (True,)),
+ ('add', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)),
+ ('add', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
+ ('add', (S, S, S), (3.14,), 'constant', (True,)),
+ ('add', (), (3.14,), 'scalar_constant', (True,)),
+ ('__radd__', (S, S, S), (3.14,), 'constant', (True, 'aten::add')),
+ ('__radd__', (), (3.14,), 'scalar_constant', (True, 'aten::add')),
+ ('sub', (S, S, S), ((S, S, S),), '', (True,)),
+ ('sub', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)),
+ ('sub', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)),
+ ('sub', (S, 1, S), ((M, S),), 'broadcast_all', (True,)),
+ ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)),
+ ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
+ ('sub', (S, S, S), (3.14,), 'constant', (True,)),
+ ('sub', (), (3.14,), 'scalar_constant', (True,)),
+ ('__rsub__', (S, S, S), (3.14,), 'constant', (True, 'aten::rsub')),
+ ('__rsub__', (), (3.14,), 'scalar_constant', (True, 'aten::rsub')),
+ ('mul', (S, S, S), ((S, S, S),), '', (True,)),
+ ('mul', (), ((),), 'scalar', (True,)),
+ ('mul', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)),
+ ('mul', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)),
+ ('mul', (S, 1, S), ((M, S),), 'broadcast_all', (True,)),
+ ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)),
+ ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
+ ('mul', (S, S, S), (3.14,), 'constant', (True,)),
+ ('mul', (), (3.14,), 'scalar_constant', (True,)),
+ ('__rmul__', (S, S, S), (3.14,), 'constant', (True, 'aten::mul')),
+ ('__rmul__', (), (3.14,), 'scalar_constant', (True, 'aten::mul')),
+ ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,), '', (True,)),
+ ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs', (True,)),
+ ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)),
+ ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all', (True,)),
+ ('div', (), (uniform_scalar(0.1),), 'scalar', (True,)),
+ ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)),
+ ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs', (True,)),
+ ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant', (True,)),
+ ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant',
+ (True, [], ['aten::mul', 'aten::reciprocal'])),
+ ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
+ ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant',
+ (True, [], ['aten::mul', 'aten::reciprocal'])),
+ ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), '', (True,)),
+ ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs', (True,)),
+ ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)),
+ ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all', (True,)),
+ ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar', (True,)),
+ ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)),
+ ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs', (True,)),
+ ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True,)),
+ ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')),
+ ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
+ ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
+ ('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]),
+ ('transpose', (), (0, 0), 'scalar', (True,)),
+ ('transpose', (1,), (0, 0), '1d', (True,)),
+ ('transpose', torch.rand(L, L), (0, 1), '2d', (True,)),
+ ('transpose', torch.rand(S, S, S), (2, 0), '3d', (True,)),
+ ('t', (1, 2), NO_ARGS, '', (True,)),
+ ('view', (S, S, S), (S * S, S), '', (True,)),
+ ('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
+ ('view', (S,), (S,), '1d', (True,)),
+ ('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
+ ('view', (), (1,), 'scalar_to_1d', (True,)),
('reshape', (S, S, S), (S * S, S),),
('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'),
('reshape', (S,), (S,), '1d'),
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
- ('expand', (S, 1, 1), (S, S, S)),
- ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
- ('expand', (S, 1), (S, S, S), 'new_dim'),
- ('expand', (1,), (S, S, S), '1_element'),
- ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'),
- ('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
- ('expand', (), (1, 3, 2), 'scalar_to_dims'),
- ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
- ('exp', (S, S, S), NO_ARGS),
- ('exp', (), NO_ARGS, 'scalar'),
- ('expm1', (S, S, S), NO_ARGS),
- ('expm1', (), NO_ARGS, 'scalar'),
- ('erf', torch.rand(S, S, S), NO_ARGS),
- ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
- ('erfc', torch.rand(S, S, S), NO_ARGS),
- ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
+ ('expand', (S, 1, 1), (S, S, S), '', (True,)),
+ ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (True,)),
+ ('expand', (S, 1), (S, S, S), 'new_dim', (True,)),
+ ('expand', (1,), (S, S, S), '1_element', (True,)),
+ ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)),
+ ('expand', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
+ ('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)),
+ ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)),
+ ('exp', (S, S, S), NO_ARGS, '', (True,)),
+ ('exp', (), NO_ARGS, 'scalar', (True,)),
+ ('expm1', (S, S, S), NO_ARGS, '', (True,)),
+ ('expm1', (), NO_ARGS, 'scalar', (True,)),
+ ('erf', torch.rand(S, S, S), NO_ARGS, '', (True,)),
+ ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('erfc', torch.rand(S, S, S), NO_ARGS, '', (True,)),
+ ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'),
- ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS),
- ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
- ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS),
- ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
- ('log1p', torch.rand(S, S, S), NO_ARGS),
- ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
- ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS),
- ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
- ('tanh', (S, S, S), NO_ARGS),
- ('tanh', (), NO_ARGS, 'scalar'),
- ('sigmoid', (S, S, S), NO_ARGS),
- ('sigmoid', (), NO_ARGS, 'scalar'),
- ('sinh', (S, S, S), NO_ARGS),
- ('sinh', (), NO_ARGS, 'scalar'),
- ('cosh', (S, S, S), NO_ARGS),
- ('cosh', (), NO_ARGS, 'scalar'),
- ('abs', (S, S, S), NO_ARGS),
- ('abs', (), NO_ARGS, 'scalar'),
- ('clamp', (S, S, S), (0, 1)),
- ('clamp', (S, S, S), (None, 0.5), 'min'),
- ('clamp', (S, S, S), (0.5, None), 'max'),
- ('clamp', (), (0, 1), 'scalar'),
- ('clamp', (), (None, 0.5), 'min_scalar'),
- ('clamp', (), (0.5, None), 'max_scalar'),
- ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS),
- ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar'),
- ('sin', (S, S, S), NO_ARGS),
- ('sin', (), NO_ARGS, 'scalar'),
- ('cos', (S, S, S), NO_ARGS),
- ('cos', (), NO_ARGS, 'scalar'),
- ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS),
- ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
- ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
- ('atan', (S, S, S), NO_ARGS),
- ('atan', (), NO_ARGS, 'scalar'),
+ ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+ ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+ ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('log1p', torch.rand(S, S, S), NO_ARGS, '', (True,)),
+ ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+ ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('tanh', (S, S, S), NO_ARGS, '', (True,)),
+ ('tanh', (), NO_ARGS, 'scalar', (True,)),
+ ('sigmoid', (S, S, S), NO_ARGS, '', (True,)),
+ ('sigmoid', (), NO_ARGS, 'scalar', (True,)),
+ ('sinh', (S, S, S), NO_ARGS, '', (True,)),
+ ('sinh', (), NO_ARGS, 'scalar', (True,)),
+ ('cosh', (S, S, S), NO_ARGS, '', (True,)),
+ ('cosh', (), NO_ARGS, 'scalar', (True,)),
+ ('abs', (S, S, S), NO_ARGS, '', (True,)),
+ ('abs', (), NO_ARGS, 'scalar', (True,)),
+ ('clamp', (S, S, S), (0, 1), '', (True,)),
+ ('clamp', (S, S, S), (None, 0.5), 'min', (True,)),
+ ('clamp', (S, S, S), (0.5, None), 'max', (True,)),
+ ('clamp', (), (0, 1), 'scalar', (True,)),
+ ('clamp', (), (None, 0.5), 'min_scalar', (True,)),
+ ('clamp', (), (0.5, None), 'max_scalar', (True,)),
+ ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS, '', (True,)),
+ ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('sin', (S, S, S), NO_ARGS, '', (True,)),
+ ('sin', (), NO_ARGS, 'scalar', (True,)),
+ ('cos', (S, S, S), NO_ARGS, '', (True,)),
+ ('cos', (), NO_ARGS, 'scalar', (True,)),
+ ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS, '', (True,)),
+ ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)),
+ ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)),
+ ('atan', (S, S, S), NO_ARGS, '', (True,)),
+ ('atan', (), NO_ARGS, 'scalar', (True,)),
('atan2', (S, S, S), ((S, S, S),)),
('atan2', (), ((),), 'scalar'),
('atan2', (S, S, S), ((S,),), 'broadcast_rhs'),
('atan2', (S,), ((S, S, S),), 'broadcast_lhs'),
('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
- ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS),
- ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar'),
- ('round', (S, S, S), NO_ARGS),
- ('round', (), NO_ARGS, 'scalar'),
+ ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)),
+ ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('round', (S, S, S), NO_ARGS, '', (True,)),
+ ('round', (), NO_ARGS, 'scalar', (True,)),
('sign', (S, S, S), NO_ARGS),
('sign', (), NO_ARGS, 'scalar'),
- ('trunc', (S, S, S), NO_ARGS),
- ('trunc', (), NO_ARGS, 'scalar'),
- ('floor', (S, S, S), NO_ARGS),
- ('floor', (), NO_ARGS, 'scalar'),
- ('ceil', (S, S, S), NO_ARGS),
- ('ceil', (), NO_ARGS, 'scalar'),
- ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS),
- ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
- ('frac', (S, S, S), NO_ARGS),
- ('frac', (), NO_ARGS, 'scalar'),
- ('fmod', (S, S, S), (1.5,)),
- ('fmod', (), (1.5,), 'scalar'),
+ ('trunc', (S, S, S), NO_ARGS, '', (True,)),
+ ('trunc', (), NO_ARGS, 'scalar', (True,)),
+ ('floor', (S, S, S), NO_ARGS, '', (True,)),
+ ('floor', (), NO_ARGS, 'scalar', (True,)),
+ ('ceil', (S, S, S), NO_ARGS, '', (True,)),
+ ('ceil', (), NO_ARGS, 'scalar', (True,)),
+ ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+ ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+ ('frac', (S, S, S), NO_ARGS, '', (True,)),
+ ('frac', (), NO_ARGS, 'scalar', (True,)),
+ ('fmod', (S, S, S), (1.5,), '', (True,)),
+ ('fmod', (), (1.5,), 'scalar', (True,)),
('fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
('fmod', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
('fmod', (S, S, S), (non_differentiable(torch.rand(S) + 1.5),), 'tensor_broadcast_rhs'),
('fmod', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
('fmod', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
('fmod', (S, S, S), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor_broadcast_rhs'),
- ('remainder', (S, S, S), (1.5,)),
- ('remainder', (), (1.5,), 'scalar'),
+ ('remainder', (S, S, S), (1.5,), '', (True,)),
+ ('remainder', (), (1.5,), 'scalar', (True,)),
('remainder', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
('remainder', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
- ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast'),
- ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'),
- ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'),
- ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all'),
- ('lerp', (), ((), 0.4), 'scalar'),
- ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs'),
- ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs'),
+ ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast', (True,)),
+ ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs', (True,)),
+ ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs', (True,)),
+ ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all', (True,)),
+ ('lerp', (), ((), 0.4), 'scalar', (True,)),
+ ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs', (True,)),
+ ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs', (True,)),
('max', (S, S, S), NO_ARGS),
- ('max', (S, S, S), (1,), 'dim', [0]),
- ('max', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+ ('max', (S, S, S), (1,), 'dim', (), [0]),
+ ('max', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('max', (), NO_ARGS, 'scalar'),
- ('max', (), (0,), 'scalar_dim', [0]),
- ('max', (), (0, True,), 'scalar_keepdim_dim', [0]),
- ('max', (S, S, S), ((S, S, S),), 'elementwise'),
- ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'),
- ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'),
- ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'),
- ('max', (), ((),), 'scalar_elementwise'),
- ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'),
- ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'),
- ('min', (S, S, S), NO_ARGS),
- ('min', (S, S, S), (1,), 'dim', [0]),
- ('min', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+ ('max', (), (0,), 'scalar_dim', (), [0]),
+ ('max', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
+ ('max', (S, S, S), ((S, S, S),), 'elementwise', (True,)),
+ ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs', (True,)),
+ ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs', (True,)),
+ ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all', (True,)),
+ ('max', (), ((),), 'scalar_elementwise', (True,)),
+ ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs', (True,)),
+ ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs', (True,)),
+ ('min', (S, S, S), NO_ARGS, ),
+ ('min', (S, S, S), (1,), 'dim', (), [0]),
+ ('min', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('min', (), NO_ARGS, 'scalar'),
- ('min', (), (0,), 'scalar_dim', [0]),
- ('min', (), (0, True,), 'scalar_keepdim_dim', [0]),
- ('min', (S, S, S), ((S, S, S),), 'elementwise'),
- ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'),
- ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'),
- ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'),
- ('min', (), ((),), 'scalar_elementwise'),
- ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'),
- ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'),
- ('mean', (S, S, S), NO_ARGS),
- ('mean', (S, S, S), (1,), 'dim', [0]),
- ('mean', (S, S, S), (1, True,), 'keepdim_dim', [0]),
- ('mean', (), NO_ARGS, 'scalar'),
- ('mean', (), (0,), 'scalar_dim', [0]),
- ('mean', (), (0, True,), 'scalar_keepdim_dim', [0]),
+ ('min', (), (0,), 'scalar_dim', (), [0]),
+ ('min', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
+ ('min', (S, S, S), ((S, S, S),), 'elementwise', (True,)),
+ ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs', (True,)),
+ ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs', (True,)),
+ ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all', (True,)),
+ ('min', (), ((),), 'scalar_elementwise', (True,)),
+ ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs', (True,)),
+ ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs', (True,)),
+ ('mean', (S, S, S), NO_ARGS, '', (True,)),
+ ('mean', (S, S, S), (1,), 'dim', (True,), [0]),
+ ('mean', (S, S, S), (1, True,), 'keepdim_dim', (True,), [0]),
+ ('mean', (), NO_ARGS, 'scalar', (True,)),
+ ('mean', (), (0,), 'scalar_dim', (True,), [0]),
+ ('mean', (), (0, True,), 'scalar_keepdim_dim', (True,), [0]),
('kthvalue', (S, S, S), (2,)),
('kthvalue', (), (1,), 'scalar'),
- ('kthvalue', (S, S, S), (2, 1,), 'dim', [1]),
- ('kthvalue', (), (1, 0,), 'scalar_dim', [1]),
- ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', [1]),
- ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', [1]),
- ('kthvalue', (S,), (2, 0,), 'dim_1d', [1]),
- ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', [1]),
+ ('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]),
+ ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]),
+ ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]),
+ ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]),
+ ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]),
+ ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]),
('median', (S, S, S), NO_ARGS),
- ('median', (S, S, S), (1,), 'dim', [0]),
- ('median', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+ ('median', (S, S, S), (1,), 'dim', (), [0]),
+ ('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('median', (), NO_ARGS, 'scalar'),
- ('median', (), (0,), 'scalar_dim', [0]),
- ('median', (), (0, True,), 'scalar_keepdim_dim', [0]),
+ ('median', (), (0,), 'scalar_dim', (), [0]),
+ ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('mode', (S, S, S), NO_ARGS),
- ('mode', (S, S, S), (1,), 'dim', [0]),
- ('mode', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+ ('mode', (S, S, S), (1,), 'dim', (), [0]),
+ ('mode', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('mode', (), NO_ARGS, 'scalar'),
- ('mode', (), (0,), 'scalar_dim', [0]),
- ('mode', (), (0, True,), 'scalar_keepdim_dim', [0]),
+ ('mode', (), (0,), 'scalar_dim', (), [0]),
+ ('mode', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('sum', (S, S, S), NO_ARGS),
- ('sum', (S, S, S), (1,), 'dim', [0]),
- ('sum', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+ ('sum', (S, S, S), (1,), 'dim', (), [0]),
+ ('sum', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('sum', (), NO_ARGS, 'scalar'),
- ('sum', (), (0,), 'scalar_dim', [0]),
- ('sum', (), (0, True,), 'scalar_keepdim_dim', [0]),
+ ('sum', (), (0,), 'scalar_dim', (), [0]),
+ ('sum', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('sum', (S, S, S), ([1, 2],), 'multi_dim'),
('sum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
('prod', (S, S, S), NO_ARGS),
- ('prod', (S, S, S), (1,), 'dim', [0]),
- ('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+ ('prod', (S, S, S), (1,), 'dim', (), [0]),
+ ('prod', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('prod', (), NO_ARGS, 'scalar'),
- ('prod', (), (0,), 'scalar_dim', [0]),
- ('prod', (), (0, True,), 'scalar_keepdim_dim', [0]),
+ ('prod', (), (0,), 'scalar_dim', (), [0]),
+ ('prod', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'),
('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'),
('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'),
- ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', [0]),
- ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', [0]),
- ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', [0]),
- ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', [0]),
- ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', [0]),
- ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', [0]),
+ ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', (), [0]),
+ ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', (), [0]),
+ ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', (), [0]),
+ ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', (), [0]),
+ ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', (), [0]),
+ ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', (), [0]),
('prod', prod_single_zero(S), NO_ARGS, 'single_zero'),
('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'),
- ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', [0]),
- ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', [0]),
- ('var', (S, S, S), NO_ARGS),
- ('var', (S, S, S), (1,), 'dim', [0]),
- ('var', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
- ('var', (S,), (0,), 'dim_1d', [0]),
- ('var', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
- ('std', (S, S, S), NO_ARGS),
- ('std', (S, S, S), (1,), 'dim', [0]),
- ('std', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
- ('std', (S,), (0,), 'dim_1d', [0]),
- ('std', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
- ('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]),
+ ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]),
+ ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]),
+ ('var', (S, S, S), NO_ARGS, '', (True,)),
+ ('var', (S, S, S), (1,), 'dim', (True,), [0]),
+ ('var', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]),
+ ('var', (S,), (0,), 'dim_1d', (True,), [0]),
+ ('var', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]),
+ ('std', (S, S, S), NO_ARGS, '', (True,)),
+ ('std', (S, S, S), (1,), 'dim', (True,), [0]),
+ ('std', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]),
+ ('std', (S,), (0,), 'dim_1d', (True,), [0]),
+ ('std', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]),
+ ('renorm', (S, S, S), (2, 1, 0.5), 'dim', (), [1]),
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'),
('repeat', (S,), (2,), 'single_number'),
('repeat', (), (2, 3), 'scalar'),
('repeat', (2, 2), (3, 2)),
('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'),
- ('cumsum', (S, S, S), (0,), 'dim0', [0]),
- ('cumsum', (S, S, S), (1,), 'dim1', [0]),
- ('cumsum', (S, S, S), (1,), 'dim1_cast', [0], (), lambda x: x, {'dtype': torch.float64}),
- ('cumsum', (), (0,), 'dim0_scalar', [0]),
+ ('cumsum', (S, S, S), (0,), 'dim0', (), [0]),
+ ('cumsum', (S, S, S), (1,), 'dim1', (), [0]),
+ ('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), lambda x: x, {'dtype': torch.float64}),
+ ('cumsum', (), (0,), 'dim0_scalar', (), [0]),
('cumprod', (S, S, S), (0,)),
- ('cumprod', (S, S, S), (1,), 'dim1', [0]),
+ ('cumprod', (S, S, S), (1,), 'dim1', (), [0]),
('cumprod', (), (0,), 'scalar'),
('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'),
- ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', [0]),
- ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', [0]),
- ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', [0]),
- ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', [0], (), lambda x: x, {'dtype': torch.float64}),
- ('unfold', (), (0, 1, 1), 'scalar', [0]),
- ('unfold', (S, S, S, S), (1, 3, 1), '', [0]),
- ('unfold', (S, S, S), (2, 3, 2), 'lastdim', [0]),
- ('addmm', (S, M), ((S, S), (S, M)),),
- ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs'),
- ('addmm', (S, M), ((S, S), (S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs'),
- ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', (), [0]),
+ ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', (), [0]),
+ ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', (), [0]),
+ ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', (), [0], (), lambda x: x, {'dtype': torch.float64}),
+ ('unfold', (), (0, 1, 1), 'scalar', (), [0]),
+ ('unfold', (S, S, S, S), (1, 3, 1), '', (), [0]),
+ ('unfold', (S, S, S), (2, 3, 2), 'lastdim', (), [0]),
+ ('addmm', (S, M), ((S, S), (S, M)), '', (True, ['aten::add', 'aten::mm'])),
+ ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs', (True, ['aten::add', 'aten::mm'])),
+ ('addmm', (S, M), ((S, S), (S, M)), 'coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs', (True, ['aten::add', 'aten::mm'])),
+ ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
('addbmm', (S, M), ((S, S, S), (S, S, M)),),
('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'),
- ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef',
+ ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', (),
(), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
- ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x,
+ ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x,
{'beta': 0.2, 'alpha': 0.6}),
('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'),
- ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef',
+ ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', (),
(), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
- ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x,
+ ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x,
{'beta': 0.2, 'alpha': 0.6}),
('addmv', (S,), ((S, M), (M,)),),
('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs'),
- ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs'),
- ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
('addr', (S, M), ((S,), (M,)),),
('addr', (), ((S,), (M,)), 'broadcast_lhs'),
- ('addr', (S, M), ((S,), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
- ('dot', (L,), ((L,),),),
- ('mm', (S, M), ((M, S),)),
- ('bmm', (M, S, M), ((M, M, S),)),
- ('mv', (S, M), ((M,),)),
+ ('addr', (S, M), ((S,), (M,)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+ ('dot', (L,), ((L,),), '', (True,)),
+ ('mm', (S, M), ((M, S),), '', (True,)),
+ ('bmm', (M, S, M), ((M, M, S),), '', (True,)),
+ ('mv', (S, M), ((M,),), '', (True,)),
('ger', (S,), ((M,),)),
- ('matmul', (L,), ((L,),),),
- ('matmul', (S, M), ((M,),), "2d_1d"),
- ('matmul', (M, ), ((M, S),), "1d_2d"),
- ('matmul', (S, M), ((M, S),), "2d_2d"),
- ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"),
- ('matmul', (S, S, M, M), ((M,),), "4d_1d"),
- ('matmul', (M,), ((S, S, M, S),), "1d_4d"),
+ ('matmul', (L,), ((L,),), '', (True,)),
+ ('matmul', (S, M), ((M,),), "2d_1d", (True,)),
+ ('matmul', (M, ), ((M, S),), "1d_2d", (True,)),
+ ('matmul', (S, M), ((M, S),), "2d_2d", (True,)),
+ ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)),
+ ('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)),
+ ('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)),
('matrix_power', (S, S), [2], "n=2"),
('matrix_power', (S, S, S), [3], "n=3"),
('matrix_power', (S, S, S), [1], "n=1"),
('matrix_power', (S, S, S), [0], "n=0"),
- ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1",
+ ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (),
NO_ARGS, [skipIfNoLapack]),
- ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3",
+ ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (),
NO_ARGS, [skipIfNoLapack]),
- ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2",
+ ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (),
NO_ARGS, [skipIfNoLapack]),
('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),
('mvlgamma', torch.empty(S, S).uniform_(2.5, 5), [5], "p=5"),
- ('addcmul', (S, S), ((S, S), (S, S))),
- ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
- ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
- ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}),
- ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
- ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}),
- ('addcmul', (), ((), ()), 'scalar'),
- ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs'),
- ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'),
- ('addcmul', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}),
- ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
- ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}),
+ ('addcmul', (S, S), ((S, S), (S, S)), '', (True,)),
+ ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs', (True,)),
+ ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all', (True,)),
+ ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (True,), (), (), lambda x: x, {'value': 0.5}),
+ ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (True,), (), (), lambda x: x, {'value': 0.5}),
+ ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (True,), (), (), lambda x: x, {'value': 0.5}),
+ ('addcmul', (), ((), ()), 'scalar', (True,)),
+ ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs', (True,)),
+ ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs', (True,)),
+ ('addcmul', (), ((), ()), 'scalar_scale', (True,), (), (), lambda x: x, {'value': 0.5}),
+ ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (True,), (), (), lambda x: x, {'value': 0.5}),
+ ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (True,), (), (), lambda x: x, {'value': 0.5}),
('addcdiv', (S, S), ((S, S), (S, S))),
('addcdiv', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
('addcdiv', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
- ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}),
- ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
- ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}),
+ ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), (), lambda x: x, {'value': 0.5}),
+ ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), (), lambda x: x, {'value': 0.5}),
+ ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), (), lambda x: x, {'value': 0.5}),
('addcdiv', (), ((), ()), 'scalar'),
('addcdiv', (S, S), ((), ()), 'scalar_broadcast_rhs'),
('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'),
- ('addcdiv', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}),
- ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
- ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}),
+ ('addcdiv', (), ((), ()), 'scalar_scale', (), (), (), lambda x: x, {'value': 0.5}),
+ ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), (), lambda x: x, {'value': 0.5}),
+ ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), (), lambda x: x, {'value': 0.5}),
('zero_', (S, S, S), NO_ARGS),
('zero_', (), NO_ARGS, 'scalar'),
- ('logsumexp', (S, S), (1,)),
- ('logsumexp', (), (0,), 'scalar'),
+ ('logsumexp', (S, S), (1,), '', (True,)),
+ ('logsumexp', (), (0,), 'scalar', (True,)),
('norm', (S, S), (), 'default'),
('norm', (S, S), (2,), '2'),
('norm', (S, S), (0,), '0'),
('norm', (S, S), (-inf,), '-inf'),
('norm', (S, S), ('fro',), 'fro_default'),
('norm', (S, S), ('fro', [0, 1],), 'fro'),
- ('norm', (S, S), ('nuc',), 'nuc', NO_ARGS, [skipIfNoLapack]),
+ ('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
('norm', (S, S), (-1,), 'neg_1'),
('norm', (S, S), (-2,), 'neg_2'),
('norm', (S, S), (-0.5,), 'neg_0_5'),
('norm', (S, S), (-1.5,), 'neg_1_5'),
- ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', [1]),
- ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', [1]),
- ('norm', (S, S), (0, 1,), '0_2_dim', [1]),
- ('norm', (S, S), (1, 1,), '1_2_dim', [1]),
- ('norm', (S, S), (2, 1,), '2_2_dim', [1]),
- ('norm', (S, S), (3, 1,), '3_2_dim', [1]),
+ ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', (), [1]),
+ ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', (), [1]),
+ ('norm', (S, S), (0, 1,), '0_2_dim', (), [1]),
+ ('norm', (S, S), (1, 1,), '1_2_dim', (), [1]),
+ ('norm', (S, S), (2, 1,), '2_2_dim', (), [1]),
+ ('norm', (S, S), (3, 1,), '3_2_dim', (), [1]),
('norm', (S, S), (inf, 1,), 'inf_2_dim'),
('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5_default'),
- ('norm', (S, S, S), (2, 1), '2_dim', [1]),
- ('norm', (S, S, S), (3, 1), '3_dim', [1]),
- ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', [1]),
- ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', [1]),
- ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', [1]),
- ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', [1]),
- ('norm', (), (2, 0), '2_dim_scalar', [1]),
- ('norm', (), (3, 0), '3_dim_scalar', [1]),
- ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', [1]),
- ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', [1]),
+ ('norm', (S, S, S), (2, 1), '2_dim', (), [1]),
+ ('norm', (S, S, S), (3, 1), '3_dim', (), [1]),
+ ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', (), [1]),
+ ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', (), [1]),
+ ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', (), [1]),
+ ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', (), [1]),
+ ('norm', (), (2, 0), '2_dim_scalar', (), [1]),
+ ('norm', (), (3, 0), '3_dim_scalar', (), [1]),
+ ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', (), [1]),
+ ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', (), [1]),
('clone', (S, M, S), NO_ARGS),
('clone', (), NO_ARGS, 'scalar'),
('dist', (S, S, S), ((S, S, S),)),
('trace', (M, M), NO_ARGS),
('cross', (S, 3), ((S, 3),)),
('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
- ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]),
- ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', [0]),
- ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', [0]),
- ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', [0]),
- ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]),
- ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]),
- ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]),
- ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]),
- ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]),
- ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
- ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', [0]),
- ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', [0]),
- ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', [0]),
- ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', [0]),
- ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
+ ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', (), [0]),
+ ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', (), [0]),
+ ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', (), [0]),
+ ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', (), [0]),
+ ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]),
+ ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]),
+ ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', (), [0]),
+ ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]),
+ ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]),
+ ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', (), [0]),
+ ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', (), [0]),
+ ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', (), [0]),
+ ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', (), [0]),
+ ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', (), [0]),
+ ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3),
- NO_ARGS, 'batched', NO_ARGS, [skipIfNoLapack]),
- ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
- ('det', (1, 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]),
- ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]),
- ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', NO_ARGS, [skipIfNoLapack]),
- ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
- ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', NO_ARGS, [skipIfNoLapack]),
- ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', NO_ARGS, [skipIfNoLapack]),
- ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', NO_ARGS, [skipIfNoLapack]),
+ NO_ARGS, 'batched', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', (S, S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', (1, 1), NO_ARGS, '1x1', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), NO_ARGS, [skipIfNoLapack]),
+ ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), NO_ARGS, [skipIfNoLapack]),
('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
- 'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
+ 'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
# For `logdet` and `slogdet`, the function at det=0 is not smooth.
# We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use
# `make_nonzero_det` to make the random matrices have nonzero det. For
# `logdet`, we also set `make_nonzero_det(matrix, sign=1)` to make the
# matrix have positive det.
- ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
- ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]),
+ ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
+ ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', (), NO_ARGS, [skipIfNoLapack]),
('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS,
- 'symmetric', NO_ARGS, [skipIfNoLapack]),
+ 'symmetric', (), NO_ARGS, [skipIfNoLapack]),
('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS,
- 'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
+ 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS,
- 'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
+ 'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS,
- '1x1_pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ '1x1_pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS,
- '1x1_neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ '1x1_neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
('slogdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS,
- 'pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ 'pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
('slogdet', lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS,
- 'neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ 'neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS,
- 'symmetric', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ 'symmetric', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
('slogdet', lambda: random_symmetric_pd_matrix(S), NO_ARGS,
- 'symmetric_pd', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
- 'distinct_singular_values', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
- ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', NO_ARGS, [skipIfNoLapack]),
- ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', NO_ARGS, [skipIfNoLapack]),
- ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', NO_ARGS, [skipIfNoLapack]),
- ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
+ 'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+ ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', (), NO_ARGS, [skipIfNoLapack]),
+ ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', (), NO_ARGS, [skipIfNoLapack]),
+ ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', (), NO_ARGS, [skipIfNoLapack]),
+ ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS,
- 'wide', NO_ARGS, [skipIfNoLapack]),
+ 'wide', (), NO_ARGS, [skipIfNoLapack]),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS,
- 'tall', NO_ARGS, [skipIfNoLapack]),
+ 'tall', (), NO_ARGS, [skipIfNoLapack]),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,),
- 'wide_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])),
+ 'wide_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,),
- 'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
+ 'tall_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
- 'large', NO_ARGS, [skipIfNoLapack]),
+ 'large', (), NO_ARGS, [skipIfNoLapack]),
('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
- S, silent=True),), '', NO_ARGS, [skipIfNoLapack]),
+ S, silent=True),), '', (), NO_ARGS, [skipIfNoLapack]),
('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),
- 'batched', NO_ARGS, [skipIfNoLapack]),
+ 'batched', (), NO_ARGS, [skipIfNoLapack]),
('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),),
- 'batched_dims', NO_ARGS, [skipIfNoLapack]),
+ 'batched_dims', (), NO_ARGS, [skipIfNoLapack]),
('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),),
- 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
+ 'batched_broadcast_A', (), NO_ARGS, [skipIfNoLapack]),
('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),),
- 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
+ 'batched_broadcast_b', (), NO_ARGS, [skipIfNoLapack]),
('fill_', (S, S, S), (1,), 'number'),
('fill_', (), (1,), 'number_scalar'),
('fill_', (S, S, S), ((),), 'variable'),
('ge_', (), (0,), 'pyscalar_scalar'),
('lt_', (), (0,), 'pyscalar_scalar'),
('le_', (), (0,), 'pyscalar_scalar'),
- ('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
- ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim'),
- ('permute', (), (dont_convert(()),), 'scalar'),
- ('select', (S, S, S), (1, 2), 'dim', [0]),
- ('select', (S, S, S), (1, -1), 'wrap_dim', [0]),
+ ('permute', (1, 2, 3, 4), (0, 2, 3, 1), '', (True,)),
+ ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim', (True,)),
+ ('permute', (), (dont_convert(()),), 'scalar', (True,)),
+ ('select', (S, S, S), (1, 2), 'dim', (), [0]),
+ ('select', (S, S, S), (1, -1), 'wrap_dim', (), [0]),
('select', (S,), (0, 2), '1d'),
- ('narrow', (S, S, S), (1, 2, 2), 'dim', [0]),
- ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', [0]),
- ('squeeze', (S, 1, S, 1), NO_ARGS),
- ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones'),
- ('squeeze', (S, 1, S, 1), (1,), '1_dim', [0]),
- ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', [0]),
- ('squeeze', (), (0,), 'scalar', [0]),
- ('unsqueeze', (S, S, S), (0,), 'first', [0]),
- ('unsqueeze', (S, S, S), (1,), 'middle', [0]),
- ('unsqueeze', (S, S, S), (3,), 'last', [0]),
- ('unsqueeze', (), (0,), 'scalar', [0]),
- ('chunk', (S, S, S), (2,)),
- ('chunk', (S, S, S), (S, 1), 'dim', [1]),
+ ('narrow', (S, S, S), (1, 2, 2), 'dim', (), [0]),
+ ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', (), [0]),
+ ('squeeze', (S, 1, S, 1), NO_ARGS, '', (True,)),
+ ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones', (True,)),
+ ('squeeze', (S, 1, S, 1), (1,), '1_dim', (True,), [0]),
+ ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', (True,), [0]),
+ ('squeeze', (), (0,), 'scalar', (True,), [0]),
+ ('unsqueeze', (S, S, S), (0,), 'first', (True,), [0]),
+ ('unsqueeze', (S, S, S), (1,), 'middle', (True,), [0]),
+ ('unsqueeze', (S, S, S), (3,), 'last', (True,), [0]),
+ ('unsqueeze', (), (0,), 'scalar', (True,), [0]),
+ ('chunk', (S, S, S), (2,), '', (True, 'prim::ConstantChunk')),
+ ('chunk', (S, S, S), (S, 1), 'dim', (True, 'prim::ConstantChunk'), [1]),
('split', (S, S, S), (2,)),
- ('split', (S, S, S), (S, 1), 'dim', [1]),
+ ('split', (S, S, S), (S, 1), 'dim', (), [1]),
('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list'),
- ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', [1]),
- ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', [0]),
- ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', [0]),
- ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', [0]),
- ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', [0]),
- ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]),
- ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
- ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
- ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]),
- ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]),
- ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
- ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
- ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
+ ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', (), [1]),
+ ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]),
+ ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]),
+ ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]),
+ ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', (), [0]),
+ ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', (), [0]),
+ ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]),
+ ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]),
+ ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', (), [0]),
+ ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', (), [0]),
+ ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]),
+ ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]),
+ ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', (), [0]),
('masked_select', (M, M), (mask_not_all_zeros((M, M)),)),
('masked_select', (M, M), (mask_not_all_zeros((M,)),), 'broadcast_rhs'),
('masked_select', (M,), (mask_not_all_zeros((M, M)),), 'broadcast_lhs'),
('sort', (), (0,), 'dim_scalar'),
('sort', (), (0, True), 'dim_desc_scalar'),
('topk', (S, M, S), (3,)),
- ('topk', (S, M, S), (3, 1), 'dim', [1]),
- ('topk', (S, M, S), (3, 1, True), 'dim_desc', [1]),
- ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', [1]),
+ ('topk', (S, M, S), (3, 1), 'dim', (), [1]),
+ ('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]),
+ ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', (), [1]),
('topk', (), (1,), 'scalar'),
- ('topk', (), (1, 0), 'dim_scalar', [1]),
- ('topk', (), (1, 0, True), 'dim_desc_scalar', [1]),
- ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', [1]),
+ ('topk', (), (1, 0), 'dim_scalar', (), [1]),
+ ('topk', (), (1, 0, True), 'dim_desc_scalar', (), [1]),
+ ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', (), [1]),
('take', (S, S, S), (torch.LongTensor([[-3, 2], [20, 2]]),)),
('take', (S, S, S), (torch.tensor(0, dtype=torch.int64),), 'scalar_index'),
('take', (), (torch.LongTensor([0]),), 'scalar_data'),
('take', (), (torch.tensor(0, dtype=torch.int64),), 'scalar_both'),
- ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M))),
- ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all'),
- ('where', (), (bernoulli_scalar(), ()), 'scalar'),
- ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask'),
- ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask'),
+ ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M)), '', (True,)),
+ ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all', (True,)),
+ ('where', (), (bernoulli_scalar(), ()), 'scalar', (True,)),
+ ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask', (True,)),
+ ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask', (True,)),
('__getitem__', torch.randn(S, S, S), (dont_convert([1, 2]),)),
('__getitem__', torch.randn(S, S, S), (slice(0, 3),), 'slice'),
('__getitem__', torch.randn(S, S, S), (dont_convert([slice(0, 3), 1]),), 'slice_index'),
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
+# Note: creating FusionGroups is currently device-independent.
+# FusionGroup creation with CPU is disabled.
+FUSION_ENABLED = torch._C._jit_can_fuse_on_cpu() or torch._C._jit_can_fuse_on_gpu()
+
RUN_CUDA = torch.cuda.is_available()
RUN_CUDA_HALF = RUN_CUDA
if torch.cuda.is_available():
torch._C._jit_pass_lint(graph)
self.assertExpected(str(graph), *args, **kwargs)
+ def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
+ if not FUSION_ENABLED:
+ nonfusible_nodes = nonfusible_nodes + fusible_nodes
+ fusible_nodes = []
+ diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
+ diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
+
+ # For any non-fusible node, it must show up in one of the DifferentiableGraph.
+ found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0)\
+ or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes])
+
+ # For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph.
+ fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
+ fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
+ found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0)\
+ or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes])
+
+ self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes)
+
def run_pass(self, name, trace):
if isinstance(trace, torch._C.Graph):
graph = trace
'test_nn_fold',
}
+# chunk returns a list in scripting and we don't unpack the list,
+# Thus it won't be replaced by ConstantChunk and run AD.
+# It's explicitly checked in test_chunk_constant_script_ad
+EXCLUDE_SCRIPT_AD_CHECK = {
+ 'test_chunk',
+ 'test_chunk_dim',
+ 'test_chunk_dim_neg0',
+}
+
EXCLUDE_PYTHON_PRINT = {
# no support for BroadcastingList in python printer
'test_nn_max_unpool1d',
'test_nn_AdaptiveMaxPool3d_tuple_none',
}
-DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
- 'test_nn_avg_pool2d',
- 'test_nn_adaptive_avg_pool1d',
- 'test_nn_adaptive_avg_pool2d',
- 'test_nn_adaptive_avg_pool3d',
- 'test_nn_batch_norm',
- 'test_nn_embedding',
- 'test_nn_log_softmax',
- 'test_nn_softmax',
- 'test_nn_softmax_with_all_args',
- 'test_nn_threshold',
- 'test_nn_nll_loss',
- # Should have added all test_nn_interpolate_* here,
- # but it's using autodiff since its subgraph is over
- # 2 nodes.
-}
-
# make a new function where all non-tensor arguments in 'args' have been partially
# applied, and all tensor arguments remain.
def assertGraphSize(self, graph, size):
self.assertEqual(len(list(graph.nodes())), size)
+ def test_chunk_constant_script_ad(self):
+ @torch.jit.script
+ def func(x):
+ x1, x2 = torch.chunk(x, 2)
+ return (x1, x2)
+
+ input = torch.rand(6, 10).requires_grad_()
+ func.debug_disable_autodiff_subgraph_inlining()
+ output = func(input)
+ self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
+
def test_simple_merge(self):
# o --> o
def fn(x, y, z):
# args (tuple represents shape of a tensor arg),
# test variant name(will be used at test name suffix,
# 'inplace' skips grad tests), // optional
+# (True, nonfusible_nodes, fusible_nodes) for autodiff // optional
# fn to determine if test should be skipped, // optional
# fn mapping output to part that should be gradcheck'ed, // optional
# kwargs for function, // optional
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
('avg_pool1d', (S, S, S), (3,)),
- ('avg_pool2d', (S, S, S, S), (3,)),
+ ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
('avg_pool3d', (S, S, S, S, S), (3,)),
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
('max_pool1d', (S, S, S), (2, 1)),
('adaptive_max_pool1d', (S, S, S), (5,)),
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
- ('adaptive_avg_pool1d', (S, S, S), (5,)),
- ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],)),
- ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
- ('dropout', (S, S, S), (0.5,)),
+ ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
+ ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
+ ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
+ ('dropout', (S, S, S), (0.5,), '', (True,
+ ['prim::is_cuda', 'aten::bernoulli_'],
+ ['aten::rand_like', 'aten::lt', 'aten::type_as', 'aten::mul', 'aten::div'])),
('alpha_dropout', (S, S, S), (0.5,)),
('dropout2d', (S, S, S), (0.5,)),
('dropout3d', (S, S, S), (0.5,)),
('feature_alpha_dropout', (S, S, S), (0.5,)),
- ('threshold', (S, S, S), (0.1, 2.),),
+ ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
('relu', (S, S, S), (),),
('relu', (S, S, S), (), 'inplace'),
('softsign', (S, S, S), (),),
('softplus', (S, S, S), (),),
('softmin', (S, S, S), (0,),),
- ('softmax', (S, S, S), (0,),),
- ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args'),
+ ('softmax', (S, S, S), (0,), '', (True,)),
+ ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
('tanh', (S, S, S), (),),
('sigmoid', (S, S, S), (),),
- ('log_softmax', (S, S, S), (0,),),
+ ('log_softmax', (S, S, S), (0,), '', (True,)),
('linear', (S, S), ((M, S),),),
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
- ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),),
+ ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
- ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
+ '', (True, 'aten::_batch_norm_impl_index')),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
('layer_norm', (S, S, S, S), ([5],),),
('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
('group_norm', (S, S, S), (1, torch.rand(5),),),
('local_response_norm', (S, S, S), (2, ),),
- ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),),),
+ ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')),
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
('unfold', (S, S, S, S), ([2, 3]),),
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
- ('gumbel_softmax', (S, S), (2.,),),
- ('gumbel_softmax', (S, S), (2., True,), 'hard'),
+ ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])),
+ ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])),
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
1, 1., non_differentiable(torch.randn(S))),),
torch.randint(1, S, (S,), dtype=torch.long))),
('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
- ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
- ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
- ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
- ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
- ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
- ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
- ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
- ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
- ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
- ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
- ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
- ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
- ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
- ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
- ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
- ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
- ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
- ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
- ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
- ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
- ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
- ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
- ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
+ ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size', (True, 'aten::__interpolate')),
+ ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale', (True, 'aten::__interpolate')),
+ ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size', (True, 'aten::__interpolate')),
]
self_size,
args,
variant_name='',
+ check_ad=(),
dim_args_idx=(),
skipTestIf=(),
output_process_fn=lambda x: x,
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
- output_process_fn=output_process_fn):
+ check_ad=check_ad, output_process_fn=output_process_fn):
def check(name):
set_rng_seed(2)
is_magic_method = name[:2] == '__' and name[-2:] == '__'
# Test with disable_autodiff_subgraph_inlining, which forces the graph
# to contain DifferentiableGraph nodes whenever possible. This allows us
# to test autodiff; we assume that autograd is correct and use autodiff for backprop
+ should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
if test_name not in EXCLUDE_TRACED:
- check_against_reference(self,
- create_traced_fn(self, fn,
- disable_autodiff_subgraph_inlining=True),
+ traced_fn = create_traced_fn(self, fn, disable_autodiff_subgraph_inlining=True)
+
+ check_against_reference(self, traced_fn,
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)
+ self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
- check_against_reference(self,
- create_script_fn(self, name, 'method', output_process_fn,
- disable_autodiff_subgraph_inlining=True),
+ script_fn = create_script_fn(self, name, 'method', output_process_fn,
+ disable_autodiff_subgraph_inlining=True)
+ check_against_reference(self, script_fn,
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)
+ self.assertAutodiffNode(script_fn.last_graph,
+ should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
+ autodiff_nodes,
+ fusible_nodes)
+
# functional interface tests
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
def fn(*inputs, **kwargs):
return wrapper
-def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=(),
+def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(), skipTestIf=(),
output_process_fn=lambda x: x, kwargs=None):
test_name = 'test_nn_' + name
no_grad = variant_name == 'inplace'
@suppress_warnings
- def do_test(self, name=name, args=args, test_name=test_name):
+ def do_test(self, name=name, args=args, test_name=test_name, check_ad=check_ad):
torch.manual_seed(2)
self_variable = create_input((self_size,))[0][0]
f_args_variable = (self_variable,) + args_variable
f_args_tensor = (self_tensor,) + args_tensor
+ should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
if test_name not in EXCLUDE_SCRIPT:
- disable_ad_subgraph_inlining = test_name in DISABLE_AUTODIFF_SUBGRAPH_INLINING
-
def run_test():
script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn,
- disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining)
+ disable_autodiff_subgraph_inlining=should_autodiff_node)
check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
+ # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd
+ self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
if test_name in EXCLUDE_PYTHON_PRINT:
with self.disableModuleHook():
setattr(test_class, test_name, do_test)
+def normalize_check_ad(check_ad, name):
+ # normalized check_ad is 3-element tuple: (bool, List[str], List[str])
+ if len(check_ad) == 0:
+ check_ad = [False, ['aten::' + name], []]
+ elif len(check_ad) == 1:
+ check_ad = [check_ad[0], ['aten::' + name], []]
+ elif len(check_ad) == 2:
+ check_ad = [check_ad[0], check_ad[1], []]
+ elif len(check_ad) == 3:
+ check_ad = list(check_ad)
+ else:
+ raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])')
+
+ check_ad = [[t] if isinstance(t, str) else t for t in check_ad]
+
+ return check_ad
+
+
class TestAsync(JitTestCase):
def test_async_python(self):
@torch.jit.script