Update recompute_grad for TPU
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 26 Mar 2018 22:34:21 +0000 (15:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 26 Mar 2018 22:36:40 +0000 (15:36 -0700)
PiperOrigin-RevId: 190536468

tensorflow/contrib/layers/python/layers/rev_block_lib.py
tensorflow/contrib/layers/python/layers/rev_block_lib_test.py

index 123275e..0b38c0c 100644 (file)
@@ -29,6 +29,7 @@ from __future__ import print_function
 import functools
 import re
 
+import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.contrib.framework.python import ops as contrib_framework_ops
@@ -37,6 +38,7 @@ from tensorflow.python.framework import ops as framework_ops
 from tensorflow.python.layers import base
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
@@ -46,6 +48,7 @@ from tensorflow.python.util import nest
 __all__ = ["rev_block", "RevBlock", "recompute_grad"]
 
 LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*")
+_USE_DEFAULT = "__rev_block_lib_default"
 
 
 def _acc_grads(*lists_of_grads):
@@ -219,7 +222,13 @@ class RevBlock(base.Layer):
 
   def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
     """Custom gradient fn for a block of reversible residual layers."""
+    # Inputs have passed through an Identity. Recover the original Tensors to
+    # be able to match up side inputs.
+    assert [u"Identity"] == list(set([x.op.type for x in inputs]))
+    inputs = [x.op.inputs[0] for x in inputs]
     side_inputs = inputs[2:]
+    del inputs
+
     f_side_idxs = [None] * len(self.f_side_input)
     g_side_idxs = [None] * len(self.g_side_input)
     assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
@@ -405,12 +414,36 @@ def rev_block(x1,
   return block.forward(x1, x2)
 
 
-def recompute_grad(fn):
+def enable_with_args(dec):
+  """A decorator for decorators to enable their usage with or without args."""
+
+  @functools.wraps(dec)
+  def new_dec(*args, **kwargs):
+    if len(args) == 1 and not kwargs and callable(args[0]):
+      # Used as decorator without args
+      fn = args[0]
+      return dec(fn)
+    else:
+      return lambda fn: dec(fn, *args, **kwargs)
+
+  return new_dec
+
+
+@enable_with_args
+def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
   """Decorator that recomputes the function on the backwards pass.
 
   Args:
     fn: a function that takes Tensors (all as positional arguments) and returns
       a tuple of Tensors.
+    use_data_dep: `bool`, if `True` will use a dummy data dependency to force
+      the recompute to happen. If `False` will use a control dependency. By
+      default will be `True` if in an XLA context and `False` otherwise. XLA
+      ignores control dependencies and so this data dependency is necessary.
+    tupleize_grads: `bool`, if `True` will use control dependencies to ensure
+      that all gradients are produced before any are consumed by downstream ops.
+      If `use_data_dep` is also `True`, will use a data dependency instead of
+      a control dependency.
 
   Returns:
     A wrapped fn that is identical to fn when called, but its activations will
@@ -420,13 +453,25 @@ def recompute_grad(fn):
 
   @functools.wraps(fn)
   def wrapped(*args):
-    return _recompute_grad(fn, args)
+    return _recompute_grad(
+        fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
 
   return wrapped
 
 
-def _recompute_grad(fn, args):
+def _is_on_tpu():
+  ctxt = framework_ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
+  return control_flow_util.GetContainingXLAContext(ctxt) is not None
+
+
+def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
   """See recompute_grad."""
+  for arg in args:
+    if not isinstance(arg, framework_ops.Tensor):
+      raise ValueError("All inputs to function must be Tensors")
+  use_data_dep_ = use_data_dep
+  if use_data_dep_ == _USE_DEFAULT:
+    use_data_dep_ = _is_on_tpu()
 
   cached_vs = []
   cached_arg_scope = []
@@ -436,6 +481,8 @@ def _recompute_grad(fn, args):
     del outputs
     # Recompute outputs
     with framework_ops.control_dependencies(output_grads):
+      if use_data_dep_:
+        inputs = _force_data_dependency(output_grads, inputs)
       with contrib_framework_ops.arg_scope(cached_arg_scope[0]):
         with variable_scope.variable_scope(cached_vs[0], reuse=True):
           outputs = fn(*inputs)
@@ -444,6 +491,13 @@ def _recompute_grad(fn, args):
       outputs = [outputs]
     outputs = list(outputs)
     grads = gradients_impl.gradients(outputs, inputs + variables, output_grads)
+
+    if tupleize_grads:
+      if use_data_dep_:
+        grads = _tuple_with_data_dep(grads)
+      else:
+        grads = control_flow_ops.tuple(grads)
+
     grad_inputs = grads[:len(inputs)]
     grad_vars = grads[len(inputs):]
     return grad_inputs, grad_vars
@@ -532,7 +586,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
   get_vars_fn = (
       vs.global_variables if use_global_vars else vs.trainable_variables)
   len_before_vars = len(get_vars_fn())
-  inputs = list(inputs)
+  inputs = [array_ops.identity(x) for x in inputs]
   outputs = fn(*inputs)
   train_vars = get_vars_fn()[len_before_vars:]
 
@@ -581,3 +635,46 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
   flat_inputs = nest.flatten(defun_inputs)
   id_out = identity(*flat_inputs)
   return id_out
+
+
+def _force_data_dependency(first_compute, then_compute):
+  """Force all of `then_compute` to depend on all of `first_compute`.
+
+  Uses a dummy data dependency, which is useful when running on TPUs because
+  XLA ignores control dependencies. Only supports float arguments.
+
+  Args:
+    first_compute: `list<Tensor>`. These will be made to run before the
+      `Tensor`s `then_compute`.
+    then_compute: `list<Tensor>`. These will run after all the `Tensor`s in
+      `first_compute`.
+
+  Returns:
+    `list<Tensor>`, same length as `then_compute`.
+
+  Raises:
+    ValueError: if ranks are unknown or types are not floating.
+  """
+
+  def _first_element(x):
+    if x.get_shape().ndims is None:
+      raise ValueError("Rank of Tensor %s must be known" % x)
+    ndims = x.get_shape().ndims
+    return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+
+  first_compute_sum = math_ops.add_n(
+      [_first_element(x) for x in first_compute if x is not None])
+  dtype = first_compute_sum.dtype
+  if not dtype.is_floating:
+    raise ValueError("_force_data_dependency only supports floating dtypes.")
+  epsilon = np.finfo(dtype.as_numpy_dtype).tiny
+  zero = array_ops.stop_gradient(epsilon * first_compute_sum)
+
+  return [
+      array_ops.identity(x) + zero if x is not None else None
+      for x in then_compute
+  ]
+
+
+def _tuple_with_data_dep(tensors):
+  return _force_data_dependency(tensors, tensors)
index cbcbcd7..d1ad4e8 100644 (file)
@@ -154,7 +154,7 @@ class RevBlockTest(test.TestCase):
       y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
       self.assertAllClose(y_val, yd_val)
       for g1, g2 in zip(gd_val, g_val):
-        self.assertAllClose(g1, g2)
+        self.assertAllClose(g1, g2, rtol=1e-5)
 
   def testRevBlock(self):
     self._testRevBlock()
@@ -255,25 +255,54 @@ class RecomputeTest(test.TestCase):
     def fn_recompute(x):
       return fn(x)
 
+    @rev_block_lib.recompute_grad(use_data_dep=True)
+    def fn_use_data_dep(x):
+      return fn(x)
+
+    @rev_block_lib.recompute_grad(tupleize_grads=True)
+    def fn_tupleize(x):
+      return fn(x)
+
+    @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True)
+    def fn_both(x):
+      return fn(x)
+
     x = random_ops.random_uniform((3, 1, 3))
-    recompute_vars = None
-    with variable_scope.variable_scope("recompute") as vs:
-      out1 = math_ops.reduce_sum(fn_recompute(x))
-      recompute_vars = vs.trainable_variables()
-    reg_vars = None
-    with variable_scope.variable_scope("regular") as vs:
-      out2 = math_ops.reduce_sum(fn(x))
-      reg_vars = vs.trainable_variables()
-
-    grad1 = gradients_impl.gradients(out1, recompute_vars)
-    grad2 = gradients_impl.gradients(out2, reg_vars)
+
+    names_and_fns = [
+        ("recompute", fn_recompute),
+        ("regular", fn),
+        ("use_data_dep", fn_use_data_dep),
+        ("tupleize", fn_tupleize),
+        ("tuple_and_data_dep", fn_both),
+    ]
+    outputs_and_vars = []
+    for name, wrapped_fn in names_and_fns:
+      with variable_scope.variable_scope(name) as vs:
+        out = math_ops.reduce_sum(wrapped_fn(x))
+        outputs_and_vars.append((out, vs.trainable_variables()))
+
+    all_grads = []
+    for out, scope_vars in outputs_and_vars:
+      all_grads.append(gradients_impl.gradients(out, scope_vars))
 
     with self.test_session() as sess:
       sess.run(variables.global_variables_initializer())
-      outs = sess.run([out1, out2, grad1, grad2])
-      self.assertAllClose(outs[0], outs[1])
-      for g1, g2 in zip(outs[2], outs[3]):
-        self.assertAllClose(g1, g2)
+      outputs = list(zip(*outputs_and_vars))[0]
+      outs, all_grads_val = sess.run([outputs, all_grads])
+
+      # All outputs are the same
+      current = outs[0]
+      for out in outs[1:]:
+        self.assertAllClose(current, out)
+        current = out
+
+      # All gradients are the same
+      for grads in zip(all_grads_val):
+        current = grads[0]
+        for g in grads[1:]:
+          self.assertAllClose(current, g)
+          current = g
 
 
 class FnWithCustomGradTest(test.TestCase):