From d1e0a73577b226d2a865a96f1b4ea9f463f3f4ed Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Apr 2018 11:41:21 -0700 Subject: [PATCH] Internally rewrite @recompute_grad to use @custom_gradient PiperOrigin-RevId: 194571125 --- .../contrib/layers/python/layers/rev_block_lib.py | 98 ++++++++++++++-------- .../layers/python/layers/rev_block_lib_test.py | 48 ++++++++--- tensorflow/python/ops/custom_gradient.py | 42 ++++++++-- tensorflow/python/ops/gradients_test.py | 34 ++++++++ 4 files changed, 167 insertions(+), 55 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 02d294c..1a439f0 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -33,6 +33,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops @@ -40,6 +41,7 @@ from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -50,6 +52,13 @@ __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") _USE_DEFAULT = "__rev_block_lib_default" +_WRONG_VARS_ERR = """\ +The variables used on recompute were different than the variables originally +used. The function wrapped with @recompute_grad likley creates its own variable +scope with a default name and has been called twice in the same enclosing scope. +To fix, ensure each call to the function happens in its own unique variable +scope. +""" def _acc_grads(*lists_of_grads): @@ -432,6 +441,10 @@ def enable_with_args(dec): def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. + To use this function, you must use `ResourceVariable`s (i.e. + `variable_scope(name, use_resource=True), which are the default in Eager mode + and when running on TPU. + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -472,44 +485,55 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if use_data_dep_ == _USE_DEFAULT: use_data_dep_ = _is_on_tpu() - cached_vs = [] - cached_arg_scope = [] - - def grad_fn(inputs, variables, outputs, output_grads): - """Recompute outputs for gradient computation.""" - del outputs - # Recompute outputs - with framework_ops.control_dependencies(output_grads): - if use_data_dep_: - inputs = _force_data_dependency(output_grads, inputs) - with contrib_framework_ops.arg_scope(cached_arg_scope[0]): - with variable_scope.variable_scope(cached_vs[0], reuse=True): - outputs = fn(*inputs) - - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): - outputs = [outputs] - outputs = list(outputs) - grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) - - if tupleize_grads: - if use_data_dep_: - grads = _tuple_with_data_dep(grads) - else: - grads = control_flow_ops.tuple(grads) - - grad_inputs = grads[:len(inputs)] - grad_vars = grads[len(inputs):] - return grad_inputs, grad_vars - - @_fn_with_custom_grad(grad_fn) + @custom_gradient.custom_gradient def fn_with_recompute(*args): - cached_vs.append(variable_scope.get_variable_scope()) - # TODO(rsepassi): Rm conditional in TF 1.4 - if hasattr(contrib_framework_ops, "current_arg_scope"): - cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) - else: - cached_arg_scope.append({}) - return fn(*args) + """Wrapper for fn.""" + # Forward pass + vs = variable_scope.get_variable_scope() + arg_scope = contrib_framework_ops.current_arg_scope() + with backprop.GradientTape() as tape: + outputs = fn(*args) + original_vars = set(tape.watched_variables()) + + # Backward pass + def grad_fn(*output_grads, **kwargs): + """Recompute outputs for gradient computation.""" + variables = [] + if original_vars: + variables = kwargs["variables"] + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + del kwargs + inputs = list(args) + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) + with contrib_framework_ops.arg_scope(arg_scope): + with variable_scope.variable_scope(vs, reuse=True): + with backprop.GradientTape() as tape: + outputs = fn(*inputs) + recompute_vars = set(tape.watched_variables()) + if original_vars != recompute_vars: + raise ValueError(_WRONG_VARS_ERR) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, + output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + return outputs, grad_fn return fn_with_recompute(*args) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 8c11840..8107486 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -278,7 +278,7 @@ class RecomputeTest(test.TestCase): ] outputs_and_vars = [] for name, wrapped_fn in names_and_fns: - with variable_scope.variable_scope(name) as vs: + with variable_scope.variable_scope(name, use_resource=True) as vs: out = math_ops.reduce_sum(wrapped_fn(x)) outputs_and_vars.append((out, vs.trainable_variables())) @@ -304,19 +304,45 @@ class RecomputeTest(test.TestCase): self.assertAllClose(current, g) current = g - def testResourceVariable(self): - @rev_block_lib.recompute_grad(tupleize_grads=True) + def testDoubleCallInSameScopeFails(self): + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs): + return core_layers.dense(inputs, 2) + + with variable_scope.variable_scope("layer", use_resource=True): + inputs = array_ops.ones((2, 4), dtypes.float32) + out1 = layer_with_recompute(inputs) + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) + + tvars = variables.trainable_variables() + assert len(tvars) == 4 + with self.assertRaisesWithPredicateMatch( + ValueError, "called twice in the same enclosing scope"): + gradients_impl.gradients(out, [inputs] + tvars) + + def testDoubleCallInUniqueScope(self): + + @rev_block_lib.recompute_grad def layer_with_recompute(inputs): - var = variable_scope.get_variable("var", ()) - return var * inputs + with variable_scope.variable_scope("inner", use_resource=True): + return core_layers.dense(inputs, 2) - inputs = array_ops.ones((), dtypes.float32) with variable_scope.variable_scope("layer", use_resource=True): - outputs = layer_with_recompute(inputs) - loss = math_ops.square(outputs) - grads = gradients_impl.gradients(loss, variables.trainable_variables()) - self.assertEqual(1, len(grads)) - self.assertTrue(grads[0] is not None) + inputs = array_ops.ones((2, 4), dtypes.float32) + + with variable_scope.variable_scope("layer1", use_resource=True): + out1 = layer_with_recompute(inputs) + with variable_scope.variable_scope("layer2", use_resource=True): + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) + + tvars = variables.trainable_variables() + assert len(tvars) == 4 + grads = gradients_impl.gradients(out, [inputs] + tvars) + for grad in grads: + self.assertTrue(grad is not None) class FnWithCustomGradTest(test.TestCase): diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 446ad1b..d934f27 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -121,17 +122,42 @@ def _graph_mode_decorator(f, *args, **kwargs): "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() args = [ops.convert_to_tensor(x) for x in args] + + # Checking global and local variables attempts to ensure that no non-resource + # Variables are added to the graph. + current_var_scope = variable_scope.get_variable_scope() + before_vars = set(current_var_scope.global_variables() + + current_var_scope.local_variables()) with backprop.GradientTape() as tape: result, grad_fn = f(*args) + after_vars = set(current_var_scope.global_variables() + + current_var_scope.local_variables()) + new_vars = after_vars - before_vars + for v in new_vars: + if not isinstance(v, resource_variable_ops.ResourceVariable): + raise TypeError( + "All variables used by a function wrapped with @custom_gradient must " + "be `ResourceVariable`s. Ensure that no `variable_scope` is created " + "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = list(set(tape.watched_variables()) - set(args)) grad_argspec = tf_inspect.getargspec(grad_fn) - if "variables" in grad_argspec.args: + variables_in_signature = ("variables" in grad_argspec.args or + grad_argspec.keywords) + if variables and not variables_in_signature: + raise TypeError("If using @custom_gradient with a function that " + "uses variables, then grad_fn must accept a keyword " + "argument 'variables'.") + if variables_in_signature and not variables: + # User seems to intend to use variables but none were captured. if not variable_scope.get_variable_scope().use_resource: raise TypeError("If using @custom_gradient with a function that " - "creates variables, the enclosing variable scope must " + "uses variables, the enclosing variable scope must " "have use_resource=True.") + else: + logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " + "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) all_tensors = flat_result + args + variables @@ -167,11 +193,13 @@ def _eager_mode_decorator(f, *args, **kwargs): all_inputs = list(args) + list(kwargs.values()) # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. - variable_inputs = [ - arg for arg in all_inputs - if isinstance(arg, resource_variable_ops.ResourceVariable) - ] - variables = list(set(tape.watched_variables()) - set(variable_inputs)) + variables = [v for v in set(tape.watched_variables()) if v not in all_inputs] + grad_argspec = tf_inspect.getargspec(grad_fn) + if (variables and + not ("variables" in grad_argspec.args or grad_argspec.keywords)): + raise TypeError("If using @custom_gradient with a function that " + "uses variables, then grad_fn must accept a keyword " + "argument 'variables'.") flat_result = nest.flatten(result) # TODO(apassos) consider removing the identity below. flat_result = [gen_array_ops.identity(x) for x in flat_result] diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 9d29617..5e8b882 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -894,6 +894,40 @@ class CustomGradientTest(test_util.TensorFlowTestCase): self.assertEqual(6., math_ops.reduce_sum(dx).numpy()) self.assertEqual(8., math_ops.reduce_sum(dw).numpy()) + def testCustomGradientErrorsWithNonResourceVariables(self): + + def F(x, use_resource=False): + with variable_scope.variable_scope("f", use_resource=use_resource): + out = core_layers.dense(x, 4, use_bias=False) + + def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name + del out_grad + self.assertEqual(1, len(variables)) + return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) + + return out, Grad + + @custom_gradient.custom_gradient + def FResource(x): + return F(x, use_resource=True) + + @custom_gradient.custom_gradient + def FNonResource(x): + return F(x, use_resource=False) + + x = array_ops.ones((3, 2)) + 2. + + # Wrapping scope has use_resource=True but inner scope sets to False. Fails. + with variable_scope.variable_scope("vs1", use_resource=True): + with self.assertRaisesWithPredicateMatch(TypeError, + "must be `ResourceVariable`s"): + FNonResource(x) + + # Wrapping scope has use_resource=False but inner scope sets to True. + # Passes. + with variable_scope.variable_scope("vs2", use_resource=False): + FResource(x) + def testWithNumpyInputs(self): with context.eager_mode(): -- 2.7.4