clarify implementation of check_grad_usage (#64439)
authorMichael Dagitses <mikeyd@fb.com>
Sat, 18 Sep 2021 13:47:20 +0000 (06:47 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 18 Sep 2021 14:30:30 +0000 (07:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64439

1) remove unused fully_implemented
2) rename used_grad to uses_grad and make it a boolean
3) rename used_grads to num_grads_uses
4) add comments explaining what some of the checks mean

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D30733904

Pulled By: dagitses

fbshipit-source-id: dccbbef8a4be8713215ef91aa97a34124f06a7a1

tools/autograd/load_derivatives.py

index 8a5904b..0ed5fea 100644 (file)
@@ -285,22 +285,24 @@ def create_differentiability_info(
         used with double backwards.
         """
 
-        used_grad = 0
-        used_grads = 0
-        fully_implemented = True
-        used_grads_indices: List[int] = []
+        uses_grad = False                   # true if any derivative uses "grad"
+        num_grads_uses = 0                  # count of uses of "grads" or "grads[INDEX]"
+        used_grads_indices: List[int] = []  # which indices of grads are used
         for d in derivatives:
             formula = d.formula
-            used_grad += len(re.findall(IDENT_REGEX.format('grad'), formula))
-            used_grads += len(re.findall(IDENT_REGEX.format('grads'), formula))
-            fully_implemented = \
-                fully_implemented and \
-                not re.search(IDENT_REGEX.format('not_implemented'), formula)
+            uses_grad = uses_grad or bool(re.findall(IDENT_REGEX.format('grad'), formula))
+            num_grads_uses += len(re.findall(IDENT_REGEX.format('grads'), formula))
             used_grads_indices.extend(used_gradient_indices(formula))
-        assert used_grads >= len(used_grads_indices)
-        only_used_grads_indices = used_grads == len(used_grads_indices)
-
-        if used_grad and used_grads:
+        # This is a basic sanity check: the number of places we see
+        # "grads" should be no fewer than the number of indices we see
+        # inside "grads". They may not be equal because we may use
+        # "grads" without an index.
+        assert num_grads_uses >= len(used_grads_indices)
+        # Thus if the number is equal, every use of grads is also
+        # indexed.
+        only_used_grads_indices = num_grads_uses == len(used_grads_indices)
+
+        if uses_grad and num_grads_uses > 0:
             raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml illegally "
                                "mixes use of 'grad' and 'grads'. Consider replacing "
                                "occurrences of 'grad' with 'grads[0]'")