gradient_penalty.backward()
fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
- self.assertEqual(fn.name(), "ThresholdBackwardBackward")
+ self.assertEqual(fn.name(), "ThresholdBackwardBackward0")
def test_inplace_on_view_weak_grad_fn(self):
# Issue 23502: Test that b's grad_fn is preserved.
# The 3 elements are for view_as, first output of unbind and second output of unbind
run_test(grad_mode=True, requires_grad=False, is_view=True,
should_raise_tuple=(None, None, None))
- inp_change_err = "Output {} of UnbindBackward is a view and is being modified inplace."
+ inp_change_err = "Output {} of UnbindBackward0 is a view and is being modified inplace."
run_test(grad_mode=True, requires_grad=True, is_view=True,
should_raise_tuple=(None, inp_change_err.format("0"), inp_change_err.format("1")))
leaf_grad_err = "A view was created in no_grad mode and is being modified inplace"
#
# Each autograd function is represented by `DifferentiabilityInfo` containing
# a list of `Derivative`. See `tools.codegen.api.autograd` for the data models.
-from collections import defaultdict, Counter
+from collections import defaultdict
import re
-from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional
+from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional
import yaml
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
assert str(function.func) not in functions_by_schema
functions_by_schema[str(function.func)] = function
+ # Keep track of how many of which ops we've seen so we can
+ # disambiguate them with a numeric suffix.
+ op_counter = Counter[str]()
+
infos = [
- create_differentiability_info(defn, functions_by_signature, functions_by_schema)
+ create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
for defn in definitions]
- # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate
- # step. We only assign op names to those with differentiable args, and only append suffix to
- # duplicated op names. This can be simplified if the first of the duplicates can be named
- # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons.
- op_names = create_op_names(infos)
- res = [
- DifferentiabilityInfo(
- name=info.name,
- func=info.func,
- op=op_name,
- derivatives=info.derivatives,
- forward_derivatives=info.forward_derivatives,
- all_saved_inputs=info.all_saved_inputs,
- all_saved_outputs=info.all_saved_outputs,
- args_with_derivatives=info.args_with_derivatives,
- non_differentiable_arg_names=info.non_differentiable_arg_names,
- output_differentiability=info.output_differentiability,
- output_differentiability_conditions=info.output_differentiability_conditions,
- )
- for info, op_name in zip(infos, op_names)]
-
- _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res
+ _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos
return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
defn: Dict[Any, Any],
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
functions_by_schema: Dict[str, NativeFunction],
+ op_counter: Counter[str],
) -> DifferentiabilityInfo:
"""Processes a single entry `defn` in derivatives.yaml"""
derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical)
+ # only assign an op name if we are actually going to calculate a derivative
+ op = None
+ if args_with_derivatives:
+ op_prefix = _create_op_prefix(defn_name)
+ op = f'{op_prefix}{op_counter[op_prefix]}'
+ op_counter[op_prefix] += 1
+
return DifferentiabilityInfo(
name=defn_name,
func=canonical,
- op=None,
+ op=op,
derivatives=derivatives,
forward_derivatives=forward_derivatives,
all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]),
return formula, tuple(saved)
-def create_op_name(info: DifferentiabilityInfo) -> Optional[str]:
- # only assign an op name if we are actually going to calculate a derivative
- if not info.args_with_derivatives:
- return None
- name = info.name
+def _create_op_prefix(name: str) -> str:
+ """Takes a native function name converts to a op prefix name.
+
+ Note that the "name" parameter must be the native function name
+ without the optional variant suffix, so "add" instead of
+ "add.out".
+
+ OP names correspond to classes, hence the change to title case.
+
+ Example::
+ >>> _create_op_prefix('add')
+ 'AddBackward'
+ """
camel_case = ''.join([p.title() for p in name.split('_')])
return (camel_case + 'Backward').replace('ForwardBackward', 'Backward')
-def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]:
- names = list(map(create_op_name, infos))
- dups = set(item for item, count in Counter(names).items() if count > 1)
-
- # de-duplicate operation names
- # you end up with something like:
- # AddBackward0
- # AddBackward1
- # one for each overload
- counter: Dict[str, int] = Counter()
- dedup: List[Optional[str]] = []
- for name in names:
- if name is None:
- # Keep a placeholder
- dedup.append(None)
- elif name in dups:
- dedup.append(f'{name}{counter[name]}')
- counter[name] += 1
- else:
- dedup.append(name)
- return dedup
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
seen: Set[str] = set()
// self = view_op_n(view_n-1)
// self = inplace_op(self)
//
- // For CPU/CUDA backends, we employ one AsStridedBackward Node to represent the chain of
+ // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to represent the chain of
// view backward ops for effienciency.
//
- // However in XLA backend we don't have full support of AsStridedBackward, we instead run a full
+ // However in XLA backend we don't have full support of AsStridedBackward0, we instead run a full
// forward pass with a tensor that requires gradient to get proper grad_fn setup,
// then save it to DifferentiableViewMeta for future use.
// This is fairly cheap for XLA lazy tensor approach (but would be really expensive for CPU/CUDA).
auto diff_view = view_fn(view_info.base_);
diff_view_meta->grad_fn_ = diff_view.grad_fn();
} else {
- auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward>();
+ auto fn = std::make_shared<torch::autograd::generated::AsStridedBackward0>();
fn->self_geometry = at::TensorGeometry(view_info.base_);
fn->size = self.sizes().vec();
fn->stride = self.strides().vec();