Enforce check ad in test_jit (#18509)
authorAiling Zhang <ailzhang@fb.com>
Sun, 31 Mar 2019 15:41:46 +0000 (08:41 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 31 Mar 2019 15:51:30 +0000 (08:51 -0700)
Summary:
If a test triggers autodiff, it must have a `DifferentiableGraph` in its differentiated forward graph, and this subgraph must have either the original aten node, or the corresponding nodes used in AD formula.

Typically a forward differentiable graph looks like this:
```
graph(%i0 : Float(),
      %i1 : Float()):
  %3 : Float() = prim::DifferentiableGraph_0(%i0, %i1)
  return (%3)
with prim::DifferentiableGraph_0 = graph(%0 : Float(),
      %1 : Float()):
  %2 : Float() = aten::max(%0, %1)
  return (%2)
```
which tells us `aten::max(Tensor self, Tensor other) -> Tensor` is symbolically differentiable.

Update: there're a lot of cases (fusions/ConstantChunk/python implementations) that breaks it so I decided to make the check optionally take node names if different from function name.
~~[OLD]Theoretically I could also check if `aten::max` is in the differentiable block or not to be more precise, but there're also cases like `chunk` where in a differentiable block it's replaced with a prim node (ConstantChunk) and we will have to special case them. Any suggestions here (to be more precise or no) is very welcome!~~

We used to have a list containing nn tests should be run against AD, I moved it to an field when constructing our test(either torch or nn). I think it's cleaner this way, and it matches the fact that for the same op we support one schema of it but not all, in this way we could just turn on the corresponding test which triggers that supported schema.

cc: apaszke zdevito wanchaol ngimel for a review

[Done] :
- Going through a manual second pass of all tests to check if they should enable AD test or not....
- Add a readme about how to add AD for an op and how to add/enable its test in test_jit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18509

Differential Revision: D14696811

Pulled By: ailzhang

fbshipit-source-id: c5e693277baac585cd3aed5ab2c0e7faa5e6f29f

test/common_methods_invocations.py
test/test_autograd.py
test/test_jit.py
torch/csrc/jit/README.md
torch/csrc/jit/init.cpp
torch/csrc/jit/symbolic_script.cpp

index 43ef996..d1f0809 100644 (file)
@@ -102,77 +102,85 @@ S = 5
 #   input size/constructing fn,
 #   args (tuple represents shape of a tensor arg),
 #   test variant name (will be used at test name suffix),    // optional
+#   (True, nonfusible_nodes, fusible_nodes) for autodiff,    // optional
 #   indices for possible dim arg,                            // optional
 #   fn mapping output to part that should be gradcheck'ed,   // optional
 # )
+# Note: some functions have separate schema for (Tensor other) and (Scalar other),
+#       and it's possible that we only support AD for Scalar version but not Tensor
+#       version, and vice versa.
+#       When writing tests, only scalar(float/int) input triggers the Scalar schema.
+#       uniform_scalar produces a scalar **Tensor** which won't match Scalar input.
 def method_tests():
     set_rng_seed(0)
     return [
-        ('add', (S, S, S), ((S, S, S),)),
-        ('add', (S, S, S), ((S, S),), 'broadcast_rhs'),
-        ('add', (S, S), ((S, S, S),), 'broadcast_lhs'),
-        ('add', (S, 1, S), ((M, S),), 'broadcast_all'),
-        ('add', (), ((),), 'scalar'),
-        ('add', (S, S, S), ((),), 'scalar_broadcast_rhs'),
-        ('add', (), ((S, S, S),), 'scalar_broadcast_lhs'),
-        ('add', (S, S, S), (3.14,), 'constant'),
-        ('add', (), (3.14,), 'scalar_constant'),
-        ('__radd__', (S, S, S), (3.14,), 'constant'),
-        ('__radd__', (), (3.14,), 'scalar_constant'),
-        ('sub', (S, S, S), ((S, S, S),)),
-        ('sub', (S, S, S), ((S, S),), 'broadcast_rhs'),
-        ('sub', (S, S), ((S, S, S),), 'broadcast_lhs'),
-        ('sub', (S, 1, S), ((M, S),), 'broadcast_all'),
-        ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs'),
-        ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs'),
-        ('sub', (S, S, S), (3.14,), 'constant'),
-        ('sub', (), (3.14,), 'scalar_constant'),
-        ('__rsub__', (S, S, S), (3.14,), 'constant'),
-        ('__rsub__', (), (3.14,), 'scalar_constant'),
-        ('mul', (S, S, S), ((S, S, S),)),
-        ('mul', (), ((),), 'scalar'),
-        ('mul', (S, S, S), ((S, S),), 'broadcast_rhs'),
-        ('mul', (S, S), ((S, S, S),), 'broadcast_lhs'),
-        ('mul', (S, 1, S), ((M, S),), 'broadcast_all'),
-        ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs'),
-        ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs'),
-        ('mul', (S, S, S), (3.14,), 'constant'),
-        ('mul', (), (3.14,), 'scalar_constant'),
-        ('__rmul__', (S, S, S), (3.14,), 'constant'),
-        ('__rmul__', (), (3.14,), 'scalar_constant'),
-        ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,)),
-        ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs'),
-        ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
-        ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all'),
-        ('div', (), (uniform_scalar(0.1),), 'scalar'),
-        ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
-        ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs'),
-        ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'),
-        ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'),
-        ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'),
-        ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'),
-        ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,)),
-        ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
-        ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
-        ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
-        ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
-        ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
-        ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
-        ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
-        ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
-        ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'),
-        ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'),
-        ('transpose', (1, 2, 3), (1, 2), 'dim', [0, 1]),
-        ('transpose', (), (0, 0), 'scalar'),
-        ('transpose', (1,), (0, 0), '1d'),
-        ('transpose', torch.rand(L, L), (0, 1), '2d'),
-        ('transpose', torch.rand(S, S, S), (2, 0), '3d'),
-        ('t', (1, 2), NO_ARGS),
-        ('view', (S, S, S), (S * S, S),),
-        ('view', (S, S, S), (torch.Size([S * S, S]),), 'size'),
-        ('view', (S,), (S,), '1d'),
-        ('view', (), (dont_convert(()),), 'scalar_to_scalar'),
-        ('view', (), (1,), 'scalar_to_1d'),
+        ('add', (S, S, S), ((S, S, S),), '', (True,)),
+        ('add', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)),
+        ('add', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)),
+        ('add', (S, 1, S), ((M, S),), 'broadcast_all', (True,)),
+        ('add', (), ((),), 'scalar', (True,)),
+        ('add', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)),
+        ('add', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
+        ('add', (S, S, S), (3.14,), 'constant', (True,)),
+        ('add', (), (3.14,), 'scalar_constant', (True,)),
+        ('__radd__', (S, S, S), (3.14,), 'constant', (True, 'aten::add')),
+        ('__radd__', (), (3.14,), 'scalar_constant', (True, 'aten::add')),
+        ('sub', (S, S, S), ((S, S, S),), '', (True,)),
+        ('sub', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)),
+        ('sub', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)),
+        ('sub', (S, 1, S), ((M, S),), 'broadcast_all', (True,)),
+        ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)),
+        ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
+        ('sub', (S, S, S), (3.14,), 'constant', (True,)),
+        ('sub', (), (3.14,), 'scalar_constant', (True,)),
+        ('__rsub__', (S, S, S), (3.14,), 'constant', (True, 'aten::rsub')),
+        ('__rsub__', (), (3.14,), 'scalar_constant', (True, 'aten::rsub')),
+        ('mul', (S, S, S), ((S, S, S),), '', (True,)),
+        ('mul', (), ((),), 'scalar', (True,)),
+        ('mul', (S, S, S), ((S, S),), 'broadcast_rhs', (True,)),
+        ('mul', (S, S), ((S, S, S),), 'broadcast_lhs', (True,)),
+        ('mul', (S, 1, S), ((M, S),), 'broadcast_all', (True,)),
+        ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs', (True,)),
+        ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs', (True,)),
+        ('mul', (S, S, S), (3.14,), 'constant', (True,)),
+        ('mul', (), (3.14,), 'scalar_constant', (True,)),
+        ('__rmul__', (S, S, S), (3.14,), 'constant', (True, 'aten::mul')),
+        ('__rmul__', (), (3.14,), 'scalar_constant', (True, 'aten::mul')),
+        ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,), '', (True,)),
+        ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs', (True,)),
+        ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)),
+        ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all', (True,)),
+        ('div', (), (uniform_scalar(0.1),), 'scalar', (True,)),
+        ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)),
+        ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs', (True,)),
+        ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant', (True,)),
+        ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant',
+            (True, [], ['aten::mul', 'aten::reciprocal'])),
+        ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
+        ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant',
+            (True, [], ['aten::mul', 'aten::reciprocal'])),
+        ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), '', (True,)),
+        ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs', (True,)),
+        ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)),
+        ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all', (True,)),
+        ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar', (True,)),
+        ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)),
+        ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs', (True,)),
+        ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True,)),
+        ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant', (True, 'aten::pow')),
+        ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
+        ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant', (True, 'aten::pow')),
+        ('transpose', (1, 2, 3), (1, 2), 'dim', (True,), [0, 1]),
+        ('transpose', (), (0, 0), 'scalar', (True,)),
+        ('transpose', (1,), (0, 0), '1d', (True,)),
+        ('transpose', torch.rand(L, L), (0, 1), '2d', (True,)),
+        ('transpose', torch.rand(S, S, S), (2, 0), '3d', (True,)),
+        ('t', (1, 2), NO_ARGS, '', (True,)),
+        ('view', (S, S, S), (S * S, S), '', (True,)),
+        ('view', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)),
+        ('view', (S,), (S,), '1d', (True,)),
+        ('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
+        ('view', (), (1,), 'scalar_to_1d', (True,)),
         ('reshape', (S, S, S), (S * S, S),),
         ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'),
         ('reshape', (S,), (S,), '1d'),
@@ -201,82 +209,82 @@ def method_tests():
         ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
         ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
         ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
-        ('expand', (S, 1, 1), (S, S, S)),
-        ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
-        ('expand', (S, 1), (S, S, S), 'new_dim'),
-        ('expand', (1,), (S, S, S), '1_element'),
-        ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'),
-        ('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
-        ('expand', (), (1, 3, 2), 'scalar_to_dims'),
-        ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
-        ('exp', (S, S, S), NO_ARGS),
-        ('exp', (), NO_ARGS, 'scalar'),
-        ('expm1', (S, S, S), NO_ARGS),
-        ('expm1', (), NO_ARGS, 'scalar'),
-        ('erf', torch.rand(S, S, S), NO_ARGS),
-        ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
-        ('erfc', torch.rand(S, S, S), NO_ARGS),
-        ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
+        ('expand', (S, 1, 1), (S, S, S), '', (True,)),
+        ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size', (True,)),
+        ('expand', (S, 1), (S, S, S), 'new_dim', (True,)),
+        ('expand', (1,), (S, S, S), '1_element', (True,)),
+        ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1', (True,)),
+        ('expand', (), (dont_convert(()),), 'scalar_to_scalar', (True,)),
+        ('expand', (), (1, 3, 2), 'scalar_to_dims', (True,)),
+        ('expand_as', (S, 1, 1), (torch.rand(S, S, S),), '', (True,)),
+        ('exp', (S, S, S), NO_ARGS, '', (True,)),
+        ('exp', (), NO_ARGS, 'scalar', (True,)),
+        ('expm1', (S, S, S), NO_ARGS, '', (True,)),
+        ('expm1', (), NO_ARGS, 'scalar', (True,)),
+        ('erf', torch.rand(S, S, S), NO_ARGS, '', (True,)),
+        ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('erfc', torch.rand(S, S, S), NO_ARGS, '', (True,)),
+        ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
         ('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
         ('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'),
-        ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS),
-        ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
-        ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS),
-        ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
-        ('log1p', torch.rand(S, S, S), NO_ARGS),
-        ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
-        ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS),
-        ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
-        ('tanh', (S, S, S), NO_ARGS),
-        ('tanh', (), NO_ARGS, 'scalar'),
-        ('sigmoid', (S, S, S), NO_ARGS),
-        ('sigmoid', (), NO_ARGS, 'scalar'),
-        ('sinh', (S, S, S), NO_ARGS),
-        ('sinh', (), NO_ARGS, 'scalar'),
-        ('cosh', (S, S, S), NO_ARGS),
-        ('cosh', (), NO_ARGS, 'scalar'),
-        ('abs', (S, S, S), NO_ARGS),
-        ('abs', (), NO_ARGS, 'scalar'),
-        ('clamp', (S, S, S), (0, 1)),
-        ('clamp', (S, S, S), (None, 0.5), 'min'),
-        ('clamp', (S, S, S), (0.5, None), 'max'),
-        ('clamp', (), (0, 1), 'scalar'),
-        ('clamp', (), (None, 0.5), 'min_scalar'),
-        ('clamp', (), (0.5, None), 'max_scalar'),
-        ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS),
-        ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar'),
-        ('sin', (S, S, S), NO_ARGS),
-        ('sin', (), NO_ARGS, 'scalar'),
-        ('cos', (S, S, S), NO_ARGS),
-        ('cos', (), NO_ARGS, 'scalar'),
-        ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS),
-        ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
-        ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
-        ('atan', (S, S, S), NO_ARGS),
-        ('atan', (), NO_ARGS, 'scalar'),
+        ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+        ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+        ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('log1p', torch.rand(S, S, S), NO_ARGS, '', (True,)),
+        ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+        ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('tanh', (S, S, S), NO_ARGS, '', (True,)),
+        ('tanh', (), NO_ARGS, 'scalar', (True,)),
+        ('sigmoid', (S, S, S), NO_ARGS, '', (True,)),
+        ('sigmoid', (), NO_ARGS, 'scalar', (True,)),
+        ('sinh', (S, S, S), NO_ARGS, '', (True,)),
+        ('sinh', (), NO_ARGS, 'scalar', (True,)),
+        ('cosh', (S, S, S), NO_ARGS, '', (True,)),
+        ('cosh', (), NO_ARGS, 'scalar', (True,)),
+        ('abs', (S, S, S), NO_ARGS, '', (True,)),
+        ('abs', (), NO_ARGS, 'scalar', (True,)),
+        ('clamp', (S, S, S), (0, 1), '', (True,)),
+        ('clamp', (S, S, S), (None, 0.5), 'min', (True,)),
+        ('clamp', (S, S, S), (0.5, None), 'max', (True,)),
+        ('clamp', (), (0, 1), 'scalar', (True,)),
+        ('clamp', (), (None, 0.5), 'min_scalar', (True,)),
+        ('clamp', (), (0.5, None), 'max_scalar', (True,)),
+        ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS, '', (True,)),
+        ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('sin', (S, S, S), NO_ARGS, '', (True,)),
+        ('sin', (), NO_ARGS, 'scalar', (True,)),
+        ('cos', (S, S, S), NO_ARGS, '', (True,)),
+        ('cos', (), NO_ARGS, 'scalar', (True,)),
+        ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS, '', (True,)),
+        ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)),
+        ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS, '', (True,)),
+        ('atan', (S, S, S), NO_ARGS, '', (True,)),
+        ('atan', (), NO_ARGS, 'scalar', (True,)),
         ('atan2', (S, S, S), ((S, S, S),)),
         ('atan2', (), ((),), 'scalar'),
         ('atan2', (S, S, S), ((S,),), 'broadcast_rhs'),
         ('atan2', (S,), ((S, S, S),), 'broadcast_lhs'),
         ('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
-        ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS),
-        ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar'),
-        ('round', (S, S, S), NO_ARGS),
-        ('round', (), NO_ARGS, 'scalar'),
+        ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS, '', (True,)),
+        ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('round', (S, S, S), NO_ARGS, '', (True,)),
+        ('round', (), NO_ARGS, 'scalar', (True,)),
         ('sign', (S, S, S), NO_ARGS),
         ('sign', (), NO_ARGS, 'scalar'),
-        ('trunc', (S, S, S), NO_ARGS),
-        ('trunc', (), NO_ARGS, 'scalar'),
-        ('floor', (S, S, S), NO_ARGS),
-        ('floor', (), NO_ARGS, 'scalar'),
-        ('ceil', (S, S, S), NO_ARGS),
-        ('ceil', (), NO_ARGS, 'scalar'),
-        ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS),
-        ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
-        ('frac', (S, S, S), NO_ARGS),
-        ('frac', (), NO_ARGS, 'scalar'),
-        ('fmod', (S, S, S), (1.5,)),
-        ('fmod', (), (1.5,), 'scalar'),
+        ('trunc', (S, S, S), NO_ARGS, '', (True,)),
+        ('trunc', (), NO_ARGS, 'scalar', (True,)),
+        ('floor', (S, S, S), NO_ARGS, '', (True,)),
+        ('floor', (), NO_ARGS, 'scalar', (True,)),
+        ('ceil', (S, S, S), NO_ARGS, '', (True,)),
+        ('ceil', (), NO_ARGS, 'scalar', (True,)),
+        ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS, '', (True,)),
+        ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar', (True,)),
+        ('frac', (S, S, S), NO_ARGS, '', (True,)),
+        ('frac', (), NO_ARGS, 'scalar', (True,)),
+        ('fmod', (S, S, S), (1.5,), '', (True,)),
+        ('fmod', (), (1.5,), 'scalar', (True,)),
         ('fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
         ('fmod', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
         ('fmod', (S, S, S), (non_differentiable(torch.rand(S) + 1.5),), 'tensor_broadcast_rhs'),
@@ -284,217 +292,217 @@ def method_tests():
         ('fmod', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
         ('fmod', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
         ('fmod', (S, S, S), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor_broadcast_rhs'),
-        ('remainder', (S, S, S), (1.5,)),
-        ('remainder', (), (1.5,), 'scalar'),
+        ('remainder', (S, S, S), (1.5,), '', (True,)),
+        ('remainder', (), (1.5,), 'scalar', (True,)),
         ('remainder', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
         ('remainder', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
         ('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
         ('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
         ('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
-        ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast'),
-        ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'),
-        ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'),
-        ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all'),
-        ('lerp', (), ((), 0.4), 'scalar'),
-        ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs'),
-        ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs'),
+        ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast', (True,)),
+        ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs', (True,)),
+        ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs', (True,)),
+        ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all', (True,)),
+        ('lerp', (), ((), 0.4), 'scalar', (True,)),
+        ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs', (True,)),
+        ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs', (True,)),
         ('max', (S, S, S), NO_ARGS),
-        ('max', (S, S, S), (1,), 'dim', [0]),
-        ('max', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+        ('max', (S, S, S), (1,), 'dim', (), [0]),
+        ('max', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
         ('max', (), NO_ARGS, 'scalar'),
-        ('max', (), (0,), 'scalar_dim', [0]),
-        ('max', (), (0, True,), 'scalar_keepdim_dim', [0]),
-        ('max', (S, S, S), ((S, S, S),), 'elementwise'),
-        ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'),
-        ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'),
-        ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'),
-        ('max', (), ((),), 'scalar_elementwise'),
-        ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'),
-        ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'),
-        ('min', (S, S, S), NO_ARGS),
-        ('min', (S, S, S), (1,), 'dim', [0]),
-        ('min', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+        ('max', (), (0,), 'scalar_dim', (), [0]),
+        ('max', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
+        ('max', (S, S, S), ((S, S, S),), 'elementwise', (True,)),
+        ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs', (True,)),
+        ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs', (True,)),
+        ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all', (True,)),
+        ('max', (), ((),), 'scalar_elementwise', (True,)),
+        ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs', (True,)),
+        ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs', (True,)),
+        ('min', (S, S, S), NO_ARGS),
+        ('min', (S, S, S), (1,), 'dim', (), [0]),
+        ('min', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
         ('min', (), NO_ARGS, 'scalar'),
-        ('min', (), (0,), 'scalar_dim', [0]),
-        ('min', (), (0, True,), 'scalar_keepdim_dim', [0]),
-        ('min', (S, S, S), ((S, S, S),), 'elementwise'),
-        ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'),
-        ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'),
-        ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'),
-        ('min', (), ((),), 'scalar_elementwise'),
-        ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'),
-        ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'),
-        ('mean', (S, S, S), NO_ARGS),
-        ('mean', (S, S, S), (1,), 'dim', [0]),
-        ('mean', (S, S, S), (1, True,), 'keepdim_dim', [0]),
-        ('mean', (), NO_ARGS, 'scalar'),
-        ('mean', (), (0,), 'scalar_dim', [0]),
-        ('mean', (), (0, True,), 'scalar_keepdim_dim', [0]),
+        ('min', (), (0,), 'scalar_dim', (), [0]),
+        ('min', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
+        ('min', (S, S, S), ((S, S, S),), 'elementwise', (True,)),
+        ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs', (True,)),
+        ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs', (True,)),
+        ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all', (True,)),
+        ('min', (), ((),), 'scalar_elementwise', (True,)),
+        ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs', (True,)),
+        ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs', (True,)),
+        ('mean', (S, S, S), NO_ARGS, '', (True,)),
+        ('mean', (S, S, S), (1,), 'dim', (True,), [0]),
+        ('mean', (S, S, S), (1, True,), 'keepdim_dim', (True,), [0]),
+        ('mean', (), NO_ARGS, 'scalar', (True,)),
+        ('mean', (), (0,), 'scalar_dim', (True,), [0]),
+        ('mean', (), (0, True,), 'scalar_keepdim_dim', (True,), [0]),
         ('kthvalue', (S, S, S), (2,)),
         ('kthvalue', (), (1,), 'scalar'),
-        ('kthvalue', (S, S, S), (2, 1,), 'dim', [1]),
-        ('kthvalue', (), (1, 0,), 'scalar_dim', [1]),
-        ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', [1]),
-        ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', [1]),
-        ('kthvalue', (S,), (2, 0,), 'dim_1d', [1]),
-        ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', [1]),
+        ('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]),
+        ('kthvalue', (), (1, 0,), 'scalar_dim', (), [1]),
+        ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', (), [1]),
+        ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', (), [1]),
+        ('kthvalue', (S,), (2, 0,), 'dim_1d', (), [1]),
+        ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', (), [1]),
         ('median', (S, S, S), NO_ARGS),
-        ('median', (S, S, S), (1,), 'dim', [0]),
-        ('median', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+        ('median', (S, S, S), (1,), 'dim', (), [0]),
+        ('median', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
         ('median', (), NO_ARGS, 'scalar'),
-        ('median', (), (0,), 'scalar_dim', [0]),
-        ('median', (), (0, True,), 'scalar_keepdim_dim', [0]),
+        ('median', (), (0,), 'scalar_dim', (), [0]),
+        ('median', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
         ('mode', (S, S, S), NO_ARGS),
-        ('mode', (S, S, S), (1,), 'dim', [0]),
-        ('mode', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+        ('mode', (S, S, S), (1,), 'dim', (), [0]),
+        ('mode', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
         ('mode', (), NO_ARGS, 'scalar'),
-        ('mode', (), (0,), 'scalar_dim', [0]),
-        ('mode', (), (0, True,), 'scalar_keepdim_dim', [0]),
+        ('mode', (), (0,), 'scalar_dim', (), [0]),
+        ('mode', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
         ('sum', (S, S, S), NO_ARGS),
-        ('sum', (S, S, S), (1,), 'dim', [0]),
-        ('sum', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+        ('sum', (S, S, S), (1,), 'dim', (), [0]),
+        ('sum', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
         ('sum', (), NO_ARGS, 'scalar'),
-        ('sum', (), (0,), 'scalar_dim', [0]),
-        ('sum', (), (0, True,), 'scalar_keepdim_dim', [0]),
+        ('sum', (), (0,), 'scalar_dim', (), [0]),
+        ('sum', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
         ('sum', (S, S, S), ([1, 2],), 'multi_dim'),
         ('sum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
         ('prod', (S, S, S), NO_ARGS),
-        ('prod', (S, S, S), (1,), 'dim', [0]),
-        ('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]),
+        ('prod', (S, S, S), (1,), 'dim', (), [0]),
+        ('prod', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
         ('prod', (), NO_ARGS, 'scalar'),
-        ('prod', (), (0,), 'scalar_dim', [0]),
-        ('prod', (), (0, True,), 'scalar_keepdim_dim', [0]),
+        ('prod', (), (0,), 'scalar_dim', (), [0]),
+        ('prod', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
         ('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'),
         ('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'),
         ('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'),
-        ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', [0]),
-        ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', [0]),
-        ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', [0]),
-        ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', [0]),
-        ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', [0]),
-        ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', [0]),
+        ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', (), [0]),
+        ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', (), [0]),
+        ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', (), [0]),
+        ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', (), [0]),
+        ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', (), [0]),
+        ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', (), [0]),
         ('prod', prod_single_zero(S), NO_ARGS, 'single_zero'),
         ('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'),
-        ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', [0]),
-        ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', [0]),
-        ('var', (S, S, S), NO_ARGS),
-        ('var', (S, S, S), (1,), 'dim', [0]),
-        ('var', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
-        ('var', (S,), (0,), 'dim_1d', [0]),
-        ('var', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
-        ('std', (S, S, S), NO_ARGS),
-        ('std', (S, S, S), (1,), 'dim', [0]),
-        ('std', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
-        ('std', (S,), (0,), 'dim_1d', [0]),
-        ('std', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
-        ('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]),
+        ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]),
+        ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]),
+        ('var', (S, S, S), NO_ARGS, '', (True,)),
+        ('var', (S, S, S), (1,), 'dim', (True,), [0]),
+        ('var', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]),
+        ('var', (S,), (0,), 'dim_1d', (True,), [0]),
+        ('var', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]),
+        ('std', (S, S, S), NO_ARGS, '', (True,)),
+        ('std', (S, S, S), (1,), 'dim', (True,), [0]),
+        ('std', (S, S, S), (1, True, True), 'keepdim_dim', (True,), [0]),
+        ('std', (S,), (0,), 'dim_1d', (True,), [0]),
+        ('std', (S,), (0, True, True), 'keepdim_dim_1d', (True,), [0]),
+        ('renorm', (S, S, S), (2, 1, 0.5), 'dim', (), [1]),
         ('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
         ('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'),
         ('repeat', (S,), (2,), 'single_number'),
         ('repeat', (), (2, 3), 'scalar'),
         ('repeat', (2, 2), (3, 2)),
         ('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'),
-        ('cumsum', (S, S, S), (0,), 'dim0', [0]),
-        ('cumsum', (S, S, S), (1,), 'dim1', [0]),
-        ('cumsum', (S, S, S), (1,), 'dim1_cast', [0], (), lambda x: x, {'dtype': torch.float64}),
-        ('cumsum', (), (0,), 'dim0_scalar', [0]),
+        ('cumsum', (S, S, S), (0,), 'dim0', (), [0]),
+        ('cumsum', (S, S, S), (1,), 'dim1', (), [0]),
+        ('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), lambda x: x, {'dtype': torch.float64}),
+        ('cumsum', (), (0,), 'dim0_scalar', (), [0]),
         ('cumprod', (S, S, S), (0,)),
-        ('cumprod', (S, S, S), (1,), 'dim1', [0]),
+        ('cumprod', (S, S, S), (1,), 'dim1', (), [0]),
         ('cumprod', (), (0,), 'scalar'),
         ('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'),
-        ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', [0]),
-        ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', [0]),
-        ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', [0]),
-        ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', [0], (), lambda x: x, {'dtype': torch.float64}),
-        ('unfold', (), (0, 1, 1), 'scalar', [0]),
-        ('unfold', (S, S, S, S), (1, 3, 1), '', [0]),
-        ('unfold', (S, S, S), (2, 3, 2), 'lastdim', [0]),
-        ('addmm', (S, M), ((S, S), (S, M)),),
-        ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs'),
-        ('addmm', (S, M), ((S, S), (S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs'),
-        ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', (), [0]),
+        ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', (), [0]),
+        ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', (), [0]),
+        ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', (), [0], (), lambda x: x, {'dtype': torch.float64}),
+        ('unfold', (), (0, 1, 1), 'scalar', (), [0]),
+        ('unfold', (S, S, S, S), (1, 3, 1), '', (), [0]),
+        ('unfold', (S, S, S), (2, 3, 2), 'lastdim', (), [0]),
+        ('addmm', (S, M), ((S, S), (S, M)), '', (True, ['aten::add', 'aten::mm'])),
+        ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs', (True, ['aten::add', 'aten::mm'])),
+        ('addmm', (S, M), ((S, S), (S, M)), 'coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs', (True, ['aten::add', 'aten::mm'])),
+        ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (True,), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
         ('addbmm', (S, M), ((S, S, S), (S, S, M)),),
         ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'),
-        ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef',
+        ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', (),
          (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
         ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
-        ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x,
+        ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x,
          {'beta': 0.2, 'alpha': 0.6}),
         ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
         ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'),
-        ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef',
+        ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef', (),
          (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
         ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
-        ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x,
+        ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x,
          {'beta': 0.2, 'alpha': 0.6}),
         ('addmv', (S,), ((S, M), (M,)),),
         ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs'),
-        ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
         ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs'),
-        ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
         ('addr', (S, M), ((S,), (M,)),),
         ('addr', (), ((S,), (M,)), 'broadcast_lhs'),
-        ('addr', (S, M), ((S,), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
-        ('dot', (L,), ((L,),),),
-        ('mm', (S, M), ((M, S),)),
-        ('bmm', (M, S, M), ((M, M, S),)),
-        ('mv', (S, M), ((M,),)),
+        ('addr', (S, M), ((S,), (M,)), 'coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
+        ('dot', (L,), ((L,),), '', (True,)),
+        ('mm', (S, M), ((M, S),), '', (True,)),
+        ('bmm', (M, S, M), ((M, M, S),), '', (True,)),
+        ('mv', (S, M), ((M,),), '', (True,)),
         ('ger', (S,), ((M,),)),
-        ('matmul', (L,), ((L,),),),
-        ('matmul', (S, M), ((M,),), "2d_1d"),
-        ('matmul', (M, ), ((M, S),), "1d_2d"),
-        ('matmul', (S, M), ((M, S),), "2d_2d"),
-        ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"),
-        ('matmul', (S, S, M, M), ((M,),), "4d_1d"),
-        ('matmul', (M,), ((S, S, M, S),), "1d_4d"),
+        ('matmul', (L,), ((L,),), '', (True,)),
+        ('matmul', (S, M), ((M,),), "2d_1d", (True,)),
+        ('matmul', (M, ), ((M, S),), "1d_2d", (True,)),
+        ('matmul', (S, M), ((M, S),), "2d_2d", (True,)),
+        ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d", (True,)),
+        ('matmul', (S, S, M, M), ((M,),), "4d_1d", (True,)),
+        ('matmul', (M,), ((S, S, M, S),), "1d_4d", (True,)),
         ('matrix_power', (S, S), [2], "n=2"),
         ('matrix_power', (S, S, S), [3], "n=3"),
         ('matrix_power', (S, S, S), [1], "n=1"),
         ('matrix_power', (S, S, S), [0], "n=0"),
-        ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1",
+        ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1", (),
          NO_ARGS, [skipIfNoLapack]),
-        ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3",
+        ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3", (),
          NO_ARGS, [skipIfNoLapack]),
-        ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2",
+        ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2", (),
          NO_ARGS, [skipIfNoLapack]),
         ('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
         ('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
         ('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),
         ('mvlgamma', torch.empty(S, S).uniform_(2.5, 5), [5], "p=5"),
-        ('addcmul', (S, S), ((S, S), (S, S))),
-        ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
-        ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
-        ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}),
-        ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
-        ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}),
-        ('addcmul', (), ((), ()), 'scalar'),
-        ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs'),
-        ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'),
-        ('addcmul', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}),
-        ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
-        ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}),
+        ('addcmul', (S, S), ((S, S), (S, S)), '', (True,)),
+        ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs', (True,)),
+        ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all', (True,)),
+        ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (True,), (), (), lambda x: x, {'value': 0.5}),
+        ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (True,), (), (), lambda x: x, {'value': 0.5}),
+        ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (True,), (), (), lambda x: x, {'value': 0.5}),
+        ('addcmul', (), ((), ()), 'scalar', (True,)),
+        ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs', (True,)),
+        ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs', (True,)),
+        ('addcmul', (), ((), ()), 'scalar_scale', (True,), (), (), lambda x: x, {'value': 0.5}),
+        ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (True,), (), (), lambda x: x, {'value': 0.5}),
+        ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (True,), (), (), lambda x: x, {'value': 0.5}),
         ('addcdiv', (S, S), ((S, S), (S, S))),
         ('addcdiv', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
         ('addcdiv', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
-        ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}),
-        ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
-        ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}),
+        ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), (), lambda x: x, {'value': 0.5}),
+        ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), (), lambda x: x, {'value': 0.5}),
+        ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), (), lambda x: x, {'value': 0.5}),
         ('addcdiv', (), ((), ()), 'scalar'),
         ('addcdiv', (S, S), ((), ()), 'scalar_broadcast_rhs'),
         ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'),
-        ('addcdiv', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}),
-        ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
-        ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}),
+        ('addcdiv', (), ((), ()), 'scalar_scale', (), (), (), lambda x: x, {'value': 0.5}),
+        ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), (), lambda x: x, {'value': 0.5}),
+        ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), (), lambda x: x, {'value': 0.5}),
         ('zero_', (S, S, S), NO_ARGS),
         ('zero_', (), NO_ARGS, 'scalar'),
-        ('logsumexp', (S, S), (1,)),
-        ('logsumexp', (), (0,), 'scalar'),
+        ('logsumexp', (S, S), (1,), '', (True,)),
+        ('logsumexp', (), (0,), 'scalar', (True,)),
         ('norm', (S, S), (), 'default'),
         ('norm', (S, S), (2,), '2'),
         ('norm', (S, S), (0,), '0'),
@@ -505,29 +513,29 @@ def method_tests():
         ('norm', (S, S), (-inf,), '-inf'),
         ('norm', (S, S), ('fro',), 'fro_default'),
         ('norm', (S, S), ('fro', [0, 1],), 'fro'),
-        ('norm', (S, S), ('nuc',), 'nuc', NO_ARGS, [skipIfNoLapack]),
+        ('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
         ('norm', (S, S), (-1,), 'neg_1'),
         ('norm', (S, S), (-2,), 'neg_2'),
         ('norm', (S, S), (-0.5,), 'neg_0_5'),
         ('norm', (S, S), (-1.5,), 'neg_1_5'),
-        ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', [1]),
-        ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', [1]),
-        ('norm', (S, S), (0, 1,), '0_2_dim', [1]),
-        ('norm', (S, S), (1, 1,), '1_2_dim', [1]),
-        ('norm', (S, S), (2, 1,), '2_2_dim', [1]),
-        ('norm', (S, S), (3, 1,), '3_2_dim', [1]),
+        ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', (), [1]),
+        ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', (), [1]),
+        ('norm', (S, S), (0, 1,), '0_2_dim', (), [1]),
+        ('norm', (S, S), (1, 1,), '1_2_dim', (), [1]),
+        ('norm', (S, S), (2, 1,), '2_2_dim', (), [1]),
+        ('norm', (S, S), (3, 1,), '3_2_dim', (), [1]),
         ('norm', (S, S), (inf, 1,), 'inf_2_dim'),
         ('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5_default'),
-        ('norm', (S, S, S), (2, 1), '2_dim', [1]),
-        ('norm', (S, S, S), (3, 1), '3_dim', [1]),
-        ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', [1]),
-        ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', [1]),
-        ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', [1]),
-        ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', [1]),
-        ('norm', (), (2, 0), '2_dim_scalar', [1]),
-        ('norm', (), (3, 0), '3_dim_scalar', [1]),
-        ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', [1]),
-        ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', [1]),
+        ('norm', (S, S, S), (2, 1), '2_dim', (), [1]),
+        ('norm', (S, S, S), (3, 1), '3_dim', (), [1]),
+        ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', (), [1]),
+        ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', (), [1]),
+        ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', (), [1]),
+        ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', (), [1]),
+        ('norm', (), (2, 0), '2_dim_scalar', (), [1]),
+        ('norm', (), (3, 0), '3_dim_scalar', (), [1]),
+        ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', (), [1]),
+        ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', (), [1]),
         ('clone', (S, M, S), NO_ARGS),
         ('clone', (), NO_ARGS, 'scalar'),
         ('dist', (S, S, S), ((S, S, S),)),
@@ -580,84 +588,84 @@ def method_tests():
         ('trace', (M, M), NO_ARGS),
         ('cross', (S, 3), ((S, 3),)),
         ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
-        ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]),
-        ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', [0]),
-        ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', [0]),
-        ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', [0]),
-        ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]),
-        ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]),
-        ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]),
-        ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]),
-        ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]),
-        ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
-        ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', [0]),
-        ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', [0]),
-        ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', [0]),
-        ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', [0]),
-        ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
+        ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', (), [0]),
+        ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', (), [0]),
+        ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', (), [0]),
+        ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', (), [0]),
+        ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]),
+        ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]),
+        ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', (), [0]),
+        ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', (), [0]),
+        ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', (), [0]),
+        ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', (), [0]),
+        ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', (), [0]),
+        ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', (), [0]),
+        ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', (), [0]),
+        ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', (), [0]),
+        ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
         ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3),
-         NO_ARGS, 'batched', NO_ARGS, [skipIfNoLapack]),
-        ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
-        ('det', (1, 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]),
-        ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]),
-        ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', NO_ARGS, [skipIfNoLapack]),
-        ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
-        ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', NO_ARGS, [skipIfNoLapack]),
-        ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', NO_ARGS, [skipIfNoLapack]),
-        ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', NO_ARGS, [skipIfNoLapack]),
+         NO_ARGS, 'batched', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', (S, S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', (1, 1), NO_ARGS, '1x1', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', (), NO_ARGS, [skipIfNoLapack]),
+        ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', (), NO_ARGS, [skipIfNoLapack]),
         ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
-         'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
+         'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
         # For `logdet` and `slogdet`, the function at det=0 is not smooth.
         # We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use
         # `make_nonzero_det` to make the random matrices have nonzero det. For
         # `logdet`, we also set `make_nonzero_det(matrix, sign=1)` to make the
         # matrix have positive det.
-        ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
-        ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]),
+        ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
+        ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', (), NO_ARGS, [skipIfNoLapack]),
         ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS,
-         'symmetric', NO_ARGS, [skipIfNoLapack]),
+         'symmetric', (), NO_ARGS, [skipIfNoLapack]),
         ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS,
-         'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
+         'symmetric_pd', (), NO_ARGS, [skipIfNoLapack]),
         ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS,
-         'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
+         'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack]),
         ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS,
-         '1x1_pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+         '1x1_pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
         ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS,
-         '1x1_neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+         '1x1_neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
         ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS,
-         'pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+         'pos_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
         ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS,
-         'neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+         'neg_det', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
         ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS,
-         'symmetric', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+         'symmetric', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
         ('slogdet', lambda: random_symmetric_pd_matrix(S), NO_ARGS,
-         'symmetric_pd', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+         'symmetric_pd', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
         ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
-         'distinct_singular_values', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
-        ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', NO_ARGS, [skipIfNoLapack]),
-        ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', NO_ARGS, [skipIfNoLapack]),
-        ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', NO_ARGS, [skipIfNoLapack]),
-        ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
+         'distinct_singular_values', (), NO_ARGS, [skipIfNoLapack], itemgetter(1)),
+        ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', (), NO_ARGS, [skipIfNoLapack]),
+        ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', (), NO_ARGS, [skipIfNoLapack]),
+        ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', (), NO_ARGS, [skipIfNoLapack]),
+        ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', (), NO_ARGS, [skipIfNoLapack]),
         ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS,
-         'wide', NO_ARGS, [skipIfNoLapack]),
+         'wide', (), NO_ARGS, [skipIfNoLapack]),
         ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS,
-         'tall', NO_ARGS, [skipIfNoLapack]),
+         'tall', (), NO_ARGS, [skipIfNoLapack]),
         ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,),
-         'wide_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])),
+         'wide_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])),
         ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,),
-         'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
+         'tall_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
         ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
-         'large', NO_ARGS, [skipIfNoLapack]),
+         'large', (), NO_ARGS, [skipIfNoLapack]),
         ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
-            S, silent=True),), '', NO_ARGS, [skipIfNoLapack]),
+            S, silent=True),), '', (), NO_ARGS, [skipIfNoLapack]),
         ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),
-         'batched', NO_ARGS, [skipIfNoLapack]),
+         'batched', (), NO_ARGS, [skipIfNoLapack]),
         ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),),
-         'batched_dims', NO_ARGS, [skipIfNoLapack]),
+         'batched_dims', (), NO_ARGS, [skipIfNoLapack]),
         ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),),
-         'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
+         'batched_broadcast_A', (), NO_ARGS, [skipIfNoLapack]),
         ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),),
-         'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
+         'batched_broadcast_b', (), NO_ARGS, [skipIfNoLapack]),
         ('fill_', (S, S, S), (1,), 'number'),
         ('fill_', (), (1,), 'number_scalar'),
         ('fill_', (S, S, S), ((),), 'variable'),
@@ -697,41 +705,41 @@ def method_tests():
         ('ge_', (), (0,), 'pyscalar_scalar'),
         ('lt_', (), (0,), 'pyscalar_scalar'),
         ('le_', (), (0,), 'pyscalar_scalar'),
-        ('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
-        ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim'),
-        ('permute', (), (dont_convert(()),), 'scalar'),
-        ('select', (S, S, S), (1, 2), 'dim', [0]),
-        ('select', (S, S, S), (1, -1), 'wrap_dim', [0]),
+        ('permute', (1, 2, 3, 4), (0, 2, 3, 1), '', (True,)),
+        ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim', (True,)),
+        ('permute', (), (dont_convert(()),), 'scalar', (True,)),
+        ('select', (S, S, S), (1, 2), 'dim', (), [0]),
+        ('select', (S, S, S), (1, -1), 'wrap_dim', (), [0]),
         ('select', (S,), (0, 2), '1d'),
-        ('narrow', (S, S, S), (1, 2, 2), 'dim', [0]),
-        ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', [0]),
-        ('squeeze', (S, 1, S, 1), NO_ARGS),
-        ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones'),
-        ('squeeze', (S, 1, S, 1), (1,), '1_dim', [0]),
-        ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', [0]),
-        ('squeeze', (), (0,), 'scalar', [0]),
-        ('unsqueeze', (S, S, S), (0,), 'first', [0]),
-        ('unsqueeze', (S, S, S), (1,), 'middle', [0]),
-        ('unsqueeze', (S, S, S), (3,), 'last', [0]),
-        ('unsqueeze', (), (0,), 'scalar', [0]),
-        ('chunk', (S, S, S), (2,)),
-        ('chunk', (S, S, S), (S, 1), 'dim', [1]),
+        ('narrow', (S, S, S), (1, 2, 2), 'dim', (), [0]),
+        ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', (), [0]),
+        ('squeeze', (S, 1, S, 1), NO_ARGS, '', (True,)),
+        ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones', (True,)),
+        ('squeeze', (S, 1, S, 1), (1,), '1_dim', (True,), [0]),
+        ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', (True,), [0]),
+        ('squeeze', (), (0,), 'scalar', (True,), [0]),
+        ('unsqueeze', (S, S, S), (0,), 'first', (True,), [0]),
+        ('unsqueeze', (S, S, S), (1,), 'middle', (True,), [0]),
+        ('unsqueeze', (S, S, S), (3,), 'last', (True,), [0]),
+        ('unsqueeze', (), (0,), 'scalar', (True,), [0]),
+        ('chunk', (S, S, S), (2,), '', (True, 'prim::ConstantChunk')),
+        ('chunk', (S, S, S), (S, 1), 'dim', (True, 'prim::ConstantChunk'), [1]),
         ('split', (S, S, S), (2,)),
-        ('split', (S, S, S), (S, 1), 'dim', [1]),
+        ('split', (S, S, S), (S, 1), 'dim', (), [1]),
         ('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list'),
-        ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', [1]),
-        ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', [0]),
-        ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', [0]),
-        ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', [0]),
-        ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', [0]),
-        ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]),
-        ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
-        ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
-        ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]),
-        ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]),
-        ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
-        ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
-        ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
+        ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', (), [1]),
+        ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]),
+        ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]),
+        ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]),
+        ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', (), [0]),
+        ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', (), [0]),
+        ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]),
+        ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]),
+        ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', (), [0]),
+        ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', (), [0]),
+        ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', (), [0]),
+        ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', (), [0]),
+        ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', (), [0]),
         ('masked_select', (M, M), (mask_not_all_zeros((M, M)),)),
         ('masked_select', (M, M), (mask_not_all_zeros((M,)),), 'broadcast_rhs'),
         ('masked_select', (M,), (mask_not_all_zeros((M, M)),), 'broadcast_lhs'),
@@ -770,22 +778,22 @@ def method_tests():
         ('sort', (), (0,), 'dim_scalar'),
         ('sort', (), (0, True), 'dim_desc_scalar'),
         ('topk', (S, M, S), (3,)),
-        ('topk', (S, M, S), (3, 1), 'dim', [1]),
-        ('topk', (S, M, S), (3, 1, True), 'dim_desc', [1]),
-        ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', [1]),
+        ('topk', (S, M, S), (3, 1), 'dim', (), [1]),
+        ('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]),
+        ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', (), [1]),
         ('topk', (), (1,), 'scalar'),
-        ('topk', (), (1, 0), 'dim_scalar', [1]),
-        ('topk', (), (1, 0, True), 'dim_desc_scalar', [1]),
-        ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', [1]),
+        ('topk', (), (1, 0), 'dim_scalar', (), [1]),
+        ('topk', (), (1, 0, True), 'dim_desc_scalar', (), [1]),
+        ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', (), [1]),
         ('take', (S, S, S), (torch.LongTensor([[-3, 2], [20, 2]]),)),
         ('take', (S, S, S), (torch.tensor(0, dtype=torch.int64),), 'scalar_index'),
         ('take', (), (torch.LongTensor([0]),), 'scalar_data'),
         ('take', (), (torch.tensor(0, dtype=torch.int64),), 'scalar_both'),
-        ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M))),
-        ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all'),
-        ('where', (), (bernoulli_scalar(), ()), 'scalar'),
-        ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask'),
-        ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask'),
+        ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M)), '', (True,)),
+        ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all', (True,)),
+        ('where', (), (bernoulli_scalar(), ()), 'scalar', (True,)),
+        ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask', (True,)),
+        ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask', (True,)),
         ('__getitem__', torch.randn(S, S, S), (dont_convert([1, 2]),)),
         ('__getitem__', torch.randn(S, S, S), (slice(0, 3),), 'slice'),
         ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(0, 3), 1]),), 'slice_index'),
index d966685..0a2c08e 100644 (file)
@@ -2986,6 +2986,7 @@ def add_test(
         self_size,
         args,
         variant_name='',
+        check_ad=(),  # only used in test_jit
         dim_args_idx=(),
         skipTestIf=(),
         output_process_fn=lambda x: x,
index 379594a..498cf0e 100644 (file)
@@ -68,6 +68,10 @@ except ImportError:
 
 skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
 
+# Note: creating FusionGroups is currently device-independent.
+# FusionGroup creation with CPU is disabled.
+FUSION_ENABLED = torch._C._jit_can_fuse_on_cpu() or torch._C._jit_can_fuse_on_gpu()
+
 RUN_CUDA = torch.cuda.is_available()
 RUN_CUDA_HALF = RUN_CUDA
 if torch.cuda.is_available():
@@ -385,6 +389,25 @@ class JitTestCase(TestCase):
         torch._C._jit_pass_lint(graph)
         self.assertExpected(str(graph), *args, **kwargs)
 
+    def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
+        if not FUSION_ENABLED:
+            nonfusible_nodes = nonfusible_nodes + fusible_nodes
+            fusible_nodes = []
+        diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
+        diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
+
+        # For any non-fusible node, it must show up in one of the DifferentiableGraph.
+        found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0)\
+            or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes])
+
+        # For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph.
+        fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
+        fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
+        found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0)\
+            or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes])
+
+        self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes)
+
     def run_pass(self, name, trace):
         if isinstance(trace, torch._C.Graph):
             graph = trace
@@ -11283,6 +11306,15 @@ EXCLUDE_SCRIPT = {
     'test_nn_fold',
 }
 
+# chunk returns a list in scripting and we don't unpack the list,
+# Thus it won't be replaced by ConstantChunk and run AD.
+# It's explicitly checked in test_chunk_constant_script_ad
+EXCLUDE_SCRIPT_AD_CHECK = {
+    'test_chunk',
+    'test_chunk_dim',
+    'test_chunk_dim_neg0',
+}
+
 EXCLUDE_PYTHON_PRINT = {
     # no support for BroadcastingList in python printer
     'test_nn_max_unpool1d',
@@ -11301,23 +11333,6 @@ EXCLUDE_SCRIPT_MODULES = {
     'test_nn_AdaptiveMaxPool3d_tuple_none',
 }
 
-DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
-    'test_nn_avg_pool2d',
-    'test_nn_adaptive_avg_pool1d',
-    'test_nn_adaptive_avg_pool2d',
-    'test_nn_adaptive_avg_pool3d',
-    'test_nn_batch_norm',
-    'test_nn_embedding',
-    'test_nn_log_softmax',
-    'test_nn_softmax',
-    'test_nn_softmax_with_all_args',
-    'test_nn_threshold',
-    'test_nn_nll_loss',
-    # Should have added all test_nn_interpolate_* here,
-    # but it's using autodiff since its subgraph is over
-    # 2 nodes.
-}
-
 
 # make a new function where all non-tensor arguments in 'args' have been partially
 # applied, and all tensor arguments remain.
@@ -11528,6 +11543,17 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
     def assertGraphSize(self, graph, size):
         self.assertEqual(len(list(graph.nodes())), size)
 
+    def test_chunk_constant_script_ad(self):
+        @torch.jit.script
+        def func(x):
+            x1, x2 = torch.chunk(x, 2)
+            return (x1, x2)
+
+        input = torch.rand(6, 10).requires_grad_()
+        func.debug_disable_autodiff_subgraph_inlining()
+        output = func(input)
+        self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
+
     def test_simple_merge(self):
         # o --> o
         def fn(x, y, z):
@@ -11867,6 +11893,7 @@ EXCLUDE_MODULE_EXPORT_IMPORT = {
 #   args (tuple represents shape of a tensor arg),
 #   test variant name(will be used at test name suffix,
 #       'inplace' skips grad tests),                         // optional
+#   (True, nonfusible_nodes, fusible_nodes) for autodiff     // optional
 #   fn to determine if test should be skipped,               // optional
 #   fn mapping output to part that should be gradcheck'ed,   // optional
 #   kwargs for function,                                     // optional
@@ -11880,7 +11907,7 @@ nn_functional_tests = [
     ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
     ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
     ('avg_pool1d', (S, S, S), (3,)),
-    ('avg_pool2d', (S, S, S, S), (3,)),
+    ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
     ('avg_pool3d', (S, S, S, S, S), (3,)),
     ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
     ('max_pool1d', (S, S, S), (2, 1)),
@@ -11895,15 +11922,17 @@ nn_functional_tests = [
     ('adaptive_max_pool1d', (S, S, S), (5,)),
     ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
     ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
-    ('adaptive_avg_pool1d', (S, S, S), (5,)),
-    ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],)),
-    ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
-    ('dropout', (S, S, S), (0.5,)),
+    ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
+    ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
+    ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
+    ('dropout', (S, S, S), (0.5,), '', (True,
+                                        ['prim::is_cuda', 'aten::bernoulli_'],
+                                        ['aten::rand_like', 'aten::lt', 'aten::type_as', 'aten::mul', 'aten::div'])),
     ('alpha_dropout', (S, S, S), (0.5,)),
     ('dropout2d', (S, S, S), (0.5,)),
     ('dropout3d', (S, S, S), (0.5,)),
     ('feature_alpha_dropout', (S, S, S), (0.5,)),
-    ('threshold', (S, S, S), (0.1, 2.),),
+    ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
     ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
     ('relu', (S, S, S), (),),
     ('relu', (S, S, S), (), 'inplace'),
@@ -11927,16 +11956,17 @@ nn_functional_tests = [
     ('softsign', (S, S, S), (),),
     ('softplus', (S, S, S), (),),
     ('softmin', (S, S, S), (0,),),
-    ('softmax', (S, S, S), (0,),),
-    ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args'),
+    ('softmax', (S, S, S), (0,), '', (True,)),
+    ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
     ('tanh', (S, S, S), (),),
     ('sigmoid', (S, S, S), (),),
-    ('log_softmax', (S, S, S), (0,),),
+    ('log_softmax', (S, S, S), (0,), '', (True,)),
     ('linear', (S, S), ((M, S),),),
     ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
-    ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),),
+    ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
     ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
-    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),),
+    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
+        '', (True, 'aten::_batch_norm_impl_index')),
     ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
     ('layer_norm', (S, S, S, S), ([5],),),
     ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
@@ -11944,7 +11974,7 @@ nn_functional_tests = [
     ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
     ('group_norm', (S, S, S), (1, torch.rand(5),),),
     ('local_response_norm', (S, S, S), (2, ),),
-    ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),),),
+    ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')),
     ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
     ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
     ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
@@ -11972,8 +12002,8 @@ nn_functional_tests = [
     ('unfold', (S, S, S, S), ([2, 3]),),
     ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
     ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
-    ('gumbel_softmax', (S, S), (2.,),),
-    ('gumbel_softmax', (S, S), (2., True,), 'hard'),
+    ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])),
+    ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])),
     ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
     ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
                                    1, 1., non_differentiable(torch.randn(S))),),
@@ -11987,35 +12017,35 @@ nn_functional_tests = [
       torch.randint(1, S, (S,), dtype=torch.long))),
     ('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
     ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
-    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
-    ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
-    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
-    ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
-    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
-    ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
-    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
-    ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
-    ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
-    ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
-    ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
-    ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
-    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
-    ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
-    ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
-    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
-    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
-    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
-    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
-    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
-    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
-    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
-    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size', (True, 'aten::__interpolate')),
+    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale', (True, 'aten::__interpolate')),
+    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size', (True, 'aten::__interpolate')),
 ]
 
 
@@ -12063,6 +12093,7 @@ def add_autograd_test(
         self_size,
         args,
         variant_name='',
+        check_ad=(),
         dim_args_idx=(),
         skipTestIf=(),
         output_process_fn=lambda x: x,
@@ -12080,7 +12111,7 @@ def add_autograd_test(
         # for-loop bodies don't define scopes, so we have to save the variables
         # we want to close over in some way
         def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
-                    output_process_fn=output_process_fn):
+                    check_ad=check_ad, output_process_fn=output_process_fn):
             def check(name):
                 set_rng_seed(2)
                 is_magic_method = name[:2] == '__' and name[-2:] == '__'
@@ -12104,20 +12135,27 @@ def add_autograd_test(
                     # Test with disable_autodiff_subgraph_inlining, which forces the graph
                     # to contain DifferentiableGraph nodes whenever possible. This allows us
                     # to test autodiff; we assume that autograd is correct and use autodiff for backprop
+                    should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
                     if test_name not in EXCLUDE_TRACED:
-                        check_against_reference(self,
-                                                create_traced_fn(self, fn,
-                                                                 disable_autodiff_subgraph_inlining=True),
+                        traced_fn = create_traced_fn(self, fn, disable_autodiff_subgraph_inlining=True)
+
+                        check_against_reference(self, traced_fn,
                                                 fn, (self_variable,) + args_variable, kwargs_variable,
                                                 check_types=check_types)
+                        self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
 
                     if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
-                        check_against_reference(self,
-                                                create_script_fn(self, name, 'method', output_process_fn,
-                                                                 disable_autodiff_subgraph_inlining=True),
+                        script_fn = create_script_fn(self, name, 'method', output_process_fn,
+                                                     disable_autodiff_subgraph_inlining=True)
+                        check_against_reference(self, script_fn,
                                                 fn, (self_variable,) + args_variable, kwargs_variable,
                                                 check_types=check_types)
 
+                        self.assertAutodiffNode(script_fn.last_graph,
+                                                should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
+                                                autodiff_nodes,
+                                                fusible_nodes)
+
                 # functional interface tests
                 if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
                     def fn(*inputs, **kwargs):
@@ -12162,7 +12200,7 @@ def suppress_warnings(fn):
     return wrapper
 
 
-def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=(),
+def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(), skipTestIf=(),
                            output_process_fn=lambda x: x, kwargs=None):
     test_name = 'test_nn_' + name
 
@@ -12172,7 +12210,7 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=()
     no_grad = variant_name == 'inplace'
 
     @suppress_warnings
-    def do_test(self, name=name, args=args, test_name=test_name):
+    def do_test(self, name=name, args=args, test_name=test_name, check_ad=check_ad):
         torch.manual_seed(2)
 
         self_variable = create_input((self_size,))[0][0]
@@ -12193,13 +12231,14 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=()
         f_args_variable = (self_variable,) + args_variable
         f_args_tensor = (self_tensor,) + args_tensor
 
+        should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
         if test_name not in EXCLUDE_SCRIPT:
-            disable_ad_subgraph_inlining = test_name in DISABLE_AUTODIFF_SUBGRAPH_INLINING
-
             def run_test():
                 script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn,
-                                             disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining)
+                                             disable_autodiff_subgraph_inlining=should_autodiff_node)
                 check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
+                # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd
+                self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
 
             if test_name in EXCLUDE_PYTHON_PRINT:
                 with self.disableModuleHook():
@@ -12330,6 +12369,24 @@ def post_add_test(test_name, skipTestIf, do_test, test_class):
         setattr(test_class, test_name, do_test)
 
 
+def normalize_check_ad(check_ad, name):
+    # normalized check_ad is 3-element tuple: (bool, List[str], List[str])
+    if len(check_ad) == 0:
+        check_ad = [False, ['aten::' + name], []]
+    elif len(check_ad) == 1:
+        check_ad = [check_ad[0], ['aten::' + name], []]
+    elif len(check_ad) == 2:
+        check_ad = [check_ad[0], check_ad[1], []]
+    elif len(check_ad) == 3:
+        check_ad = list(check_ad)
+    else:
+        raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])')
+
+    check_ad = [[t] if isinstance(t, str) else t for t in check_ad]
+
+    return check_ad
+
+
 class TestAsync(JitTestCase):
     def test_async_python(self):
         @torch.jit.script
index 70e9ecd..0a403bd 100644 (file)
@@ -66,6 +66,8 @@ Sections start with a reference to the source file where the code related to the
     + [`tensors/`](#-tensors--)
     + [`attributes.pkl`](#-attributespkl-)
     + [Implementation Details](#implementation-details)
+- [Testing Programs](#testing-programs)
+  * [Test Autodiff](#testautodiff)
 - [Python Bindings](#python-bindings)
 
 
@@ -1162,6 +1164,36 @@ TODO: fusion, operators
 
 # Saving Programs
 
+# Testing Programs
+## Test Autodiff ##
+[symbolic_script.cpp](symbolic_script.cpp)
+
+When differentiating a graph, each node that has a symbolic gradient will be included in a `prim::DifferentiableGraph`. We fall back to use autograd for the node if there isn't a gradient formula for it.
+Adding/updating symbolic gradient functions must be tested carefully as it's easy to get CI green by comparing autograd result with itself, but potentially cause autodiff support regression.
+
+If your PR adds/updates a gradient formula for `torch`/`nn` functions, you **MUST** enable/update the corresponding tests in
+- `torch` functions: `method_tests` in [common_method_tests.py](../../../test/common_method_tests.py)
+- `nn` functions: `nn_functional_tests` in [test_jit.py](../../../test/test_jit.py)
+
+To turn on autodiff check, you can add an optional `check_ad(should_check_autodiff[bool], nonfusible_nodes[str|list[str]], fusible_nodes[str|list[str]])` tuple after the optional test variant name field.
+If `should_check_autodiff=True`, the differentiated traced/script forward graph must have a `prim::DifferentiableGraph`.
+
+All nodes in `nonfusible_nodes` should show up in at least once in `prim::DifferentiableGraph` subgraphs.
+When fusion is enabled, all nodes in `fusible_nodes` should show up in one of `prim::FusionGroup` graphs attached to `prim::DifferentiableGraph`,
+otherwise they're checked as `nonfusible_nodes` as well.
+On the other hand, if `should_check_autodiff=False`, the graph can still have `prim::DifferentiableGraph` with other nodes, but not `nonfusible_nodes` and `fusible_nodes`.
+
+To make writing test easier, you only need to write out node names if it's different from the function name. Below are a few examples:
+```python
+('conv1d', ...), # No symbolic gradient formula
+('avg_pool2d', ..., (True,)), # Has symbolic gradient formula, only has one nonfusible node aten::avg_pool2d
+('nll_loss', ..., (True, 'aten::nll_loss_forward')), # Is replaced by a different node in its symbolic gradient formula
+('dropout', ..., (True, ['prim::is_cuda', 'aten::bernoulli_'], ['aten::rand_like', ..., 'aten::div'])), # Some op are fused when fusion is enabled.
+```
+
+Note that even for the same function, different tests could trigger different function schemas (e.g `aten::add`) while only a few of them have symbolic gradient formulas.
+You should only turn on autodiff check in tests who have symbolic gradient. If you are not sure, uncomment the debugging line in [symbolic_script.cpp](symbolic_script.cpp)
+to check which function schema the test triggers.
 
 ## Python Printer
 
index 489c438..19c6af6 100644 (file)
@@ -216,6 +216,8 @@ void initJITBindings(PyObject* module) {
       .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
       .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
       .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
+      .def("_jit_can_fuse_on_cpu", canFuseOnCPU)
+      .def("_jit_can_fuse_on_gpu", canFuseOnGPU)
       .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
       .def(
           "_jit_differentiate",
index 6974936..a87470b 100644 (file)
@@ -1008,6 +1008,8 @@ c10::optional<GradientPair> gradientInfoForSchema(
     //    value since scalar/int aren't differentiable either way.
     //
     c10::ReplaceAll(schema_str, "Scalar", "float");
+    // For debugging AD change:
+    // std::cout << "Looking for " << schema_str << std::endl;
 
     auto sym_script_it = schema_to_graphs.find(schema_str);