From eff672ef063f5d02783fad59254f28b2726e01d5 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Tue, 26 Feb 2019 17:41:56 -0800 Subject: [PATCH] Remove Bool/IndexTensor from schema for native functions with derivatives (#17193) Summary: This only deals with four functions, but is an important first step towards removing BoolTensor and IndexTensor entirely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17193 Differential Revision: D14157829 Pulled By: cpuhrsch fbshipit-source-id: a36f16d1d88171036c44cc7de60ac9dfed9d14f2 --- aten/src/ATen/native/native_functions.yaml | 12 ++++--- tools/autograd/derivatives.yaml | 13 ++++++++ tools/autograd/gen_autograd_functions.py | 2 +- tools/autograd/gen_variable_type.py | 44 +++++++++++++++----------- tools/autograd/load_derivatives.py | 51 ++++++++++++++++++------------ 5 files changed, 79 insertions(+), 43 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4f0d27a..6466d6a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -47,7 +47,8 @@ dispatch: CUDA: _cudnn_rnn_flatten_weight -- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, BoolTensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) +- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + matches_jit_signature: True dispatch: CUDA: _cudnn_rnn @@ -780,7 +781,8 @@ - func: einsum(str equation, Tensor[] tensors) -> Tensor matches_jit_signature: True -- func: embedding(Tensor weight, IndexTensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor +- func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + matches_jit_signature: True - func: embedding_backward(Tensor grad, IndexTensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor @@ -789,7 +791,8 @@ CPU: embedding_dense_backward_cpu CUDA: embedding_dense_backward_cuda -- func: embedding_renorm_(Tensor(a!) self, IndexTensor indices, float max_norm, float norm_type) -> Tensor(a!) +- func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + matches_jit_signature: True dispatch: CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ @@ -2289,7 +2292,8 @@ - func: where(BoolTensor condition, Tensor self, Tensor other) -> Tensor variants: function, method -- func: _s_where(BoolTensor condition, Tensor self, Tensor other) -> Tensor +- func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor + matches_jit_signature: True variants: function dispatch: CPU: _s_where_cpu diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a884b6d..7912d66 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -8,6 +8,15 @@ # Note that a single gradient entry can specify the gradient # formula for multiple input names, by specifying a key # "input1, input2" (see atan2 for an example). +# - An argument can be flagged as 'not_differentiable'. +# In general there are 3 possibilities: +# 1. An argument has an entry with a specified gradient +# 2. An argument has an entry specified as not differentiable +# 3. An argument has no entry +# Using the flag 'not_differentiable' resolves to the second case. +# The second case was introduced in support for arguments of +# type e.g. IndexTensor for 'embedding', that are not differentiable. +# TODO: Determine whether case 3 and case 2 can be replaced by one concept. # - Optional entry with key 'output_differentiability' and value a list of the # same length as the number of outputs from the forward function. The list # should contain only booleans, specifying whether each of the output Tensor @@ -870,6 +879,7 @@ self: grad.reshape(self.sizes()) - name: _s_where(Tensor condition, Tensor self, Tensor other) + condition: not_differentiable self: where(condition, grad, zeros_like(grad)) other: where(condition, zeros_like(grad), grad) @@ -921,12 +931,14 @@ target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction) - name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) + indices: not_differentiable weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse) - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) + indices: not_differentiable self: not_implemented("embedding_renorm") - name: kl_div(Tensor self, Tensor target, int64_t reduction) @@ -1377,6 +1389,7 @@ # Only frst three of _cudnn_rnn outputs can have gradients. # _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf) - name: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, Tensor dropout_state) + dropout_state: not_differentiable output_differentiability: [True, True, True, False, False] input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 8db0c60..b5a6a00 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -124,7 +124,7 @@ def process_function(func): unpack = [] env['compute_index_ranges'] = [] - for arg in func['args_with_gradients']: + for arg in func['args_with_derivatives']: if arg['type'] == 'TensorList': size = '{}_size_'.format(arg['name']) saved_list_sizes.append('size_t {}_size_;'.format(arg['name'])) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 73d5a58..2cd53a6 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -483,11 +483,29 @@ def emit_body(declaration): if 'Tensor' not in arg['type']: return False if arg['dynamic_type'] in {'IndexTensor', 'BoolTensor'}: + # TODO: Enable this after native_functions.yaml schema unification. + # These are necessary for legacy code and should be + # used by legacy code only! + # assert name.startswith('_th_'), \ + # "IndexTensor and BoolTensor are restricted to legacy _th_ functions only. return False return True + def find_args_with_derivatives(differentiable_inputs): + """Find arguments that have derivative definitions""" + if func is None: + return differentiable_inputs + names = set(name for d in func['derivatives'] for name in d['var_names']) + differentiable = [arg for arg in differentiable_inputs if arg['name'] in names] + if len(differentiable) != len(names): + missing = names - set(arg['name'] for arg in differentiable) + raise RuntimeError('Missing arguments for derivatives: {} in {}'.format(missing, func['name'])) + return differentiable + inputs = [arg for arg in arguments if not arg.get('output', False)] differentiable_inputs = list(filter(is_differentiable, inputs)) + args_with_derivatives = find_args_with_derivatives(differentiable_inputs) + not_differentiable_args_names = func['not_differentiable_args_names'] if func else [] candidate_differentiable_outputs = list(filter(is_differentiable, returns)) if func is not None and func.get('output_differentiability') is not None: @@ -514,7 +532,7 @@ def emit_body(declaration): if func is None: return setup - has_tensorlist_arg = any(arg['type'] == 'TensorList' for arg in func['args_with_gradients']) + has_tensorlist_arg = any(arg['type'] == 'TensorList' for arg in func['args_with_derivatives']) # We don't want to save tensors if we know that they will never be used # when computing the derivative, so we add guards to those statements @@ -534,7 +552,7 @@ def emit_body(declaration): # If there's a single derivative we could compute, we already have # a requires_grad check that is sufficient - if len(func['args_with_gradients']) <= 1: + if len(func['args_with_derivatives']) <= 1: return None # We really only care about trimming down the amount of tensors we save @@ -553,7 +571,7 @@ def emit_body(declaration): derivative_var_name = derivative['var_names'][0] # Figure out the offset of the edge that uses this variable - for edge_off, arg in enumerate(func['args_with_gradients']): + for edge_off, arg in enumerate(func['args_with_derivatives']): if arg['name'] == derivative_var_name: break else: @@ -562,14 +580,13 @@ def emit_body(declaration): return 'grad_fn->should_compute_output({})'.format(edge_off) setup.extend(save_variables(func['saved_inputs'], False, guard_for)) - for arg in func['args_with_gradients']: + for arg in func['args_with_derivatives']: if arg['type'] == 'TensorList': setup.append("grad_fn->{}_size_ = {}.size();".format(arg['name'], arg['name'])) return setup - def setup_derivative(): - args_with_derivatives = find_args_with_derivatives() + def setup_derivative(differentiable_inputs): env = {} env['args_with_derivatives'] = reference_args(args_with_derivatives) @@ -598,17 +615,6 @@ def emit_body(declaration): body.append(SETUP_DERIVATIVE.substitute(env, setup=setup)) return body - def find_args_with_derivatives(): - """Find arguments that have derivative definitions""" - if func is None: - return differentiable_inputs - names = set(name for d in func['derivatives'] for name in d['var_names']) - differentiable = [arg for arg in differentiable_inputs if arg['name'] in names] - if len(differentiable) != len(names): - missing = names - set(arg['name'] for arg in differentiable) - raise RuntimeError('Missing arguments for derivatives: {} in {}'.format(missing, func['name'])) - return differentiable - def emit_check_no_requires_grad(tensor_args, args_with_derivatives): """Checks that arguments without derivatives don't require grad""" body = [] @@ -616,6 +622,8 @@ def emit_body(declaration): if arg in args_with_derivatives: continue name = arg['name'] + if name in not_differentiable_args_names: + continue if name == 'output': # Double-backwards definitions sometimes take in 'input' and # 'output', but only define the derivative for input. @@ -847,7 +855,7 @@ def emit_body(declaration): body.extend(unpack_args(env, declaration)) if requires_derivative: body.extend(emit_check_inplace()) - body.extend(setup_derivative()) + body.extend(setup_derivative(differentiable_inputs)) body.append(declare_returned_variables()) pre_record_trace, post_record_trace = emit_record_trace(env) diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index acff6bd..1a3bf1b 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -29,15 +29,16 @@ def load_derivatives(path, declarations): # How do you feel about pasting declaration inside autograd function... -def create_autograd_function(name, derivatives, args_with_gradients, signature, - declaration, output_differentiability): +def create_autograd_function(name, derivatives, args_with_derivatives, not_differentiable_args_names, + signature, declaration, output_differentiability): op = to_camel_case(name) + 'Backward' op = op.replace('ForwardBackward', 'Backward') return { 'name': name, 'op': op, 'declaration': declaration, - 'args_with_gradients': args_with_gradients, + 'args_with_derivatives': args_with_derivatives, + 'not_differentiable_args_names': not_differentiable_args_names, 'signature': signature, 'derivatives': derivatives, 'saved_inputs': all_saved_variables(derivatives, 'saved_inputs'), @@ -46,7 +47,7 @@ def create_autograd_function(name, derivatives, args_with_gradients, signature, } -def create_derivative(declaration, formula, var_names): +def create_derivative(arguments, returns, name, formula, var_names): def transform_return(r): # In-place functions take in and return self. Call the modified version # "output" so that it can be referred to in derivative definitions. @@ -55,18 +56,17 @@ def create_derivative(declaration, formula, var_names): r['name'] = 'result' return r - returns = [transform_return(r) for r in declaration['returns']] - arguments = declaration['arguments'] + returns = [transform_return(r) for r in returns] formula, saved_inputs = saved_variables(formula, arguments) formula, saved_outputs = saved_variables(formula, returns) - # Check that the referenced gradients in the formula are in bounds + # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): - if i >= len(declaration['returns']): + if i >= len(returns): raise RuntimeError( "Out of bounds grads access: derivative formula for {} " "used grads[{}], but the forward only returns {} outputs." - .format(declaration['name'], i, len(declaration['returns']))) + .format(name, i, len(returns))) return { 'formula': formula, @@ -97,7 +97,7 @@ def process_definition(defn, declarations_by_signature): def check_grad_usage(defn_name, declaration, derivatives): """ - Check for some subtle mistakes one might make when writing gradients. + Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is used with double backwards. """ @@ -130,29 +130,40 @@ def process_definition(defn, declarations_by_signature): "declaration.".format(defn_name)) def set_up_derivatives(defn_name, defn, declaration): - # Determine the set of inputs which have gradients - args_with_gradients_set = set() + # Determine the set of inputs which have derivatives + args_with_derivatives_set = set() for raw_names in defn: - args_with_gradients_set |= set(split_names(raw_names)) + args_with_derivatives_set |= set(split_names(raw_names)) # Next, let us determine the list of inputs in order. - args_with_gradients = [] + args_with_derivatives = [] for arg in declaration['arguments']: - if arg['name'] not in args_with_gradients_set: + if arg['name'] not in args_with_derivatives_set: continue - args_with_gradients.append(arg) + args_with_derivatives.append(arg) # Set up the derivative information derivatives = [] + not_differentiable_args_names = [] for raw_names in sorted(defn.keys()): formula = defn[raw_names] names = split_names(raw_names) - derivatives.append(create_derivative(declaration, formula, names)) + derivative = create_derivative(declaration['arguments'], declaration['returns'], + declaration['name'], formula, names) + if formula.lower().strip() == 'not_differentiable': + assert not sum([type(var_name) == list + for var_name in derivative['var_names']]), \ + "Variable names associated to a formula should be a flat list" + not_differentiable_args_names += derivative['var_names'] + else: + derivatives.append(derivative) + args_with_derivatives = list(filter(lambda x: x['name'] not in not_differentiable_args_names, + args_with_derivatives)) # Test to see if the use of 'grads' makes sense. check_grad_usage(defn_name, declaration, derivatives) - return derivatives, args_with_gradients + return derivatives, args_with_derivatives, not_differentiable_args_names def unzip(xs): return zip(*xs) @@ -195,8 +206,8 @@ def process_definition(defn, declarations_by_signature): 'Declarations.yaml ({})' .format(i, defn_name, x, y)) - derivatives, args_with_gradients = set_up_derivatives(defn_name, defn, canonical) - return create_autograd_function(defn_name, derivatives, args_with_gradients, + derivatives, args_with_derivatives, not_differentiable_args_names = set_up_derivatives(defn_name, defn, canonical) + return create_autograd_function(defn_name, derivatives, args_with_derivatives, not_differentiable_args_names, signature, canonical, output_differentiability) -- 2.7.4