simplify op name determination into a single forward pass (#64261)
authorMichael Dagitses <mikeyd@fb.com>
Thu, 2 Sep 2021 13:49:09 +0000 (06:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 14:32:11 +0000 (07:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64261

Note that this does not preserve byte-for-byte compatibility with
existing names.

Test Plan:
* Rely on CI to catch gross errors.
* Merge after release cut to catch subtle issues.

Reviewed By: albanD

Differential Revision: D30700647

Pulled By: dagitses

fbshipit-source-id: 7b02f34b8fae3041240cc78fbc6bcae498c3acd4

test/expect/TestAutograd.test_function-x_grad_desc.expect
test/expect/TestAutograd.test_function-y_grad_desc.expect
test/test_autograd.py
test/test_cuda.py
tools/autograd/load_derivatives.py
torch/csrc/autograd/variable.cpp

index b6fdb63..68242e2 100644 (file)
@@ -1 +1 @@
-CopyBackwards(None, AddBackward0(ExpandBackward(AccumulateGrad()), MulBackward0(ExpandBackward(AccumulateGrad()), AccumulateGrad())))
\ No newline at end of file
+CopyBackwards(None, AddBackward0(ExpandBackward0(AccumulateGrad()), MulBackward0(ExpandBackward0(AccumulateGrad()), AccumulateGrad())))
\ No newline at end of file
index e32d588..88db873 100644 (file)
@@ -1 +1 @@
-CopyBackwards(None, AddBackward0(MulBackward0(ExpandBackward(AccumulateGrad()), None), MulBackward0(ExpandBackward(AccumulateGrad()), AccumulateGrad())))
\ No newline at end of file
+CopyBackwards(None, AddBackward0(MulBackward0(ExpandBackward0(AccumulateGrad()), None), MulBackward0(ExpandBackward0(AccumulateGrad()), AccumulateGrad())))
\ No newline at end of file
index fde64b0..2da74cb 100644 (file)
@@ -3450,7 +3450,7 @@ class TestAutograd(TestCase):
         gradient_penalty.backward()
 
         fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
-        self.assertEqual(fn.name(), "ThresholdBackwardBackward")
+        self.assertEqual(fn.name(), "ThresholdBackwardBackward0")
 
     def test_inplace_on_view_weak_grad_fn(self):
         # Issue 23502: Test that b's grad_fn is preserved.
@@ -4859,7 +4859,7 @@ for shape in [(1,), ()]:
         # The 3 elements are for view_as, first output of unbind and second output of unbind
         run_test(grad_mode=True, requires_grad=False, is_view=True,
                  should_raise_tuple=(None, None, None))
-        inp_change_err = "Output {} of UnbindBackward is a view and is being modified inplace."
+        inp_change_err = "Output {} of UnbindBackward0 is a view and is being modified inplace."
         run_test(grad_mode=True, requires_grad=True, is_view=True,
                  should_raise_tuple=(None, inp_change_err.format("0"), inp_change_err.format("1")))
         leaf_grad_err = "A view was created in no_grad mode and is being modified inplace"
index 33dbade..cddd15a 100644 (file)
@@ -3049,7 +3049,7 @@ torch.cuda.synchronize()
                 # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed.  This check can't guarantee
                 # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has
                 # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed.
-                self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward")
+                self.assertEqual(out.grad_fn.name(), "CudnnRnnBackward0")
                 out.sum().backward()
                 grads = [p.grad.clone() for p in rnn.parameters()]
 
index 3ff11f4..8a5904b 100644 (file)
@@ -2,9 +2,9 @@
 #
 # Each autograd function is represented by `DifferentiabilityInfo` containing
 # a list of `Derivative`. See `tools.codegen.api.autograd` for the data models.
-from collections import defaultdict, Counter
+from collections import defaultdict
 import re
-from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional
+from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional
 import yaml
 
 from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
@@ -43,32 +43,15 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
             assert str(function.func) not in functions_by_schema
             functions_by_schema[str(function.func)] = function
 
+        # Keep track of how many of which ops we've seen so we can
+        # disambiguate them with a numeric suffix.
+        op_counter = Counter[str]()
+
         infos = [
-            create_differentiability_info(defn, functions_by_signature, functions_by_schema)
+            create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
             for defn in definitions]
 
-        # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate
-        # step. We only assign op names to those with differentiable args, and only append suffix to
-        # duplicated op names. This can be simplified if the first of the duplicates can be named
-        # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons.
-        op_names = create_op_names(infos)
-        res = [
-            DifferentiabilityInfo(
-                name=info.name,
-                func=info.func,
-                op=op_name,
-                derivatives=info.derivatives,
-                forward_derivatives=info.forward_derivatives,
-                all_saved_inputs=info.all_saved_inputs,
-                all_saved_outputs=info.all_saved_outputs,
-                args_with_derivatives=info.args_with_derivatives,
-                non_differentiable_arg_names=info.non_differentiable_arg_names,
-                output_differentiability=info.output_differentiability,
-                output_differentiability_conditions=info.output_differentiability_conditions,
-            )
-            for info, op_name in zip(infos, op_names)]
-
-        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res
+        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos
 
     return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
 
@@ -279,6 +262,7 @@ def create_differentiability_info(
     defn: Dict[Any, Any],
     functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
     functions_by_schema: Dict[str, NativeFunction],
+    op_counter: Counter[str],
 ) -> DifferentiabilityInfo:
     """Processes a single entry `defn` in derivatives.yaml"""
 
@@ -424,10 +408,17 @@ def create_differentiability_info(
 
     derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical)
 
+    # only assign an op name if we are actually going to calculate a derivative
+    op = None
+    if args_with_derivatives:
+        op_prefix = _create_op_prefix(defn_name)
+        op = f'{op_prefix}{op_counter[op_prefix]}'
+        op_counter[op_prefix] += 1
+
     return DifferentiabilityInfo(
         name=defn_name,
         func=canonical,
-        op=None,
+        op=op,
         derivatives=derivatives,
         forward_derivatives=forward_derivatives,
         all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]),
@@ -566,35 +557,22 @@ def saved_variables(
 
     return formula, tuple(saved)
 
-def create_op_name(info: DifferentiabilityInfo) -> Optional[str]:
-    # only assign an op name if we are actually going to calculate a derivative
-    if not info.args_with_derivatives:
-        return None
-    name = info.name
+def _create_op_prefix(name: str) -> str:
+    """Takes a native function name converts to a op prefix name.
+
+    Note that the "name" parameter must be the native function name
+    without the optional variant suffix, so "add" instead of
+    "add.out".
+
+    OP names correspond to classes, hence the change to title case.
+
+    Example::
+    >>> _create_op_prefix('add')
+    'AddBackward'
+    """
     camel_case = ''.join([p.title() for p in name.split('_')])
     return (camel_case + 'Backward').replace('ForwardBackward', 'Backward')
 
-def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]:
-    names = list(map(create_op_name, infos))
-    dups = set(item for item, count in Counter(names).items() if count > 1)
-
-    # de-duplicate operation names
-    # you end up with something like:
-    #   AddBackward0
-    #   AddBackward1
-    # one for each overload
-    counter: Dict[str, int] = Counter()
-    dedup: List[Optional[str]] = []
-    for name in names:
-        if name is None:
-            # Keep a placeholder
-            dedup.append(None)
-        elif name in dups:
-            dedup.append(f'{name}{counter[name]}')
-            counter[name] += 1
-        else:
-            dedup.append(name)
-    return dedup
 
 def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
     seen: Set[str] = set()
index 32af5f9..7ae1ac0 100644 (file)
@@ -551,10 +551,10 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
       //   self = view_op_n(view_n-1)
       //   self = inplace_op(self)
       //
-      // For CPU/CUDA backends, we employ one AsStridedBackward Node to represent the chain of
+      // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to represent the chain of
       // view backward ops for effienciency.
       //
-      // However in XLA backend we don't have full support of AsStridedBackward, we instead run a full
+      // However in XLA backend we don't have full support of AsStridedBackward0, we instead run a full
       // forward pass with a tensor that requires gradient to get proper grad_fn setup,
       // then save it to DifferentiableViewMeta for future use.
       // This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA).
@@ -572,7 +572,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
         auto diff_view = view_fn(view_info.base_);
         diff_view_meta->grad_fn_ = diff_view.grad_fn();
       } else {
-        auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
+        auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward0>();
         fn->self_geometry = at::TensorGeometry(view_info.base_);
         fn->size = self.sizes().vec();
         fn->stride = self.strides().vec();