From: Ailing Zhang Date: Sun, 31 Mar 2019 15:41:46 +0000 (-0700) Subject: Enforce check ad in test_jit (#18509) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~522 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9c87543124f0b15db0a55af66c5553bc5af014a7;p=platform%2Fupstream%2Fpytorch.git Enforce check ad in test_jit (#18509) Summary: If a test triggers autodiff, it must have a `DifferentiableGraph` in its differentiated forward graph, and this subgraph must have either the original aten node, or the corresponding nodes used in AD formula. Typically a forward differentiable graph looks like this: ``` graph(%i0 : Float(), %i1 : Float()): %3 : Float() = prim::DifferentiableGraph_0(%i0, %i1) return (%3) with prim::DifferentiableGraph_0 = graph(%0 : Float(), %1 : Float()): %2 : Float() = aten::max(%0, %1) return (%2) ``` which tells us `aten::max(Tensor self, Tensor other) -> Tensor` is symbolically differentiable. Update: there're a lot of cases (fusions/ConstantChunk/python implementations) that breaks it so I decided to make the check optionally take node names if different from function name. ~~[OLD]Theoretically I could also check if `aten::max` is in the differentiable block or not to be more precise, but there're also cases like `chunk` where in a differentiable block it's replaced with a prim node (ConstantChunk) and we will have to special case them. Any suggestions here (to be more precise or no) is very welcome!~~ We used to have a list containing nn tests should be run against AD, I moved it to an field when constructing our test(either torch or nn). I think it's cleaner this way, and it matches the fact that for the same op we support one schema of it but not all, in this way we could just turn on the corresponding test which triggers that supported schema. cc: apaszke zdevito wanchaol ngimel for a review [Done] : - Going through a manual second pass of all tests to check if they should enable AD test or not.... - Add a readme about how to add AD for an op and how to add/enable its test in test_jit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18509 Differential Revision: D14696811 Pulled By: ailzhang fbshipit-source-id: c5e693277baac585cd3aed5ab2c0e7faa5e6f29f --- diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 43ef996..d1f0809 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -102,77 +102,85 @@ S = 5 # input size/constructing fn, # args (tuple represents shape of a tensor arg), # test variant name (will be used at test name suffix), // optional +# (True, nonfusible_nodes, fusible_nodes) for autodiff, // optional # indices for possible dim arg, // optional # fn mapping output to part that should be gradcheck'ed, // optional # ) +# Note: some functions have separate schema for (Tensor other) and (Scalar other), +# and it's possible that we only support AD for Scalar version but not Tensor +# version, and vice versa. +# When writing tests, only scalar(float/int) input triggers the Scalar schema. +# uniform_scalar produces a scalar **Tensor** which won't match Scalar input. def method_tests(): set_rng_seed(0) return [ - ('add', (S, S, S), ((S, S, S),)), - ('add', (S, S, S), ((S, S),), 'broadcast_rhs'), - ('add', (S, S), ((S, S, S),), 'broadcast_lhs'), - ('add', (S, 1, S), ((M, S),), 'broadcast_all'), - ('add', (), ((),), 'scalar'), - ('add', (S, S, S), ((),), 'scalar_broadcast_rhs'), - ('add', (), ((S, S, S),), 'scalar_broadcast_lhs'), - ('add', (S, S, S), (3.14,), 'constant'), - ('add', (), (3.14,), 'scalar_constant'), - ('__radd__', (S, S, S), (3.14,), 'constant'), - ('__radd__', (), (3.14,), 'scalar_constant'), - ('sub', (S, S, S), ((S, S, S),)), - ('sub', (S, S, S), ((S, S),), 'broadcast_rhs'), - ('sub', (S, S), ((S, S, S),), 'broadcast_lhs'), - ('sub', (S, 1, S), ((M, S),), 'broadcast_all'), - ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs'), - ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs'), - ('sub', (S, S, S), (3.14,), 'constant'), - ('sub', (), (3.14,), 'scalar_constant'), - ('__rsub__', (S, S, S), (3.14,), 'constant'), - ('__rsub__', (), (3.14,), 'scalar_constant'), - ('mul', (S, S, S), ((S, S, S),)), - ('mul', (), ((),), 'scalar'), - ('mul', (S, S, S), ((S, S),), 'broadcast_rhs'), - ('mul', (S, S), ((S, S, S),), 'broadcast_lhs'), - ('mul', (S, 1, S), ((M, S),), 'broadcast_all'), - ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs'), - ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs'), - ('mul', (S, S, S), (3.14,), 'constant'), - ('mul', (), (3.14,), 'scalar_constant'), - ('__rmul__', (S, S, S), (3.14,), 'constant'), - ('__rmul__', (), (3.14,), 'scalar_constant'), - ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,)), - ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs'), - ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'), - ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all'), - ('div', (), (uniform_scalar(0.1),), 'scalar'), - ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs'), - ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs'), - ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'), - ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'), - ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'), - ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'), - ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,)), - ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'), - ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'), - ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'), - ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'), - ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'), - ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'), - ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'), - ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'), - ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'), - ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'), - ('transpose', (1, 2, 3), (1, 2), 'dim', [0, 1]), - ('transpose', (), (0, 0), 'scalar'), - ('transpose', (1,), (0, 0), '1d'), - ('transpose', torch.rand(L, L), (0, 1), '2d'), - ('transpose', torch.rand(S, S, S), (2, 0), '3d'), - ('t', (1, 2), NO_ARGS), - ('view', (S, S, S), (S * S, S),), - ('view', (S, S, S), (torch.Size([S * S, S]),), 'size'), - ('view', (S,), (S,), '1d'), - ('view', (), (dont_convert(()),), 'scalar_to_scalar'), - ('view', (), (1,), 'scalar_to_1d'), + ('add', (S, S, S), ((S, S, S),), '', (True,)), + ('add', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)), + ('add', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)), + ('add', (S, 1, S), ((M, S),), 'broadcast_all', (True,)), + ('add', (), ((),), 'scalar', (True,)), + ('add', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)), + ('add', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), + ('add', (S, S, S), (3.14,), 'constant', (True,)), + ('add', (), (3.14,), 'scalar_constant', (True,)), + ('__radd__', (S, S, S), (3.14,), 'constant', (True, 'aten::add')), + ('__radd__', (), (3.14,), 'scalar_constant', (True, 'aten::add')), + ('sub', (S, S, S), ((S, S, S),), '', (True,)), + ('sub', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)), + ('sub', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)), + ('sub', (S, 1, S), ((M, S),), 'broadcast_all', (True,)), + ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)), + ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), + ('sub', (S, S, S), (3.14,), 'constant', (True,)), + ('sub', (), (3.14,), 'scalar_constant', (True,)), + ('__rsub__', (S, S, S), (3.14,), 'constant', (True, 'aten::rsub')), + ('__rsub__', (), (3.14,), 'scalar_constant', (True, 'aten::rsub')), + ('mul', (S, S, S), ((S, S, S),), '', (True,)), + ('mul', (), ((),), 'scalar', (True,)), + ('mul', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)), + ('mul', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)), + ('mul', (S, 1, S), ((M, S),), 'broadcast_all', (True,)), + ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)), + ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)), + ('mul', (S, S, S), (3.14,), 'constant', (True,)), + ('mul', (), (3.14,), 'scalar_constant', (True,)), + ('__rmul__', (S, S, S), (3.14,), 'constant', (True, 'aten::mul')), + ('__rmul__', (), (3.14,), 'scalar_constant', (True, 'aten::mul')), + ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,), '', (True,)), + ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs', (True,)), + ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)), + ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all', (True,)), + ('div', (), (uniform_scalar(0.1),), 'scalar', (True,)), + ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)), + ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs', (True,)), + ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant', (True,)), + ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant', + (True, [], ['aten::mul', 'aten::reciprocal'])), + ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True,)), + ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', + (True, [], ['aten::mul', 'aten::reciprocal'])), + ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), '', (True,)), + ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs', (True,)), + ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)), + ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all', (True,)), + ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar', (True,)), + ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)), + ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs', (True,)), + ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True,)), + ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')), + ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)), + ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')), + ('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]), + ('transpose', (), (0, 0), 'scalar', (True,)), + ('transpose', (1,), (0, 0), '1d', (True,)), + ('transpose', torch.rand(L, L), (0, 1), '2d', (True,)), + ('transpose', torch.rand(S, S, S), (2, 0), '3d', (True,)), + ('t', (1, 2), NO_ARGS, '', (True,)), + ('view', (S, S, S), (S * S, S), '', (True,)), + ('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)), + ('view', (S,), (S,), '1d', (True,)), + ('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)), + ('view', (), (1,), 'scalar_to_1d', (True,)), ('reshape', (S, S, S), (S * S, S),), ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'), ('reshape', (S,), (S,), '1d'), @@ -201,82 +209,82 @@ def method_tests(): ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'), ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), - ('expand', (S, 1, 1), (S, S, S)), - ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'), - ('expand', (S, 1), (S, S, S), 'new_dim'), - ('expand', (1,), (S, S, S), '1_element'), - ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'), - ('expand', (), (dont_convert(()),), 'scalar_to_scalar'), - ('expand', (), (1, 3, 2), 'scalar_to_dims'), - ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)), - ('exp', (S, S, S), NO_ARGS), - ('exp', (), NO_ARGS, 'scalar'), - ('expm1', (S, S, S), NO_ARGS), - ('expm1', (), NO_ARGS, 'scalar'), - ('erf', torch.rand(S, S, S), NO_ARGS), - ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'), - ('erfc', torch.rand(S, S, S), NO_ARGS), - ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'), + ('expand', (S, 1, 1), (S, S, S), '', (True,)), + ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (True,)), + ('expand', (S, 1), (S, S, S), 'new_dim', (True,)), + ('expand', (1,), (S, S, S), '1_element', (True,)), + ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)), + ('expand', (), (dont_convert(()),), 'scalar_to_scalar', (True,)), + ('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)), + ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)), + ('exp', (S, S, S), NO_ARGS, '', (True,)), + ('exp', (), NO_ARGS, 'scalar', (True,)), + ('expm1', (S, S, S), NO_ARGS, '', (True,)), + ('expm1', (), NO_ARGS, 'scalar', (True,)), + ('erf', torch.rand(S, S, S), NO_ARGS, '', (True,)), + ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('erfc', torch.rand(S, S, S), NO_ARGS, '', (True,)), + ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), ('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS), ('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'), - ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS), - ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'), - ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS), - ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'), - ('log1p', torch.rand(S, S, S), NO_ARGS), - ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'), - ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS), - ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'), - ('tanh', (S, S, S), NO_ARGS), - ('tanh', (), NO_ARGS, 'scalar'), - ('sigmoid', (S, S, S), NO_ARGS), - ('sigmoid', (), NO_ARGS, 'scalar'), - ('sinh', (S, S, S), NO_ARGS), - ('sinh', (), NO_ARGS, 'scalar'), - ('cosh', (S, S, S), NO_ARGS), - ('cosh', (), NO_ARGS, 'scalar'), - ('abs', (S, S, S), NO_ARGS), - ('abs', (), NO_ARGS, 'scalar'), - ('clamp', (S, S, S), (0, 1)), - ('clamp', (S, S, S), (None, 0.5), 'min'), - ('clamp', (S, S, S), (0.5, None), 'max'), - ('clamp', (), (0, 1), 'scalar'), - ('clamp', (), (None, 0.5), 'min_scalar'), - ('clamp', (), (0.5, None), 'max_scalar'), - ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS), - ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar'), - ('sin', (S, S, S), NO_ARGS), - ('sin', (), NO_ARGS, 'scalar'), - ('cos', (S, S, S), NO_ARGS), - ('cos', (), NO_ARGS, 'scalar'), - ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS), - ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS), - ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS), - ('atan', (S, S, S), NO_ARGS), - ('atan', (), NO_ARGS, 'scalar'), + ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), + ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), + ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('log1p', torch.rand(S, S, S), NO_ARGS, '', (True,)), + ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), + ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('tanh', (S, S, S), NO_ARGS, '', (True,)), + ('tanh', (), NO_ARGS, 'scalar', (True,)), + ('sigmoid', (S, S, S), NO_ARGS, '', (True,)), + ('sigmoid', (), NO_ARGS, 'scalar', (True,)), + ('sinh', (S, S, S), NO_ARGS, '', (True,)), + ('sinh', (), NO_ARGS, 'scalar', (True,)), + ('cosh', (S, S, S), NO_ARGS, '', (True,)), + ('cosh', (), NO_ARGS, 'scalar', (True,)), + ('abs', (S, S, S), NO_ARGS, '', (True,)), + ('abs', (), NO_ARGS, 'scalar', (True,)), + ('clamp', (S, S, S), (0, 1), '', (True,)), + ('clamp', (S, S, S), (None, 0.5), 'min', (True,)), + ('clamp', (S, S, S), (0.5, None), 'max', (True,)), + ('clamp', (), (0, 1), 'scalar', (True,)), + ('clamp', (), (None, 0.5), 'min_scalar', (True,)), + ('clamp', (), (0.5, None), 'max_scalar', (True,)), + ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS, '', (True,)), + ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('sin', (S, S, S), NO_ARGS, '', (True,)), + ('sin', (), NO_ARGS, 'scalar', (True,)), + ('cos', (S, S, S), NO_ARGS, '', (True,)), + ('cos', (), NO_ARGS, 'scalar', (True,)), + ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS, '', (True,)), + ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)), + ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)), + ('atan', (S, S, S), NO_ARGS, '', (True,)), + ('atan', (), NO_ARGS, 'scalar', (True,)), ('atan2', (S, S, S), ((S, S, S),)), ('atan2', (), ((),), 'scalar'), ('atan2', (S, S, S), ((S,),), 'broadcast_rhs'), ('atan2', (S,), ((S, S, S),), 'broadcast_lhs'), ('atan2', (S, 1, S), ((S, S),), 'broadcast_all'), - ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS), - ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar'), - ('round', (S, S, S), NO_ARGS), - ('round', (), NO_ARGS, 'scalar'), + ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)), + ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('round', (S, S, S), NO_ARGS, '', (True,)), + ('round', (), NO_ARGS, 'scalar', (True,)), ('sign', (S, S, S), NO_ARGS), ('sign', (), NO_ARGS, 'scalar'), - ('trunc', (S, S, S), NO_ARGS), - ('trunc', (), NO_ARGS, 'scalar'), - ('floor', (S, S, S), NO_ARGS), - ('floor', (), NO_ARGS, 'scalar'), - ('ceil', (S, S, S), NO_ARGS), - ('ceil', (), NO_ARGS, 'scalar'), - ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS), - ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'), - ('frac', (S, S, S), NO_ARGS), - ('frac', (), NO_ARGS, 'scalar'), - ('fmod', (S, S, S), (1.5,)), - ('fmod', (), (1.5,), 'scalar'), + ('trunc', (S, S, S), NO_ARGS, '', (True,)), + ('trunc', (), NO_ARGS, 'scalar', (True,)), + ('floor', (S, S, S), NO_ARGS, '', (True,)), + ('floor', (), NO_ARGS, 'scalar', (True,)), + ('ceil', (S, S, S), NO_ARGS, '', (True,)), + ('ceil', (), NO_ARGS, 'scalar', (True,)), + ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)), + ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)), + ('frac', (S, S, S), NO_ARGS, '', (True,)), + ('frac', (), NO_ARGS, 'scalar', (True,)), + ('fmod', (S, S, S), (1.5,), '', (True,)), + ('fmod', (), (1.5,), 'scalar', (True,)), ('fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'), ('fmod', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'), ('fmod', (S, S, S), (non_differentiable(torch.rand(S) + 1.5),), 'tensor_broadcast_rhs'), @@ -284,217 +292,217 @@ def method_tests(): ('fmod', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'), ('fmod', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'), ('fmod', (S, S, S), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor_broadcast_rhs'), - ('remainder', (S, S, S), (1.5,)), - ('remainder', (), (1.5,), 'scalar'), + ('remainder', (S, S, S), (1.5,), '', (True,)), + ('remainder', (), (1.5,), 'scalar', (True,)), ('remainder', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'), ('remainder', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'), ('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'), ('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'), ('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'), - ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast'), - ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'), - ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'), - ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all'), - ('lerp', (), ((), 0.4), 'scalar'), - ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs'), - ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs'), + ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast', (True,)), + ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs', (True,)), + ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs', (True,)), + ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all', (True,)), + ('lerp', (), ((), 0.4), 'scalar', (True,)), + ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs', (True,)), + ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs', (True,)), ('max', (S, S, S), NO_ARGS), - ('max', (S, S, S), (1,), 'dim', [0]), - ('max', (S, S, S), (1, True,), 'keepdim_dim', [0]), + ('max', (S, S, S), (1,), 'dim', (), [0]), + ('max', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('max', (), NO_ARGS, 'scalar'), - ('max', (), (0,), 'scalar_dim', [0]), - ('max', (), (0, True,), 'scalar_keepdim_dim', [0]), - ('max', (S, S, S), ((S, S, S),), 'elementwise'), - ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'), - ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'), - ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'), - ('max', (), ((),), 'scalar_elementwise'), - ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'), - ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'), - ('min', (S, S, S), NO_ARGS), - ('min', (S, S, S), (1,), 'dim', [0]), - ('min', (S, S, S), (1, True,), 'keepdim_dim', [0]), + ('max', (), (0,), 'scalar_dim', (), [0]), + ('max', (), (0, True,), 'scalar_keepdim_dim', (), [0]), + ('max', (S, S, S), ((S, S, S),), 'elementwise', (True,)), + ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs', (True,)), + ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs', (True,)), + ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all', (True,)), + ('max', (), ((),), 'scalar_elementwise', (True,)), + ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs', (True,)), + ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs', (True,)), + ('min', (S, S, S), NO_ARGS, ), + ('min', (S, S, S), (1,), 'dim', (), [0]), + ('min', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('min', (), NO_ARGS, 'scalar'), - ('min', (), (0,), 'scalar_dim', [0]), - ('min', (), (0, True,), 'scalar_keepdim_dim', [0]), - ('min', (S, S, S), ((S, S, S),), 'elementwise'), - ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'), - ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'), - ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'), - ('min', (), ((),), 'scalar_elementwise'), - ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'), - ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'), - ('mean', (S, S, S), NO_ARGS), - ('mean', (S, S, S), (1,), 'dim', [0]), - ('mean', (S, S, S), (1, True,), 'keepdim_dim', [0]), - ('mean', (), NO_ARGS, 'scalar'), - ('mean', (), (0,), 'scalar_dim', [0]), - ('mean', (), (0, True,), 'scalar_keepdim_dim', [0]), + ('min', (), (0,), 'scalar_dim', (), [0]), + ('min', (), (0, True,), 'scalar_keepdim_dim', (), [0]), + ('min', (S, S, S), ((S, S, S),), 'elementwise', (True,)), + ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs', (True,)), + ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs', (True,)), + ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all', (True,)), + ('min', (), ((),), 'scalar_elementwise', (True,)), + ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs', (True,)), + ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs', (True,)), + ('mean', (S, S, S), NO_ARGS, '', (True,)), + ('mean', (S, S, S), (1,), 'dim', (True,), [0]), + ('mean', (S, S, S), (1, True,), 'keepdim_dim', (True,), [0]), + ('mean', (), NO_ARGS, 'scalar', (True,)), + ('mean', (), (0,), 'scalar_dim', (True,), [0]), + ('mean', (), (0, True,), 'scalar_keepdim_dim', (True,), [0]), ('kthvalue', (S, S, S), (2,)), ('kthvalue', (), (1,), 'scalar'), - ('kthvalue', (S, S, S), (2, 1,), 'dim', [1]), - ('kthvalue', (), (1, 0,), 'scalar_dim', [1]), - ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', [1]), - ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', [1]), - ('kthvalue', (S,), (2, 0,), 'dim_1d', [1]), - ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', [1]), + ('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]), + ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]), + ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]), + ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]), + ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]), + ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]), ('median', (S, S, S), NO_ARGS), - ('median', (S, S, S), (1,), 'dim', [0]), - ('median', (S, S, S), (1, True,), 'keepdim_dim', [0]), + ('median', (S, S, S), (1,), 'dim', (), [0]), + ('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('median', (), NO_ARGS, 'scalar'), - ('median', (), (0,), 'scalar_dim', [0]), - ('median', (), (0, True,), 'scalar_keepdim_dim', [0]), + ('median', (), (0,), 'scalar_dim', (), [0]), + ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('mode', (S, S, S), NO_ARGS), - ('mode', (S, S, S), (1,), 'dim', [0]), - ('mode', (S, S, S), (1, True,), 'keepdim_dim', [0]), + ('mode', (S, S, S), (1,), 'dim', (), [0]), + ('mode', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('mode', (), NO_ARGS, 'scalar'), - ('mode', (), (0,), 'scalar_dim', [0]), - ('mode', (), (0, True,), 'scalar_keepdim_dim', [0]), + ('mode', (), (0,), 'scalar_dim', (), [0]), + ('mode', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('sum', (S, S, S), NO_ARGS), - ('sum', (S, S, S), (1,), 'dim', [0]), - ('sum', (S, S, S), (1, True,), 'keepdim_dim', [0]), + ('sum', (S, S, S), (1,), 'dim', (), [0]), + ('sum', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('sum', (), NO_ARGS, 'scalar'), - ('sum', (), (0,), 'scalar_dim', [0]), - ('sum', (), (0, True,), 'scalar_keepdim_dim', [0]), + ('sum', (), (0,), 'scalar_dim', (), [0]), + ('sum', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('sum', (S, S, S), ([1, 2],), 'multi_dim'), ('sum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'), ('prod', (S, S, S), NO_ARGS), - ('prod', (S, S, S), (1,), 'dim', [0]), - ('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]), + ('prod', (S, S, S), (1,), 'dim', (), [0]), + ('prod', (S, S, S), (1, True,), 'keepdim_dim', (), [0]), ('prod', (), NO_ARGS, 'scalar'), - ('prod', (), (0,), 'scalar_dim', [0]), - ('prod', (), (0, True,), 'scalar_keepdim_dim', [0]), + ('prod', (), (0,), 'scalar_dim', (), [0]), + ('prod', (), (0, True,), 'scalar_keepdim_dim', (), [0]), ('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'), ('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'), ('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'), - ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', [0]), - ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', [0]), - ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', [0]), - ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', [0]), - ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', [0]), - ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', [0]), + ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', (), [0]), + ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', (), [0]), + ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', (), [0]), + ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', (), [0]), + ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', (), [0]), + ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', (), [0]), ('prod', prod_single_zero(S), NO_ARGS, 'single_zero'), ('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'), - ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', [0]), - ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', [0]), - ('var', (S, S, S), NO_ARGS), - ('var', (S, S, S), (1,), 'dim', [0]), - ('var', (S, S, S), (1, True, True), 'keepdim_dim', [0]), - ('var', (S,), (0,), 'dim_1d', [0]), - ('var', (S,), (0, True, True), 'keepdim_dim_1d', [0]), - ('std', (S, S, S), NO_ARGS), - ('std', (S, S, S), (1,), 'dim', [0]), - ('std', (S, S, S), (1, True, True), 'keepdim_dim', [0]), - ('std', (S,), (0,), 'dim_1d', [0]), - ('std', (S,), (0, True, True), 'keepdim_dim_1d', [0]), - ('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]), + ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]), + ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]), + ('var', (S, S, S), NO_ARGS, '', (True,)), + ('var', (S, S, S), (1,), 'dim', (True,), [0]), + ('var', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]), + ('var', (S,), (0,), 'dim_1d', (True,), [0]), + ('var', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]), + ('std', (S, S, S), NO_ARGS, '', (True,)), + ('std', (S, S, S), (1,), 'dim', (True,), [0]), + ('std', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]), + ('std', (S,), (0,), 'dim_1d', (True,), [0]), + ('std', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]), + ('renorm', (S, S, S), (2, 1, 0.5), 'dim', (), [1]), ('renorm', (S, S, S), (1, 2, 3), 'norm_1'), ('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'), ('repeat', (S,), (2,), 'single_number'), ('repeat', (), (2, 3), 'scalar'), ('repeat', (2, 2), (3, 2)), ('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'), - ('cumsum', (S, S, S), (0,), 'dim0', [0]), - ('cumsum', (S, S, S), (1,), 'dim1', [0]), - ('cumsum', (S, S, S), (1,), 'dim1_cast', [0], (), lambda x: x, {'dtype': torch.float64}), - ('cumsum', (), (0,), 'dim0_scalar', [0]), + ('cumsum', (S, S, S), (0,), 'dim0', (), [0]), + ('cumsum', (S, S, S), (1,), 'dim1', (), [0]), + ('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), lambda x: x, {'dtype': torch.float64}), + ('cumsum', (), (0,), 'dim0_scalar', (), [0]), ('cumprod', (S, S, S), (0,)), - ('cumprod', (S, S, S), (1,), 'dim1', [0]), + ('cumprod', (S, S, S), (1,), 'dim1', (), [0]), ('cumprod', (), (0,), 'scalar'), ('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'), - ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', [0]), - ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', [0]), - ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', [0]), - ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', [0], (), lambda x: x, {'dtype': torch.float64}), - ('unfold', (), (0, 1, 1), 'scalar', [0]), - ('unfold', (S, S, S, S), (1, 3, 1), '', [0]), - ('unfold', (S, S, S), (2, 3, 2), 'lastdim', [0]), - ('addmm', (S, M), ((S, S), (S, M)),), - ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs'), - ('addmm', (S, M), ((S, S), (S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs'), - ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', (), [0]), + ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', (), [0]), + ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', (), [0]), + ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', (), [0], (), lambda x: x, {'dtype': torch.float64}), + ('unfold', (), (0, 1, 1), 'scalar', (), [0]), + ('unfold', (S, S, S, S), (1, 3, 1), '', (), [0]), + ('unfold', (S, S, S), (2, 3, 2), 'lastdim', (), [0]), + ('addmm', (S, M), ((S, S), (S, M)), '', (True, ['aten::add', 'aten::mm'])), + ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs', (True, ['aten::add', 'aten::mm'])), + ('addmm', (S, M), ((S, S), (S, M)), 'coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs', (True, ['aten::add', 'aten::mm'])), + ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('addbmm', (S, M), ((S, S, S), (S, S, M)),), ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'), - ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', + ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'), - ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, + ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),), ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'), - ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', + ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'), - ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, + ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('addmv', (S,), ((S, M), (M,)),), ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs'), - ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs'), - ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), ('addr', (S, M), ((S,), (M,)),), ('addr', (), ((S,), (M,)), 'broadcast_lhs'), - ('addr', (S, M), ((S,), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), - ('dot', (L,), ((L,),),), - ('mm', (S, M), ((M, S),)), - ('bmm', (M, S, M), ((M, M, S),)), - ('mv', (S, M), ((M,),)), + ('addr', (S, M), ((S,), (M,)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}), + ('dot', (L,), ((L,),), '', (True,)), + ('mm', (S, M), ((M, S),), '', (True,)), + ('bmm', (M, S, M), ((M, M, S),), '', (True,)), + ('mv', (S, M), ((M,),), '', (True,)), ('ger', (S,), ((M,),)), - ('matmul', (L,), ((L,),),), - ('matmul', (S, M), ((M,),), "2d_1d"), - ('matmul', (M, ), ((M, S),), "1d_2d"), - ('matmul', (S, M), ((M, S),), "2d_2d"), - ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"), - ('matmul', (S, S, M, M), ((M,),), "4d_1d"), - ('matmul', (M,), ((S, S, M, S),), "1d_4d"), + ('matmul', (L,), ((L,),), '', (True,)), + ('matmul', (S, M), ((M,),), "2d_1d", (True,)), + ('matmul', (M, ), ((M, S),), "1d_2d", (True,)), + ('matmul', (S, M), ((M, S),), "2d_2d", (True,)), + ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)), + ('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)), + ('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)), ('matrix_power', (S, S), [2], "n=2"), ('matrix_power', (S, S, S), [3], "n=3"), ('matrix_power', (S, S, S), [1], "n=1"), ('matrix_power', (S, S, S), [0], "n=0"), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", + ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (), NO_ARGS, [skipIfNoLapack]), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", + ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (), NO_ARGS, [skipIfNoLapack]), - ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", + ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (), NO_ARGS, [skipIfNoLapack]), ('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"), ('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"), ('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"), ('mvlgamma', torch.empty(S, S).uniform_(2.5, 5), [5], "p=5"), - ('addcmul', (S, S), ((S, S), (S, S))), - ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'), - ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'), - ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}), - ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}), - ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}), - ('addcmul', (), ((), ()), 'scalar'), - ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs'), - ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'), - ('addcmul', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}), - ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}), - ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}), + ('addcmul', (S, S), ((S, S), (S, S)), '', (True,)), + ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs', (True,)), + ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all', (True,)), + ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (True,), (), (), lambda x: x, {'value': 0.5}), + ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (True,), (), (), lambda x: x, {'value': 0.5}), + ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (True,), (), (), lambda x: x, {'value': 0.5}), + ('addcmul', (), ((), ()), 'scalar', (True,)), + ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs', (True,)), + ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs', (True,)), + ('addcmul', (), ((), ()), 'scalar_scale', (True,), (), (), lambda x: x, {'value': 0.5}), + ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (True,), (), (), lambda x: x, {'value': 0.5}), + ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (True,), (), (), lambda x: x, {'value': 0.5}), ('addcdiv', (S, S), ((S, S), (S, S))), ('addcdiv', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'), ('addcdiv', (1,), ((S, S, 1), (1, S)), 'broadcast_all'), - ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}), - ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}), - ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}), + ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), (), lambda x: x, {'value': 0.5}), + ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), (), lambda x: x, {'value': 0.5}), + ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), (), lambda x: x, {'value': 0.5}), ('addcdiv', (), ((), ()), 'scalar'), ('addcdiv', (S, S), ((), ()), 'scalar_broadcast_rhs'), ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'), - ('addcdiv', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}), - ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}), - ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}), + ('addcdiv', (), ((), ()), 'scalar_scale', (), (), (), lambda x: x, {'value': 0.5}), + ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), (), lambda x: x, {'value': 0.5}), + ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), (), lambda x: x, {'value': 0.5}), ('zero_', (S, S, S), NO_ARGS), ('zero_', (), NO_ARGS, 'scalar'), - ('logsumexp', (S, S), (1,)), - ('logsumexp', (), (0,), 'scalar'), + ('logsumexp', (S, S), (1,), '', (True,)), + ('logsumexp', (), (0,), 'scalar', (True,)), ('norm', (S, S), (), 'default'), ('norm', (S, S), (2,), '2'), ('norm', (S, S), (0,), '0'), @@ -505,29 +513,29 @@ def method_tests(): ('norm', (S, S), (-inf,), '-inf'), ('norm', (S, S), ('fro',), 'fro_default'), ('norm', (S, S), ('fro', [0, 1],), 'fro'), - ('norm', (S, S), ('nuc',), 'nuc', NO_ARGS, [skipIfNoLapack]), + ('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]), ('norm', (S, S), (-1,), 'neg_1'), ('norm', (S, S), (-2,), 'neg_2'), ('norm', (S, S), (-0.5,), 'neg_0_5'), ('norm', (S, S), (-1.5,), 'neg_1_5'), - ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', [1]), - ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', [1]), - ('norm', (S, S), (0, 1,), '0_2_dim', [1]), - ('norm', (S, S), (1, 1,), '1_2_dim', [1]), - ('norm', (S, S), (2, 1,), '2_2_dim', [1]), - ('norm', (S, S), (3, 1,), '3_2_dim', [1]), + ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', (), [1]), + ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', (), [1]), + ('norm', (S, S), (0, 1,), '0_2_dim', (), [1]), + ('norm', (S, S), (1, 1,), '1_2_dim', (), [1]), + ('norm', (S, S), (2, 1,), '2_2_dim', (), [1]), + ('norm', (S, S), (3, 1,), '3_2_dim', (), [1]), ('norm', (S, S), (inf, 1,), 'inf_2_dim'), ('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5_default'), - ('norm', (S, S, S), (2, 1), '2_dim', [1]), - ('norm', (S, S, S), (3, 1), '3_dim', [1]), - ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', [1]), - ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', [1]), - ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', [1]), - ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', [1]), - ('norm', (), (2, 0), '2_dim_scalar', [1]), - ('norm', (), (3, 0), '3_dim_scalar', [1]), - ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', [1]), - ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', [1]), + ('norm', (S, S, S), (2, 1), '2_dim', (), [1]), + ('norm', (S, S, S), (3, 1), '3_dim', (), [1]), + ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', (), [1]), + ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', (), [1]), + ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', (), [1]), + ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', (), [1]), + ('norm', (), (2, 0), '2_dim_scalar', (), [1]), + ('norm', (), (3, 0), '3_dim_scalar', (), [1]), + ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', (), [1]), + ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', (), [1]), ('clone', (S, M, S), NO_ARGS), ('clone', (), NO_ARGS, 'scalar'), ('dist', (S, S, S), ((S, S, S),)), @@ -580,84 +588,84 @@ def method_tests(): ('trace', (M, M), NO_ARGS), ('cross', (S, 3), ((S, 3),)), ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'), - ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]), - ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', [0]), - ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', [0]), - ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', [0]), - ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]), - ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]), - ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]), - ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]), - ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]), - ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]), - ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', [0]), - ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', [0]), - ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', [0]), - ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', [0]), - ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), + ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', (), [0]), + ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', (), [0]), + ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', (), [0]), + ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', (), [0]), + ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]), + ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]), + ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', (), [0]), + ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]), + ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]), + ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', (), [0]), + ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', (), [0]), + ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', (), [0]), + ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', (), [0]), + ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', (), [0]), + ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]), ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3), - NO_ARGS, 'batched', NO_ARGS, [skipIfNoLapack]), - ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), - ('det', (1, 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]), - ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]), - ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', NO_ARGS, [skipIfNoLapack]), - ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', NO_ARGS, [skipIfNoLapack]), - ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', NO_ARGS, [skipIfNoLapack]), - ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', NO_ARGS, [skipIfNoLapack]), - ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', NO_ARGS, [skipIfNoLapack]), + NO_ARGS, 'batched', (), NO_ARGS, [skipIfNoLapack]), + ('det', (S, S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]), + ('det', (1, 1), NO_ARGS, '1x1', (), NO_ARGS, [skipIfNoLapack]), + ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), NO_ARGS, [skipIfNoLapack]), + ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', (), NO_ARGS, [skipIfNoLapack]), + ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]), + ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', (), NO_ARGS, [skipIfNoLapack]), + ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), NO_ARGS, [skipIfNoLapack]), + ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), NO_ARGS, [skipIfNoLapack]), ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, - 'distinct_singular_values', NO_ARGS, [skipIfNoLapack]), + 'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]), # For `logdet` and `slogdet`, the function at det=0 is not smooth. # We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use # `make_nonzero_det` to make the random matrices have nonzero det. For # `logdet`, we also set `make_nonzero_det(matrix, sign=1)` to make the # matrix have positive det. - ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), - ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]), + ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]), + ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', (), NO_ARGS, [skipIfNoLapack]), ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS, - 'symmetric', NO_ARGS, [skipIfNoLapack]), + 'symmetric', (), NO_ARGS, [skipIfNoLapack]), ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS, - 'symmetric_pd', NO_ARGS, [skipIfNoLapack]), + 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]), ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS, - 'distinct_singular_values', NO_ARGS, [skipIfNoLapack]), + 'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]), ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, - '1x1_pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)), + '1x1_pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS, - '1x1_neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)), + '1x1_neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, - 'pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)), + 'pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS, - 'neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)), + 'neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS, - 'symmetric', NO_ARGS, [skipIfNoLapack], itemgetter(1)), + 'symmetric', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), ('slogdet', lambda: random_symmetric_pd_matrix(S), NO_ARGS, - 'symmetric_pd', NO_ARGS, [skipIfNoLapack], itemgetter(1)), + 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, - 'distinct_singular_values', NO_ARGS, [skipIfNoLapack], itemgetter(1)), - ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', NO_ARGS, [skipIfNoLapack]), - ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', NO_ARGS, [skipIfNoLapack]), - ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', NO_ARGS, [skipIfNoLapack]), - ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]), + 'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)), + ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', (), NO_ARGS, [skipIfNoLapack]), + ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', (), NO_ARGS, [skipIfNoLapack]), + ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', (), NO_ARGS, [skipIfNoLapack]), + ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]), ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS, - 'wide', NO_ARGS, [skipIfNoLapack]), + 'wide', (), NO_ARGS, [skipIfNoLapack]), ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS, - 'tall', NO_ARGS, [skipIfNoLapack]), + 'tall', (), NO_ARGS, [skipIfNoLapack]), ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,), - 'wide_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])), + 'wide_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])), ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,), - 'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])), + 'tall_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])), ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS, - 'large', NO_ARGS, [skipIfNoLapack]), + 'large', (), NO_ARGS, [skipIfNoLapack]), ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value( - S, silent=True),), '', NO_ARGS, [skipIfNoLapack]), + S, silent=True),), '', (), NO_ARGS, [skipIfNoLapack]), ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),), - 'batched', NO_ARGS, [skipIfNoLapack]), + 'batched', (), NO_ARGS, [skipIfNoLapack]), ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),), - 'batched_dims', NO_ARGS, [skipIfNoLapack]), + 'batched_dims', (), NO_ARGS, [skipIfNoLapack]), ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),), - 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]), + 'batched_broadcast_A', (), NO_ARGS, [skipIfNoLapack]), ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),), - 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]), + 'batched_broadcast_b', (), NO_ARGS, [skipIfNoLapack]), ('fill_', (S, S, S), (1,), 'number'), ('fill_', (), (1,), 'number_scalar'), ('fill_', (S, S, S), ((),), 'variable'), @@ -697,41 +705,41 @@ def method_tests(): ('ge_', (), (0,), 'pyscalar_scalar'), ('lt_', (), (0,), 'pyscalar_scalar'), ('le_', (), (0,), 'pyscalar_scalar'), - ('permute', (1, 2, 3, 4), (0, 2, 3, 1)), - ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim'), - ('permute', (), (dont_convert(()),), 'scalar'), - ('select', (S, S, S), (1, 2), 'dim', [0]), - ('select', (S, S, S), (1, -1), 'wrap_dim', [0]), + ('permute', (1, 2, 3, 4), (0, 2, 3, 1), '', (True,)), + ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim', (True,)), + ('permute', (), (dont_convert(()),), 'scalar', (True,)), + ('select', (S, S, S), (1, 2), 'dim', (), [0]), + ('select', (S, S, S), (1, -1), 'wrap_dim', (), [0]), ('select', (S,), (0, 2), '1d'), - ('narrow', (S, S, S), (1, 2, 2), 'dim', [0]), - ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', [0]), - ('squeeze', (S, 1, S, 1), NO_ARGS), - ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones'), - ('squeeze', (S, 1, S, 1), (1,), '1_dim', [0]), - ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', [0]), - ('squeeze', (), (0,), 'scalar', [0]), - ('unsqueeze', (S, S, S), (0,), 'first', [0]), - ('unsqueeze', (S, S, S), (1,), 'middle', [0]), - ('unsqueeze', (S, S, S), (3,), 'last', [0]), - ('unsqueeze', (), (0,), 'scalar', [0]), - ('chunk', (S, S, S), (2,)), - ('chunk', (S, S, S), (S, 1), 'dim', [1]), + ('narrow', (S, S, S), (1, 2, 2), 'dim', (), [0]), + ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', (), [0]), + ('squeeze', (S, 1, S, 1), NO_ARGS, '', (True,)), + ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones', (True,)), + ('squeeze', (S, 1, S, 1), (1,), '1_dim', (True,), [0]), + ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', (True,), [0]), + ('squeeze', (), (0,), 'scalar', (True,), [0]), + ('unsqueeze', (S, S, S), (0,), 'first', (True,), [0]), + ('unsqueeze', (S, S, S), (1,), 'middle', (True,), [0]), + ('unsqueeze', (S, S, S), (3,), 'last', (True,), [0]), + ('unsqueeze', (), (0,), 'scalar', (True,), [0]), + ('chunk', (S, S, S), (2,), '', (True, 'prim::ConstantChunk')), + ('chunk', (S, S, S), (S, 1), 'dim', (True, 'prim::ConstantChunk'), [1]), ('split', (S, S, S), (2,)), - ('split', (S, S, S), (S, 1), 'dim', [1]), + ('split', (S, S, S), (S, 1), 'dim', (), [1]), ('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list'), - ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', [1]), - ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', [0]), - ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', [0]), - ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', [0]), - ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', [0]), - ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]), - ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]), - ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]), - ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]), - ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]), - ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]), - ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]), - ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]), + ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', (), [1]), + ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]), + ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]), + ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]), + ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', (), [0]), + ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', (), [0]), + ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]), + ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]), + ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', (), [0]), + ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', (), [0]), + ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]), + ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]), + ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', (), [0]), ('masked_select', (M, M), (mask_not_all_zeros((M, M)),)), ('masked_select', (M, M), (mask_not_all_zeros((M,)),), 'broadcast_rhs'), ('masked_select', (M,), (mask_not_all_zeros((M, M)),), 'broadcast_lhs'), @@ -770,22 +778,22 @@ def method_tests(): ('sort', (), (0,), 'dim_scalar'), ('sort', (), (0, True), 'dim_desc_scalar'), ('topk', (S, M, S), (3,)), - ('topk', (S, M, S), (3, 1), 'dim', [1]), - ('topk', (S, M, S), (3, 1, True), 'dim_desc', [1]), - ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', [1]), + ('topk', (S, M, S), (3, 1), 'dim', (), [1]), + ('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]), + ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', (), [1]), ('topk', (), (1,), 'scalar'), - ('topk', (), (1, 0), 'dim_scalar', [1]), - ('topk', (), (1, 0, True), 'dim_desc_scalar', [1]), - ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', [1]), + ('topk', (), (1, 0), 'dim_scalar', (), [1]), + ('topk', (), (1, 0, True), 'dim_desc_scalar', (), [1]), + ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', (), [1]), ('take', (S, S, S), (torch.LongTensor([[-3, 2], [20, 2]]),)), ('take', (S, S, S), (torch.tensor(0, dtype=torch.int64),), 'scalar_index'), ('take', (), (torch.LongTensor([0]),), 'scalar_data'), ('take', (), (torch.tensor(0, dtype=torch.int64),), 'scalar_both'), - ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M))), - ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all'), - ('where', (), (bernoulli_scalar(), ()), 'scalar'), - ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask'), - ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask'), + ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M)), '', (True,)), + ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all', (True,)), + ('where', (), (bernoulli_scalar(), ()), 'scalar', (True,)), + ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask', (True,)), + ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask', (True,)), ('__getitem__', torch.randn(S, S, S), (dont_convert([1, 2]),)), ('__getitem__', torch.randn(S, S, S), (slice(0, 3),), 'slice'), ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(0, 3), 1]),), 'slice_index'), diff --git a/test/test_autograd.py b/test/test_autograd.py index d966685..0a2c08e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2986,6 +2986,7 @@ def add_test( self_size, args, variant_name='', + check_ad=(), # only used in test_jit dim_args_idx=(), skipTestIf=(), output_process_fn=lambda x: x, diff --git a/test/test_jit.py b/test/test_jit.py index 379594a..498cf0e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -68,6 +68,10 @@ except ImportError: skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") +# Note: creating FusionGroups is currently device-independent. +# FusionGroup creation with CPU is disabled. +FUSION_ENABLED = torch._C._jit_can_fuse_on_cpu() or torch._C._jit_can_fuse_on_gpu() + RUN_CUDA = torch.cuda.is_available() RUN_CUDA_HALF = RUN_CUDA if torch.cuda.is_available(): @@ -385,6 +389,25 @@ class JitTestCase(TestCase): torch._C._jit_pass_lint(graph) self.assertExpected(str(graph), *args, **kwargs) + def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): + if not FUSION_ENABLED: + nonfusible_nodes = nonfusible_nodes + fusible_nodes + fusible_nodes = [] + diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') + diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] + + # For any non-fusible node, it must show up in one of the DifferentiableGraph. + found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0)\ + or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes]) + + # For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph. + fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) + fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] + found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0)\ + or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes]) + + self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes) + def run_pass(self, name, trace): if isinstance(trace, torch._C.Graph): graph = trace @@ -11283,6 +11306,15 @@ EXCLUDE_SCRIPT = { 'test_nn_fold', } +# chunk returns a list in scripting and we don't unpack the list, +# Thus it won't be replaced by ConstantChunk and run AD. +# It's explicitly checked in test_chunk_constant_script_ad +EXCLUDE_SCRIPT_AD_CHECK = { + 'test_chunk', + 'test_chunk_dim', + 'test_chunk_dim_neg0', +} + EXCLUDE_PYTHON_PRINT = { # no support for BroadcastingList in python printer 'test_nn_max_unpool1d', @@ -11301,23 +11333,6 @@ EXCLUDE_SCRIPT_MODULES = { 'test_nn_AdaptiveMaxPool3d_tuple_none', } -DISABLE_AUTODIFF_SUBGRAPH_INLINING = { - 'test_nn_avg_pool2d', - 'test_nn_adaptive_avg_pool1d', - 'test_nn_adaptive_avg_pool2d', - 'test_nn_adaptive_avg_pool3d', - 'test_nn_batch_norm', - 'test_nn_embedding', - 'test_nn_log_softmax', - 'test_nn_softmax', - 'test_nn_softmax_with_all_args', - 'test_nn_threshold', - 'test_nn_nll_loss', - # Should have added all test_nn_interpolate_* here, - # but it's using autodiff since its subgraph is over - # 2 nodes. -} - # make a new function where all non-tensor arguments in 'args' have been partially # applied, and all tensor arguments remain. @@ -11528,6 +11543,17 @@ class TestAutodiffSubgraphSlicing(JitTestCase): def assertGraphSize(self, graph, size): self.assertEqual(len(list(graph.nodes())), size) + def test_chunk_constant_script_ad(self): + @torch.jit.script + def func(x): + x1, x2 = torch.chunk(x, 2) + return (x1, x2) + + input = torch.rand(6, 10).requires_grad_() + func.debug_disable_autodiff_subgraph_inlining() + output = func(input) + self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], []) + def test_simple_merge(self): # o --> o def fn(x, y, z): @@ -11867,6 +11893,7 @@ EXCLUDE_MODULE_EXPORT_IMPORT = { # args (tuple represents shape of a tensor arg), # test variant name(will be used at test name suffix, # 'inplace' skips grad tests), // optional +# (True, nonfusible_nodes, fusible_nodes) for autodiff // optional # fn to determine if test should be skipped, // optional # fn mapping output to part that should be gradcheck'ed, // optional # kwargs for function, // optional @@ -11880,7 +11907,7 @@ nn_functional_tests = [ ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)), ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)), ('avg_pool1d', (S, S, S), (3,)), - ('avg_pool2d', (S, S, S, S), (3,)), + ('avg_pool2d', (S, S, S, S), (3,), '', (True,)), ('avg_pool3d', (S, S, S, S, S), (3,)), ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)), ('max_pool1d', (S, S, S), (2, 1)), @@ -11895,15 +11922,17 @@ nn_functional_tests = [ ('adaptive_max_pool1d', (S, S, S), (5,)), ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)), ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)), - ('adaptive_avg_pool1d', (S, S, S), (5,)), - ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],)), - ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],)), - ('dropout', (S, S, S), (0.5,)), + ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)), + ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)), + ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)), + ('dropout', (S, S, S), (0.5,), '', (True, + ['prim::is_cuda', 'aten::bernoulli_'], + ['aten::rand_like', 'aten::lt', 'aten::type_as', 'aten::mul', 'aten::div'])), ('alpha_dropout', (S, S, S), (0.5,)), ('dropout2d', (S, S, S), (0.5,)), ('dropout3d', (S, S, S), (0.5,)), ('feature_alpha_dropout', (S, S, S), (0.5,)), - ('threshold', (S, S, S), (0.1, 2.),), + ('threshold', (S, S, S), (0.1, 2.), '', (True,)), ('threshold', (S, S, S), (0.1, 2., True), 'inplace'), ('relu', (S, S, S), (),), ('relu', (S, S, S), (), 'inplace'), @@ -11927,16 +11956,17 @@ nn_functional_tests = [ ('softsign', (S, S, S), (),), ('softplus', (S, S, S), (),), ('softmin', (S, S, S), (0,),), - ('softmax', (S, S, S), (0,),), - ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args'), + ('softmax', (S, S, S), (0,), '', (True,)), + ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)), ('tanh', (S, S, S), (),), ('sigmoid', (S, S, S), (),), - ('log_softmax', (S, S, S), (0,),), + ('log_softmax', (S, S, S), (0,), '', (True,)), ('linear', (S, S), ((M, S),),), ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), - ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),), + ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), - ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),), + ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ), + '', (True, 'aten::_batch_norm_impl_index')), ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), ('layer_norm', (S, S, S, S), ([5],),), ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'), @@ -11944,7 +11974,7 @@ nn_functional_tests = [ ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'), ('group_norm', (S, S, S), (1, torch.rand(5),),), ('local_response_norm', (S, S, S), (2, ),), - ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),),), + ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')), ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),), ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'), ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),), @@ -11972,8 +12002,8 @@ nn_functional_tests = [ ('unfold', (S, S, S, S), ([2, 3]),), ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), - ('gumbel_softmax', (S, S), (2.,),), - ('gumbel_softmax', (S, S), (2., True,), 'hard'), + ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])), + ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])), ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), 1, 1., non_differentiable(torch.randn(S))),), @@ -11987,35 +12017,35 @@ nn_functional_tests = [ torch.randint(1, S, (S,), dtype=torch.long))), ('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'), ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'), - ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'), - ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'), - ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'), - ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'), - ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'), - ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'), - ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'), - ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'), - ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'), - ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'), - ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'), - ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'), - ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'), - ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'), - ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size', (True, 'aten::__interpolate')), + ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale', (True, 'aten::__interpolate')), + ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size', (True, 'aten::__interpolate')), ] @@ -12063,6 +12093,7 @@ def add_autograd_test( self_size, args, variant_name='', + check_ad=(), dim_args_idx=(), skipTestIf=(), output_process_fn=lambda x: x, @@ -12080,7 +12111,7 @@ def add_autograd_test( # for-loop bodies don't define scopes, so we have to save the variables # we want to close over in some way def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name, - output_process_fn=output_process_fn): + check_ad=check_ad, output_process_fn=output_process_fn): def check(name): set_rng_seed(2) is_magic_method = name[:2] == '__' and name[-2:] == '__' @@ -12104,20 +12135,27 @@ def add_autograd_test( # Test with disable_autodiff_subgraph_inlining, which forces the graph # to contain DifferentiableGraph nodes whenever possible. This allows us # to test autodiff; we assume that autograd is correct and use autodiff for backprop + should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name) if test_name not in EXCLUDE_TRACED: - check_against_reference(self, - create_traced_fn(self, fn, - disable_autodiff_subgraph_inlining=True), + traced_fn = create_traced_fn(self, fn, disable_autodiff_subgraph_inlining=True) + + check_against_reference(self, traced_fn, fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) + self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) if not is_magic_method and test_name not in EXCLUDE_SCRIPT: - check_against_reference(self, - create_script_fn(self, name, 'method', output_process_fn, - disable_autodiff_subgraph_inlining=True), + script_fn = create_script_fn(self, name, 'method', output_process_fn, + disable_autodiff_subgraph_inlining=True) + check_against_reference(self, script_fn, fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) + self.assertAutodiffNode(script_fn.last_graph, + should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK, + autodiff_nodes, + fusible_nodes) + # functional interface tests if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL: def fn(*inputs, **kwargs): @@ -12162,7 +12200,7 @@ def suppress_warnings(fn): return wrapper -def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=(), +def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(), skipTestIf=(), output_process_fn=lambda x: x, kwargs=None): test_name = 'test_nn_' + name @@ -12172,7 +12210,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=() no_grad = variant_name == 'inplace' @suppress_warnings - def do_test(self, name=name, args=args, test_name=test_name): + def do_test(self, name=name, args=args, test_name=test_name, check_ad=check_ad): torch.manual_seed(2) self_variable = create_input((self_size,))[0][0] @@ -12193,13 +12231,14 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=() f_args_variable = (self_variable,) + args_variable f_args_tensor = (self_tensor,) + args_tensor + should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name) if test_name not in EXCLUDE_SCRIPT: - disable_ad_subgraph_inlining = test_name in DISABLE_AUTODIFF_SUBGRAPH_INLINING - def run_test(): script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn, - disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining) + disable_autodiff_subgraph_inlining=should_autodiff_node) check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad) + # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd + self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) if test_name in EXCLUDE_PYTHON_PRINT: with self.disableModuleHook(): @@ -12330,6 +12369,24 @@ def post_add_test(test_name, skipTestIf, do_test, test_class): setattr(test_class, test_name, do_test) +def normalize_check_ad(check_ad, name): + # normalized check_ad is 3-element tuple: (bool, List[str], List[str]) + if len(check_ad) == 0: + check_ad = [False, ['aten::' + name], []] + elif len(check_ad) == 1: + check_ad = [check_ad[0], ['aten::' + name], []] + elif len(check_ad) == 2: + check_ad = [check_ad[0], check_ad[1], []] + elif len(check_ad) == 3: + check_ad = list(check_ad) + else: + raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])') + + check_ad = [[t] if isinstance(t, str) else t for t in check_ad] + + return check_ad + + class TestAsync(JitTestCase): def test_async_python(self): @torch.jit.script diff --git a/torch/csrc/jit/README.md b/torch/csrc/jit/README.md index 70e9ecd..0a403bd 100644 --- a/torch/csrc/jit/README.md +++ b/torch/csrc/jit/README.md @@ -66,6 +66,8 @@ Sections start with a reference to the source file where the code related to the + [`tensors/`](#-tensors--) + [`attributes.pkl`](#-attributespkl-) + [Implementation Details](#implementation-details) +- [Testing Programs](#testing-programs) + * [Test Autodiff](#testautodiff) - [Python Bindings](#python-bindings) @@ -1162,6 +1164,36 @@ TODO: fusion, operators # Saving Programs +# Testing Programs +## Test Autodiff ## +[symbolic_script.cpp](symbolic_script.cpp) + +When differentiating a graph, each node that has a symbolic gradient will be included in a `prim::DifferentiableGraph`. We fall back to use autograd for the node if there isn't a gradient formula for it. +Adding/updating symbolic gradient functions must be tested carefully as it's easy to get CI green by comparing autograd result with itself, but potentially cause autodiff support regression. + +If your PR adds/updates a gradient formula for `torch`/`nn` functions, you **MUST** enable/update the corresponding tests in +- `torch` functions: `method_tests` in [common_method_tests.py](../../../test/common_method_tests.py) +- `nn` functions: `nn_functional_tests` in [test_jit.py](../../../test/test_jit.py) + +To turn on autodiff check, you can add an optional `check_ad(should_check_autodiff[bool], nonfusible_nodes[str|list[str]], fusible_nodes[str|list[str]])` tuple after the optional test variant name field. +If `should_check_autodiff=True`, the differentiated traced/script forward graph must have a `prim::DifferentiableGraph`. + +All nodes in `nonfusible_nodes` should show up in at least once in `prim::DifferentiableGraph` subgraphs. +When fusion is enabled, all nodes in `fusible_nodes` should show up in one of `prim::FusionGroup` graphs attached to `prim::DifferentiableGraph`, +otherwise they're checked as `nonfusible_nodes` as well. +On the other hand, if `should_check_autodiff=False`, the graph can still have `prim::DifferentiableGraph` with other nodes, but not `nonfusible_nodes` and `fusible_nodes`. + +To make writing test easier, you only need to write out node names if it's different from the function name. Below are a few examples: +```python +('conv1d', ...), # No symbolic gradient formula +('avg_pool2d', ..., (True,)), # Has symbolic gradient formula, only has one nonfusible node aten::avg_pool2d +('nll_loss', ..., (True, 'aten::nll_loss_forward')), # Is replaced by a different node in its symbolic gradient formula +('dropout', ..., (True, ['prim::is_cuda', 'aten::bernoulli_'], ['aten::rand_like', ..., 'aten::div'])), # Some op are fused when fusion is enabled. +``` + +Note that even for the same function, different tests could trigger different function schemas (e.g `aten::add`) while only a few of them have symbolic gradient formulas. +You should only turn on autodiff check in tests who have symbolic gradient. If you are not sure, uncomment the debugging line in [symbolic_script.cpp](symbolic_script.cpp) +to check which function schema the test triggers. ## Python Printer diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 489c438..19c6af6 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -216,6 +216,8 @@ void initJITBindings(PyObject* module) { .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) .def("_jit_pass_canonicalize_ops", CanonicalizeOps) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero) + .def("_jit_can_fuse_on_cpu", canFuseOnCPU) + .def("_jit_can_fuse_on_gpu", canFuseOnGPU) .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) .def( "_jit_differentiate", diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 6974936..a87470b 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -1008,6 +1008,8 @@ c10::optional gradientInfoForSchema( // value since scalar/int aren't differentiable either way. // c10::ReplaceAll(schema_str, "Scalar", "float"); + // For debugging AD change: + // std::cout << "Looking for " << schema_str << std::endl; auto sym_script_it = schema_to_graphs.find(schema_str);