Improve usability of `tf.contrib.bayesflow.custom_gradient` by removing need for...
authorJoshua V. Dillon <jvdillon@google.com>
Mon, 12 Mar 2018 18:29:24 +0000 (11:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 18:33:33 +0000 (11:33 -0700)
PiperOrigin-RevId: 188751894

tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py
tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py

index a95df31..1250765 100644 (file)
@@ -83,7 +83,7 @@ class CustomGradientTest(test.TestCase):
       g = lambda z: z[0]**2 * z[1]**2 / 2
 
       z = array_ops.stack([x, y])
-      fz = cg.custom_gradient(f(z), g(z), z, axis=0)
+      fz = cg.custom_gradient(f(z), g(z), z)
       gz = gradients_impl.gradients(fz, variables.trainable_variables())
       [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]])
 
index d44fe65..927cc28 100644 (file)
@@ -24,32 +24,38 @@ from __future__ import print_function
 
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import math_ops
 
 __all__ = [
-    "custom_gradient",
+    'custom_gradient',
 ]
 
 
-def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
-                    name=None):
-  """Enables specifying a custom gradient.
+def is_list_like(x):
+  return isinstance(x, (tuple, list))
+
+
+def identity(x, dtype=None, name=None):
+  return array_ops.identity(ops.convert_to_tensor(
+      x, dtype=dtype, name=name), name=name)
+
+
+def custom_gradient(fx, gx, x, fx_gx_manually_stopped=False, name=None):
+  """Embeds a custom gradient into a `Tensor`.
 
   This function works by clever application of `stop_gradient`. I.e., observe
   that:
 
   ```none
-  h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x))
+  h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
   ```
 
-  is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] =
-  stop_gradient(g(x)).`
+  is such that `h(x) == stop_gradient(f(x))` and
+  `grad[h(x), x] == stop_gradient(g(x)).`
 
   In addition to scalar-domain/scalar-range functions, this function also
-  supports tensor-domain/scalar-range functions. However, in the latter case it
-  is necessary to reduce `x` to a scalar. This can be done by indicating the
-  `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior
-  to calling this function.
+  supports tensor-domain/scalar-range functions.
 
   Partial Custom Gradient:
 
@@ -61,12 +67,8 @@ def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
 
   Args:
     fx: `Tensor`. Output of function evaluated at `x`.
-    gx: `Tensor`. Gradient of function evaluated at `x`.
-    x: `Tensor`. Point of evaluation for `f, g`.
-    axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain
-      of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range.
-      If `None` `f` is assumed to render one scalar given all of `x`. Otherwise
-      `f` is assumed to output one scalar for each of `axis` dimensions of `x`.
+    gx: `Tensor` or list of `Tensor`s. Gradient of function at (each) `x`.
+    x: `Tensor` or list of `Tensor`s. Args of evaluation for `f`.
     fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually
       have `stop_gradient` applied.
     name: Python `str` name prefixed to Ops created by this function.
@@ -75,36 +77,62 @@ def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False,
     fx: Floating-type `Tensor` equal to `f(x)` but which has gradient
       `stop_gradient(g(x))`.
   """
-  with ops.name_scope(name, "custom_gradient", [fx, gx, x]):
-    fx = ops.convert_to_tensor(fx, name="fx")
+  def maybe_stop(x):
+    if fx_gx_manually_stopped:
+      return x
+    return array_ops.stop_gradient(x)
+  with ops.name_scope(name, 'custom_gradient', [fx, gx, x]):
+    fx = ops.convert_to_tensor(fx, name='fx')
     # We don't want to bother eagerly computing `gx` since we may not even need
     # it.
     with ops.control_dependencies([fx]):
-      gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx")
-      gx = array_ops.identity(gx, name="gx")
-    # Proof of correctness:
-    #
-    #  f(x) = x * stop[gx] + stop[fx - x * gx]
-    #       = stop[fx]
-    #
-    #  g(x) = grad[fx]
-    #       = stop[gx] + grad[stop[fx - x * gx]]
-    #       = stop[gx] + 0
-    #
-    # Notice that when x is zero it still works:
-    # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
-    #
-    # The proof is similar for the tensor-domain case, except that `x` is
-    # replaced by `reduce_sum(x)`.
-    sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x")
-    if not fx_gx_manually_stopped:
-      fx = array_ops.stop_gradient(fx)
-      gx = array_ops.stop_gradient(gx)
-    # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write
-    # the code this way, rather than, e.g.,
-    # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
-    # For more discussion regarding the relevant portions of the IEEE754
-    # standard, see the StackOverflow question,
-    # "Is there a floating point value of x, for which x-x == 0 is false?"
-    # http://stackoverflow.com/q/2686644
-    return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx
+      if is_list_like(x):
+        x = [identity(x_, name='x') for x_ in x]
+      else:
+        x = [identity(x, name='x')]
+
+      if is_list_like(gx):
+        gx = [identity(gx_, dtype=fx.dtype, name='gx')
+              for gx_ in gx]
+      else:
+        gx = [identity(gx, dtype=fx.dtype, name='gx')]
+
+      override_grad = []
+      for x_, gx_ in zip(x, gx):
+        # Observe: tf.gradients(f(x), x)[i].shape == x[i].shape
+        # thus we check that the user is supplying correct shapes.
+        equal_shape = check_ops.assert_equal(
+            array_ops.shape(x_),
+            array_ops.shape(gx_),
+            message='Each `x` must have the same shape as each `gx`.')
+        with ops.control_dependencies([equal_shape]):
+          # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to
+          # write the code this way, rather than, e.g.,
+          # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
+          # For more discussion regarding the relevant portions of the IEEE754
+          # standard, see the StackOverflow question,
+          # "Is there a floating point value of x, for which x-x == 0 is false?"
+          # http://stackoverflow.com/q/2686644
+          zeros_like_x_ = x_ - array_ops.stop_gradient(x_)
+          override_grad.append(math_ops.reduce_sum(
+              maybe_stop(gx_) * zeros_like_x_))
+      override_grad = sum(override_grad)
+      override_grad /= math_ops.cast(array_ops.size(fx),
+                                     dtype=fx.dtype.base_dtype)
+
+      # Proof of correctness:
+      #
+      #  f(x) = x * stop[gx] + stop[fx - x * gx]
+      #       = stop[fx]
+      #
+      #  g(x) = grad[fx]
+      #       = stop[gx] + grad[stop[fx - x * gx]]
+      #       = stop[gx] + 0
+      #
+      # Notice that when x is zero it still works:
+      # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
+      #
+      # The proof is similar for the tensor-domain case, except that we
+      # `reduce_sum` the `stop[gx] * (x - stop[x])` then rescale by
+      # `tf.size(fx)` since this reduced version is broadcast to `fx`.
+      return maybe_stop(fx) + override_grad