From d576afdcd38dcfd9d0f6ce6d6cb262d22e2b11dd Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Mon, 5 Mar 2018 17:28:12 -0800 Subject: [PATCH] gradients: Export tf.custom_gradients (Moved from the tf.contrib.eager namespace) PiperOrigin-RevId: 187950503 --- tensorflow/contrib/eager/python/BUILD | 2 +- tensorflow/contrib/eager/python/tfe.py | 2 +- tensorflow/python/BUILD | 4 + tensorflow/python/eager/BUILD | 18 ---- tensorflow/python/eager/backprop_test.py | 2 +- tensorflow/python/eager/custom_gradient.py | 90 ------------------ tensorflow/python/eager/tape_test.py | 17 +--- tensorflow/python/ops/custom_gradient.py | 134 +++++++++++++++++++++++++++ tensorflow/python/ops/gradients.py | 2 + tensorflow/python/ops/gradients_test.py | 55 +++++++++++ tensorflow/python/ops/standard_ops.py | 1 + tensorflow/python/training/training.py | 1 + tensorflow/tools/api/golden/tensorflow.pbtxt | 4 + 13 files changed, 205 insertions(+), 127 deletions(-) delete mode 100644 tensorflow/python/eager/custom_gradient.py create mode 100644 tensorflow/python/ops/custom_gradient.py diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 7fde534..fcb14be 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -18,6 +18,7 @@ py_library( ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", @@ -27,7 +28,6 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", - "//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", ], diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index fce7a60..5bddd26 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -97,7 +97,6 @@ from tensorflow.python.eager.context import in_eager_mode from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback @@ -107,6 +106,7 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index db17a3f..4fdfacb 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1775,6 +1775,7 @@ py_library( py_library( name = "gradients", srcs = [ + "ops/custom_gradient.py", "ops/gradients.py", "ops/gradients_impl.py", ], @@ -1788,6 +1789,7 @@ py_library( ":control_flow_util", ":framework", ":framework_for_generated_wrappers", + ":framework_ops", ":functional_ops", ":image_grad", ":linalg_grad", @@ -1800,6 +1802,8 @@ py_library( ":platform", ":spectral_grad", ":util", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index ab81d40..5bedf9c 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -42,7 +42,6 @@ py_library( ":backprop", ":context", ":core", - ":custom_gradient", ":execute", ":function", ":graph_callable", @@ -103,7 +102,6 @@ cuda_py_test( additional_deps = [ ":backprop", ":context", - ":custom_gradient", ":test", "//tensorflow/python:embedding_ops", "//tensorflow/python:array_ops", @@ -207,21 +205,6 @@ cc_library( ) py_library( - name = "custom_gradient", - srcs = ["custom_gradient.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - ":context", - ":tape", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:util", - ], -) - -py_library( name = "graph_only_ops", srcs = ["graph_only_ops.py"], srcs_version = "PY2AND3", @@ -364,7 +347,6 @@ py_test( deps = [ ":backprop", ":context", - ":custom_gradient", ":test", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 48fd170..07a2155 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -32,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py deleted file mode 100644 index fb932a9..0000000 --- a/tensorflow/python/eager/custom_gradient.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Decorator to overrides the gradient for a function.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.util import nest -from tensorflow.python.util import tf_decorator - - -def custom_gradient(f): - """Decorator to define a function with a custom gradient. - - The input function is expected to return the tuple - (results, gradient_function). - - The output function will return results while possibly recording the - gradient_function and inputs in the tape. - - Args: - f: function to be decorated. - - Returns: - decorated function. - """ - - def decorated(*args, **kwargs): - """Decorated function with custom gradient.""" - if context.in_graph_mode(): - if kwargs: - raise ValueError( - "custom_gradient in graph mode doesn't support keyword arguments.") - name = "CustomGradient-%s" % tf_ops.uid() - args = [tf_ops.convert_to_tensor(x) for x in args] - result, grad_fn = f(*args) - flat_result = nest.flatten(result) - all_tensors = flat_result + args - - @tf_ops.RegisterGradient(name) - def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable - gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) - # Need to return one value per input to the IdentityN, so pad the - # gradients of the inputs of the custom_gradient function with the - # gradients of the outputs as well. - return ([None] * len(flat_result)) + gradients - - with tf_ops.get_default_graph().gradient_override_map( - {"IdentityN": name}): - all_tensors = array_ops.identity_n(all_tensors) - return nest.pack_sequence_as( - structure=result, flat_sequence=all_tensors[:len(flat_result)]) - - input_tensors = [tf_ops.convert_to_tensor(x) for x in args] - - result, grad_fn = f(*args, **kwargs) - flat_result = nest.flatten(result) - # TODO(apassos) consider removing the identity below. - flat_result = [gen_array_ops.identity(x) for x in flat_result] - - def actual_grad_fn(*outputs): - return nest.flatten(grad_fn(*outputs)) - - tape.record_operation( - f.__name__, - flat_result, - input_tensors, - actual_grad_fn) - flat_result = list(flat_result) - return nest.pack_sequence_as(result, flat_result) - - return tf_decorator.make_decorator(f, decorated) diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py index b490bac..4326d5e 100644 --- a/tensorflow/python/eager/tape_test.py +++ b/tensorflow/python/eager/tape_test.py @@ -21,11 +21,11 @@ from __future__ import print_function from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops # Importing nn_grad for the registration functions. @@ -165,21 +165,6 @@ class TapeTest(test.TestCase): g, = backprop.gradients_function(fn, [0])(t) self.assertAllEqual(g, 1.0) - def testCustomGradientGraphMode(self): - with context.graph_mode(), self.test_session(): - - @custom_gradient.custom_gradient - def f(x): - - def grad(dresult): - return dresult * 10.0 - - return x, grad - - inp = constant_op.constant(1.0) - grad = gradients_impl.gradients(f(inp), inp) - self.assertAllEqual(grad[0].eval(), 10.0) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py new file mode 100644 index 0000000..f199ba8 --- /dev/null +++ b/tensorflow/python/ops/custom_gradient.py @@ -0,0 +1,134 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Decorator to overrides the gradient for a function.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("custom_gradient") +def custom_gradient(f): + """Decorator to define a function with a custom gradient. + + This decorator allows fine grained control over the gradients of a sequence + for operations. This may be useful for multiple reasons, including providing + a more efficient or numerically stable gradient for a sequence of operations. + + For example, consider the following function that commonly occurs in the + computation of cross entropy and log likelihoods: + + ```python + def log1pexp(x): + return tf.log(1 + tf.exp(x)) + ``` + + Due to numerical instability, the gradient this function evaluated at x=100 is + NaN. For example: + + ```python + x = tf.constant(100.) + y = log1pexp(x) + dy = tf.gradients(y, x) # Will be NaN when evaluated. + ``` + + The gradient expression can be analytically simplified to provide numerical + stability: + + ```python + @tf.custom_gradient + def log1pexp(x): + e = tf.exp(x) + def grad(dy): + return dy * (1 - 1 / (1 + e)) + return tf.log(1 + e), grad + ``` + + With this definition, the gradient at x=100 will be correctly evaluated as + 1.0. + + See also @{tf.RegisterGradient} which registers a gradient function for a + primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows + for fine grained control over the gradient computation of a sequence of + operations. + + Args: + f: function `f(x)` that returns a tuple `(y, grad_fn)` where: + - `x` is a `Tensor` or sequence of `Tensor` inputs to the function. + - `y` is a `Tensor` or sequence of `Tensor` outputs of applying + TensorFlow + operations in `f` to `x`. + - `grad_fn` is a function with the signature `g(grad_ys)` which returns + a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect + to the `Tensor`s in `x. `grad_ys` is a `Tensor` or sequence of + `Tensor`s the same size as `y` holding the initial value gradients for + each `Tensor` in `y`. + + Returns: + A function `h(x)` which returns the same value as `f(x)[0]` and whose + gradient (as calculated by @{tf.gradients}) is determined by `f(x)[1]`. + """ + + def decorated(*args, **kwargs): + """Decorated function with custom gradient.""" + if context.in_graph_mode(): + if kwargs: + raise ValueError( + "The custom_gradient decorator currently suports keywords " + "arguments only when eager execution is enabled.") + name = "CustomGradient-%s" % ops.uid() + args = [ops.convert_to_tensor(x) for x in args] + result, grad_fn = f(*args) + flat_result = nest.flatten(result) + all_tensors = flat_result + args + + @ops.RegisterGradient(name) + def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable + gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) + # Need to return one value per input to the IdentityN, so pad the + # gradients of the inputs of the custom_gradient function with the + # gradients of the outputs as well. + return ([None] * len(flat_result)) + gradients + + with ops.get_default_graph().gradient_override_map({"IdentityN": name}): + all_tensors = array_ops.identity_n(all_tensors) + return nest.pack_sequence_as( + structure=result, flat_sequence=all_tensors[:len(flat_result)]) + + input_tensors = [ops.convert_to_tensor(x) for x in args] + + result, grad_fn = f(*args, **kwargs) + flat_result = nest.flatten(result) + # TODO(apassos) consider removing the identity below. + flat_result = [gen_array_ops.identity(x) for x in flat_result] + + def actual_grad_fn(*outputs): + return nest.flatten(grad_fn(*outputs)) + + tape.record_operation(f.__name__, flat_result, input_tensors, + actual_grad_fn) + flat_result = list(flat_result) + return nest.pack_sequence_as(result, flat_result) + + return tf_decorator.make_decorator(f, decorated) diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 921fd50..63d9a23 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import +from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.gradients_impl import AggregationMethod from tensorflow.python.ops.gradients_impl import gradients from tensorflow.python.ops.gradients_impl import hessians @@ -28,6 +29,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ # TODO(drpng): find a good place to reference this. "AggregationMethod", + "custom_gradient", "gradients", # tf.gradients.gradients. "hessians", # tf.gradients.hessians ] diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index d39b934..c94f139 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import from tensorflow.python.ops import functional_ops # pylint: disable=unused-import @@ -661,6 +662,7 @@ class HessianTest(test_util.TensorFlowTestCase): self.assertAllEqual((m, n, m, n), hess_actual.shape) self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) + @test_util.with_c_api class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): @@ -741,6 +743,59 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): "of unknown shape. This may consume a large amount of memory." in str(w[0].message)) + def testCustomGradientTrivial(self): + + @custom_gradient.custom_gradient + def MyIdentity(x): + + def Grad(dy): + return [3 * dy] + + return x, Grad + + with ops.Graph().as_default(): + x = constant(3.) + y = MyIdentity(MyIdentity(x)) + dy = gradients.gradients(y, x)[0] + with session.Session(): + self.assertEqual(9., dy.eval()) + + def testCustomGradient(self): + + @custom_gradient.custom_gradient + def MyMultiply(x1, x2): + result = x1 * x2 + + def Grad(dy): + # Switched the ordering here. + return [dy * x1, dy * x2] + + return result, Grad + + with ops.Graph().as_default(): + x1 = constant(3.) + x2 = constant(5.) + y = MyMultiply(x1, x2) + dy = gradients.gradients(y, [x1, x2]) + with session.Session() as sess: + self.assertAllEqual([3., 5.], sess.run(dy)) + + def testCustomGradientErrors(self): + + @custom_gradient.custom_gradient + def F(x): + + def Grad(_): + raise RuntimeError("x") + + return x, Grad + + with ops.Graph().as_default(): + x = constant(1.0) + y = F(x) + with self.assertRaises(RuntimeError): + gradients.gradients(y, x) + @test_util.with_c_api class OnlyRealGradientsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 65b788c..60a98ac 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -218,6 +218,7 @@ _allowed_symbols_gradients = [ # Documented in training.py: # Not importing training.py to avoid complex graph dependencies. "AggregationMethod", + "custom_gradient", "gradients", # tf.gradients = gradients.gradients "hessians", ] diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 78c8ce9..e623e27 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -28,6 +28,7 @@ See the @{$python/train} guide. @@ProximalGradientDescentOptimizer @@ProximalAdagradOptimizer @@RMSPropOptimizer +@@custom_gradient @@gradients @@AggregationMethod @@stop_gradient diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 8c9e7af..a88a87b 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -893,6 +893,10 @@ tf_module { argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " } member_method { + name: "custom_gradient" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } + member_method { name: "decode_base64" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } -- 2.7.4