used with double backwards.
"""
- used_grad = 0
- used_grads = 0
- fully_implemented = True
- used_grads_indices: List[int] = []
+ uses_grad = False # true if any derivative uses "grad"
+ num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
+ used_grads_indices: List[int] = [] # which indices of grads are used
for d in derivatives:
formula = d.formula
- used_grad += len(re.findall(IDENT_REGEX.format('grad'), formula))
- used_grads += len(re.findall(IDENT_REGEX.format('grads'), formula))
- fully_implemented = \
- fully_implemented and \
- not re.search(IDENT_REGEX.format('not_implemented'), formula)
+ uses_grad = uses_grad or bool(re.findall(IDENT_REGEX.format('grad'), formula))
+ num_grads_uses += len(re.findall(IDENT_REGEX.format('grads'), formula))
used_grads_indices.extend(used_gradient_indices(formula))
- assert used_grads >= len(used_grads_indices)
- only_used_grads_indices = used_grads == len(used_grads_indices)
-
- if used_grad and used_grads:
+ # This is a basic sanity check: the number of places we see
+ # "grads" should be no fewer than the number of indices we see
+ # inside "grads". They may not be equal because we may use
+ # "grads" without an index.
+ assert num_grads_uses >= len(used_grads_indices)
+ # Thus if the number is equal, every use of grads is also
+ # indexed.
+ only_used_grads_indices = num_grads_uses == len(used_grads_indices)
+
+ if uses_grad and num_grads_uses > 0:
raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml illegally "
"mixes use of 'grad' and 'grads'. Consider replacing "
"occurrences of 'grad' with 'grads[0]'")