import functools
import re
+import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
__all__ = ["rev_block", "RevBlock", "recompute_grad"]
LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*")
+_USE_DEFAULT = "__rev_block_lib_default"
def _acc_grads(*lists_of_grads):
def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
"""Custom gradient fn for a block of reversible residual layers."""
+ # Inputs have passed through an Identity. Recover the original Tensors to
+ # be able to match up side inputs.
+ assert [u"Identity"] == list(set([x.op.type for x in inputs]))
+ inputs = [x.op.inputs[0] for x in inputs]
side_inputs = inputs[2:]
+ del inputs
+
f_side_idxs = [None] * len(self.f_side_input)
g_side_idxs = [None] * len(self.g_side_input)
assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
return block.forward(x1, x2)
-def recompute_grad(fn):
+def enable_with_args(dec):
+ """A decorator for decorators to enable their usage with or without args."""
+
+ @functools.wraps(dec)
+ def new_dec(*args, **kwargs):
+ if len(args) == 1 and not kwargs and callable(args[0]):
+ # Used as decorator without args
+ fn = args[0]
+ return dec(fn)
+ else:
+ return lambda fn: dec(fn, *args, **kwargs)
+
+ return new_dec
+
+
+@enable_with_args
+def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""Decorator that recomputes the function on the backwards pass.
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
a tuple of Tensors.
+ use_data_dep: `bool`, if `True` will use a dummy data dependency to force
+ the recompute to happen. If `False` will use a control dependency. By
+ default will be `True` if in an XLA context and `False` otherwise. XLA
+ ignores control dependencies and so this data dependency is necessary.
+ tupleize_grads: `bool`, if `True` will use control dependencies to ensure
+ that all gradients are produced before any are consumed by downstream ops.
+ If `use_data_dep` is also `True`, will use a data dependency instead of
+ a control dependency.
Returns:
A wrapped fn that is identical to fn when called, but its activations will
@functools.wraps(fn)
def wrapped(*args):
- return _recompute_grad(fn, args)
+ return _recompute_grad(
+ fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
return wrapped
-def _recompute_grad(fn, args):
+def _is_on_tpu():
+ ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ return control_flow_util.GetContainingXLAContext(ctxt) is not None
+
+
+def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
+ for arg in args:
+ if not isinstance(arg, framework_ops.Tensor):
+ raise ValueError("All inputs to function must be Tensors")
+ use_data_dep_ = use_data_dep
+ if use_data_dep_ == _USE_DEFAULT:
+ use_data_dep_ = _is_on_tpu()
cached_vs = []
cached_arg_scope = []
del outputs
# Recompute outputs
with framework_ops.control_dependencies(output_grads):
+ if use_data_dep_:
+ inputs = _force_data_dependency(output_grads, inputs)
with contrib_framework_ops.arg_scope(cached_arg_scope[0]):
with variable_scope.variable_scope(cached_vs[0], reuse=True):
outputs = fn(*inputs)
outputs = [outputs]
outputs = list(outputs)
grads = gradients_impl.gradients(outputs, inputs + variables, output_grads)
+
+ if tupleize_grads:
+ if use_data_dep_:
+ grads = _tuple_with_data_dep(grads)
+ else:
+ grads = control_flow_ops.tuple(grads)
+
grad_inputs = grads[:len(inputs)]
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
get_vars_fn = (
vs.global_variables if use_global_vars else vs.trainable_variables)
len_before_vars = len(get_vars_fn())
- inputs = list(inputs)
+ inputs = [array_ops.identity(x) for x in inputs]
outputs = fn(*inputs)
train_vars = get_vars_fn()[len_before_vars:]
flat_inputs = nest.flatten(defun_inputs)
id_out = identity(*flat_inputs)
return id_out
+
+
+def _force_data_dependency(first_compute, then_compute):
+ """Force all of `then_compute` to depend on all of `first_compute`.
+
+ Uses a dummy data dependency, which is useful when running on TPUs because
+ XLA ignores control dependencies. Only supports float arguments.
+
+ Args:
+ first_compute: `list<Tensor>`. These will be made to run before the
+ `Tensor`s `then_compute`.
+ then_compute: `list<Tensor>`. These will run after all the `Tensor`s in
+ `first_compute`.
+
+ Returns:
+ `list<Tensor>`, same length as `then_compute`.
+
+ Raises:
+ ValueError: if ranks are unknown or types are not floating.
+ """
+
+ def _first_element(x):
+ if x.get_shape().ndims is None:
+ raise ValueError("Rank of Tensor %s must be known" % x)
+ ndims = x.get_shape().ndims
+ return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+
+ first_compute_sum = math_ops.add_n(
+ [_first_element(x) for x in first_compute if x is not None])
+ dtype = first_compute_sum.dtype
+ if not dtype.is_floating:
+ raise ValueError("_force_data_dependency only supports floating dtypes.")
+ epsilon = np.finfo(dtype.as_numpy_dtype).tiny
+ zero = array_ops.stop_gradient(epsilon * first_compute_sum)
+
+ return [
+ array_ops.identity(x) + zero if x is not None else None
+ for x in then_compute
+ ]
+
+
+def _tuple_with_data_dep(tensors):
+ return _force_data_dependency(tensors, tensors)
y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
self.assertAllClose(y_val, yd_val)
for g1, g2 in zip(gd_val, g_val):
- self.assertAllClose(g1, g2)
+ self.assertAllClose(g1, g2, rtol=1e-5)
def testRevBlock(self):
self._testRevBlock()
def fn_recompute(x):
return fn(x)
+ @rev_block_lib.recompute_grad(use_data_dep=True)
+ def fn_use_data_dep(x):
+ return fn(x)
+
+ @rev_block_lib.recompute_grad(tupleize_grads=True)
+ def fn_tupleize(x):
+ return fn(x)
+
+ @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True)
+ def fn_both(x):
+ return fn(x)
+
x = random_ops.random_uniform((3, 1, 3))
- recompute_vars = None
- with variable_scope.variable_scope("recompute") as vs:
- out1 = math_ops.reduce_sum(fn_recompute(x))
- recompute_vars = vs.trainable_variables()
- reg_vars = None
- with variable_scope.variable_scope("regular") as vs:
- out2 = math_ops.reduce_sum(fn(x))
- reg_vars = vs.trainable_variables()
-
- grad1 = gradients_impl.gradients(out1, recompute_vars)
- grad2 = gradients_impl.gradients(out2, reg_vars)
+
+ names_and_fns = [
+ ("recompute", fn_recompute),
+ ("regular", fn),
+ ("use_data_dep", fn_use_data_dep),
+ ("tupleize", fn_tupleize),
+ ("tuple_and_data_dep", fn_both),
+ ]
+ outputs_and_vars = []
+ for name, wrapped_fn in names_and_fns:
+ with variable_scope.variable_scope(name) as vs:
+ out = math_ops.reduce_sum(wrapped_fn(x))
+ outputs_and_vars.append((out, vs.trainable_variables()))
+
+ all_grads = []
+ for out, scope_vars in outputs_and_vars:
+ all_grads.append(gradients_impl.gradients(out, scope_vars))
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
- outs = sess.run([out1, out2, grad1, grad2])
- self.assertAllClose(outs[0], outs[1])
- for g1, g2 in zip(outs[2], outs[3]):
- self.assertAllClose(g1, g2)
+ outputs = list(zip(*outputs_and_vars))[0]
+ outs, all_grads_val = sess.run([outputs, all_grads])
+
+ # All outputs are the same
+ current = outs[0]
+ for out in outs[1:]:
+ self.assertAllClose(current, out)
+ current = out
+
+ # All gradients are the same
+ for grads in zip(all_grads_val):
+ current = grads[0]
+ for g in grads[1:]:
+ self.assertAllClose(current, g)
+ current = g
class FnWithCustomGradTest(test.TestCase):