gradients: Export tf.custom_gradients
authorAsim Shankar <ashankar@google.com>
Tue, 6 Mar 2018 01:28:12 +0000 (17:28 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 01:32:21 +0000 (17:32 -0800)
(Moved from the tf.contrib.eager namespace)

PiperOrigin-RevId: 187950503

13 files changed:
tensorflow/contrib/eager/python/BUILD
tensorflow/contrib/eager/python/tfe.py
tensorflow/python/BUILD
tensorflow/python/eager/BUILD
tensorflow/python/eager/backprop_test.py
tensorflow/python/eager/custom_gradient.py [deleted file]
tensorflow/python/eager/tape_test.py
tensorflow/python/ops/custom_gradient.py [new file with mode: 0644]
tensorflow/python/ops/gradients.py
tensorflow/python/ops/gradients_test.py
tensorflow/python/ops/standard_ops.py
tensorflow/python/training/training.py
tensorflow/tools/api/golden/tensorflow.pbtxt

index 7fde534..fcb14be 100644 (file)
@@ -18,6 +18,7 @@ py_library(
         ":saver",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:gradients",
         "//tensorflow/python:numerics",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:script_ops",
@@ -27,7 +28,6 @@ py_library(
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:core",
-        "//tensorflow/python/eager:custom_gradient",
         "//tensorflow/python/eager:execution_callbacks",
         "//tensorflow/python/eager:function",
     ],
index fce7a60..5bddd26 100644 (file)
@@ -97,7 +97,6 @@ from tensorflow.python.eager.context import in_eager_mode
 from tensorflow.python.eager.context import in_graph_mode
 from tensorflow.python.eager.context import list_devices
 from tensorflow.python.eager.context import num_gpus
-from tensorflow.python.eager.custom_gradient import custom_gradient
 from tensorflow.python.eager.execution_callbacks import add_execution_callback
 from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks
 from tensorflow.python.eager.execution_callbacks import inf_callback
@@ -107,6 +106,7 @@ from tensorflow.python.eager.execution_callbacks import seterr
 from tensorflow.python.framework.ops import enable_eager_execution
 from tensorflow.python.framework.ops import eager_run as run
 from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
+from tensorflow.python.ops.custom_gradient import custom_gradient
 from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
 from tensorflow.python.ops.variable_scope import EagerVariableStore
 from tensorflow.python.ops import script_ops
index db17a3f..4fdfacb 100644 (file)
@@ -1775,6 +1775,7 @@ py_library(
 py_library(
     name = "gradients",
     srcs = [
+        "ops/custom_gradient.py",
         "ops/gradients.py",
         "ops/gradients_impl.py",
     ],
@@ -1788,6 +1789,7 @@ py_library(
         ":control_flow_util",
         ":framework",
         ":framework_for_generated_wrappers",
+        ":framework_ops",
         ":functional_ops",
         ":image_grad",
         ":linalg_grad",
@@ -1800,6 +1802,8 @@ py_library(
         ":platform",
         ":spectral_grad",
         ":util",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:tape",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
index ab81d40..5bedf9c 100644 (file)
@@ -42,7 +42,6 @@ py_library(
         ":backprop",
         ":context",
         ":core",
-        ":custom_gradient",
         ":execute",
         ":function",
         ":graph_callable",
@@ -103,7 +102,6 @@ cuda_py_test(
     additional_deps = [
         ":backprop",
         ":context",
-        ":custom_gradient",
         ":test",
         "//tensorflow/python:embedding_ops",
         "//tensorflow/python:array_ops",
@@ -207,21 +205,6 @@ cc_library(
 )
 
 py_library(
-    name = "custom_gradient",
-    srcs = ["custom_gradient.py"],
-    srcs_version = "PY2AND3",
-    visibility = ["//tensorflow:internal"],
-    deps = [
-        ":context",
-        ":tape",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework_ops",
-        "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:util",
-    ],
-)
-
-py_library(
     name = "graph_only_ops",
     srcs = ["graph_only_ops.py"],
     srcs_version = "PY2AND3",
@@ -364,7 +347,6 @@ py_test(
     deps = [
         ":backprop",
         ":context",
-        ":custom_gradient",
         ":test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
index 48fd170..07a2155 100644 (file)
@@ -23,7 +23,6 @@ import numpy as np
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
-from tensorflow.python.eager import custom_gradient
 from tensorflow.python.eager import tape
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
@@ -32,6 +31,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import gradients
 from tensorflow.python.ops import math_ops
diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py
deleted file mode 100644 (file)
index fb932a9..0000000
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Decorator to overrides the gradient for a function."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.eager import context
-from tensorflow.python.eager import tape
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.util import nest
-from tensorflow.python.util import tf_decorator
-
-
-def custom_gradient(f):
-  """Decorator to define a function with a custom gradient.
-
-  The input function is expected to return the tuple
-    (results, gradient_function).
-
-  The output function will return results while possibly recording the
-  gradient_function and inputs in the tape.
-
-  Args:
-    f: function to be decorated.
-
-  Returns:
-    decorated function.
-  """
-
-  def decorated(*args, **kwargs):
-    """Decorated function with custom gradient."""
-    if context.in_graph_mode():
-      if kwargs:
-        raise ValueError(
-            "custom_gradient in graph mode doesn't support keyword arguments.")
-      name = "CustomGradient-%s" % tf_ops.uid()
-      args = [tf_ops.convert_to_tensor(x) for x in args]
-      result, grad_fn = f(*args)
-      flat_result = nest.flatten(result)
-      all_tensors = flat_result + args
-
-      @tf_ops.RegisterGradient(name)
-      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
-        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
-        # Need to return one value per input to the IdentityN, so pad the
-        # gradients of the inputs of the custom_gradient function with the
-        # gradients of the outputs as well.
-        return ([None] * len(flat_result)) + gradients
-
-      with tf_ops.get_default_graph().gradient_override_map(
-          {"IdentityN": name}):
-        all_tensors = array_ops.identity_n(all_tensors)
-      return nest.pack_sequence_as(
-          structure=result, flat_sequence=all_tensors[:len(flat_result)])
-
-    input_tensors = [tf_ops.convert_to_tensor(x) for x in args]
-
-    result, grad_fn = f(*args, **kwargs)
-    flat_result = nest.flatten(result)
-    # TODO(apassos) consider removing the identity below.
-    flat_result = [gen_array_ops.identity(x) for x in flat_result]
-
-    def actual_grad_fn(*outputs):
-      return nest.flatten(grad_fn(*outputs))
-
-    tape.record_operation(
-        f.__name__,
-        flat_result,
-        input_tensors,
-        actual_grad_fn)
-    flat_result = list(flat_result)
-    return nest.pack_sequence_as(result, flat_result)
-
-  return tf_decorator.make_decorator(f, decorated)
index b490bac..4326d5e 100644 (file)
@@ -21,11 +21,11 @@ from __future__ import print_function
 
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
-from tensorflow.python.eager import custom_gradient
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 # Importing nn_grad for the registration functions.
@@ -165,21 +165,6 @@ class TapeTest(test.TestCase):
     g, = backprop.gradients_function(fn, [0])(t)
     self.assertAllEqual(g, 1.0)
 
-  def testCustomGradientGraphMode(self):
-    with context.graph_mode(), self.test_session():
-
-      @custom_gradient.custom_gradient
-      def f(x):
-
-        def grad(dresult):
-          return dresult * 10.0
-
-        return x, grad
-
-      inp = constant_op.constant(1.0)
-      grad = gradients_impl.gradients(f(inp), inp)
-      self.assertAllEqual(grad[0].eval(), 10.0)
-
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
new file mode 100644 (file)
index 0000000..f199ba8
--- /dev/null
@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Decorator to overrides the gradient for a function."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.util import nest
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("custom_gradient")
+def custom_gradient(f):
+  """Decorator to define a function with a custom gradient.
+
+  This decorator allows fine grained control over the gradients of a sequence
+  for operations.  This may be useful for multiple reasons, including providing
+  a more efficient or numerically stable gradient for a sequence of operations.
+
+  For example, consider the following function that commonly occurs in the
+  computation of cross entropy and log likelihoods:
+
+  ```python
+  def log1pexp(x):
+    return tf.log(1 + tf.exp(x))
+  ```
+
+  Due to numerical instability, the gradient this function evaluated at x=100 is
+  NaN.  For example:
+
+  ```python
+  x = tf.constant(100.)
+  y = log1pexp(x)
+  dy = tf.gradients(y, x) # Will be NaN when evaluated.
+  ```
+
+  The gradient expression can be analytically simplified to provide numerical
+  stability:
+
+  ```python
+  @tf.custom_gradient
+  def log1pexp(x):
+    e = tf.exp(x)
+    def grad(dy):
+      return dy * (1 - 1 / (1 + e))
+    return tf.log(1 + e), grad
+  ```
+
+  With this definition, the gradient at x=100 will be correctly evaluated as
+  1.0.
+
+  See also @{tf.RegisterGradient} which registers a gradient function for a
+  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
+  for fine grained control over the gradient computation of a sequence of
+  operations.
+
+  Args:
+    f: function `f(x)` that returns a tuple `(y, grad_fn)` where:
+       - `x` is a `Tensor` or sequence of `Tensor` inputs to the function.
+       - `y` is a `Tensor` or sequence of `Tensor` outputs of applying
+         TensorFlow
+         operations in `f` to `x`.
+       - `grad_fn` is a function with the signature `g(grad_ys)` which returns
+         a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
+         to the `Tensor`s in `x.  `grad_ys` is a `Tensor` or sequence of
+         `Tensor`s the same size as `y` holding the initial value gradients for
+         each `Tensor` in `y`.
+
+  Returns:
+    A function `h(x)` which returns the same value as `f(x)[0]` and whose
+    gradient (as calculated by @{tf.gradients}) is determined by `f(x)[1]`.
+  """
+
+  def decorated(*args, **kwargs):
+    """Decorated function with custom gradient."""
+    if context.in_graph_mode():
+      if kwargs:
+        raise ValueError(
+            "The custom_gradient decorator currently suports keywords "
+            "arguments only when eager execution is enabled.")
+      name = "CustomGradient-%s" % ops.uid()
+      args = [ops.convert_to_tensor(x) for x in args]
+      result, grad_fn = f(*args)
+      flat_result = nest.flatten(result)
+      all_tensors = flat_result + args
+
+      @ops.RegisterGradient(name)
+      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
+        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
+        # Need to return one value per input to the IdentityN, so pad the
+        # gradients of the inputs of the custom_gradient function with the
+        # gradients of the outputs as well.
+        return ([None] * len(flat_result)) + gradients
+
+      with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
+        all_tensors = array_ops.identity_n(all_tensors)
+      return nest.pack_sequence_as(
+          structure=result, flat_sequence=all_tensors[:len(flat_result)])
+
+    input_tensors = [ops.convert_to_tensor(x) for x in args]
+
+    result, grad_fn = f(*args, **kwargs)
+    flat_result = nest.flatten(result)
+    # TODO(apassos) consider removing the identity below.
+    flat_result = [gen_array_ops.identity(x) for x in flat_result]
+
+    def actual_grad_fn(*outputs):
+      return nest.flatten(grad_fn(*outputs))
+
+    tape.record_operation(f.__name__, flat_result, input_tensors,
+                          actual_grad_fn)
+    flat_result = list(flat_result)
+    return nest.pack_sequence_as(result, flat_result)
+
+  return tf_decorator.make_decorator(f, decorated)
index 921fd50..63d9a23 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=unused-import
+from tensorflow.python.ops.custom_gradient import custom_gradient
 from tensorflow.python.ops.gradients_impl import AggregationMethod
 from tensorflow.python.ops.gradients_impl import gradients
 from tensorflow.python.ops.gradients_impl import hessians
@@ -28,6 +29,7 @@ from tensorflow.python.util.all_util import remove_undocumented
 _allowed_symbols = [
     # TODO(drpng): find a good place to reference this.
     "AggregationMethod",
+    "custom_gradient",
     "gradients",  # tf.gradients.gradients.
     "hessians",  # tf.gradients.hessians
 ]
index d39b934..c94f139 100644 (file)
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
 from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
@@ -661,6 +662,7 @@ class HessianTest(test_util.TensorFlowTestCase):
     self.assertAllEqual((m, n, m, n), hess_actual.shape)
     self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
 
+
 @test_util.with_c_api
 class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
 
@@ -741,6 +743,59 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
         "of unknown shape. This may consume a large amount of memory." in
         str(w[0].message))
 
+  def testCustomGradientTrivial(self):
+
+    @custom_gradient.custom_gradient
+    def MyIdentity(x):
+
+      def Grad(dy):
+        return [3 * dy]
+
+      return x, Grad
+
+    with ops.Graph().as_default():
+      x = constant(3.)
+      y = MyIdentity(MyIdentity(x))
+      dy = gradients.gradients(y, x)[0]
+      with session.Session():
+        self.assertEqual(9., dy.eval())
+
+  def testCustomGradient(self):
+
+    @custom_gradient.custom_gradient
+    def MyMultiply(x1, x2):
+      result = x1 * x2
+
+      def Grad(dy):
+        # Switched the ordering here.
+        return [dy * x1, dy * x2]
+
+      return result, Grad
+
+    with ops.Graph().as_default():
+      x1 = constant(3.)
+      x2 = constant(5.)
+      y = MyMultiply(x1, x2)
+      dy = gradients.gradients(y, [x1, x2])
+      with session.Session() as sess:
+        self.assertAllEqual([3., 5.], sess.run(dy))
+
+  def testCustomGradientErrors(self):
+
+    @custom_gradient.custom_gradient
+    def F(x):
+
+      def Grad(_):
+        raise RuntimeError("x")
+
+      return x, Grad
+
+    with ops.Graph().as_default():
+      x = constant(1.0)
+      y = F(x)
+      with self.assertRaises(RuntimeError):
+        gradients.gradients(y, x)
+
 
 @test_util.with_c_api
 class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
index 65b788c..60a98ac 100644 (file)
@@ -218,6 +218,7 @@ _allowed_symbols_gradients = [
     # Documented in training.py:
     # Not importing training.py to avoid complex graph dependencies.
     "AggregationMethod",
+    "custom_gradient",
     "gradients",  # tf.gradients = gradients.gradients
     "hessians",
 ]
index 78c8ce9..e623e27 100644 (file)
@@ -28,6 +28,7 @@ See the @{$python/train} guide.
 @@ProximalGradientDescentOptimizer
 @@ProximalAdagradOptimizer
 @@RMSPropOptimizer
+@@custom_gradient
 @@gradients
 @@AggregationMethod
 @@stop_gradient
index 8c9e7af..a88a87b 100644 (file)
@@ -893,6 +893,10 @@ tf_module {
     argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
   }
   member_method {
+    name: "custom_gradient"
+    argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "decode_base64"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }