from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
-from tensorflow.python.util import tf_inspect
__all__ = ["rev_block", "RevBlock", "recompute_grad"]
@enable_with_args
-def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False,
- tensor_arg_names=None):
+def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""Decorator that recomputes the function on the backwards pass.
Args:
- fn: the subgraph-producing function to wrap and recompute when computing
- gradients. Provide `tensor_arg_names` if not all arguments are `Tensor`s.
+ fn: a function that takes Tensors (all as positional arguments) and returns
+ a tuple of Tensors.
use_data_dep: `bool`, if `True` will use a dummy data dependency to force
the recompute to happen. If `False` will use a control dependency. By
default will be `True` if in an XLA context and `False` otherwise. XLA
that all gradients are produced before any are consumed by downstream ops.
If `use_data_dep` is also `True`, will use a data dependency instead of
a control dependency.
- tensor_arg_names: `list<str>`, names of the `Tensor` arguments to `fn`. If
- `None`, assumes all arguments are `Tensor`s.
Returns:
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
"""
- if tensor_arg_names:
- if not isinstance(tensor_arg_names, (list, tuple)):
- raise TypeError("tensor_arg_names must be a list")
@functools.wraps(fn)
- def wrapped(*args, **kwargs):
- tensor_only_fn, tensor_args = _make_tensor_only(fn, args, kwargs,
- tensor_arg_names)
+ def wrapped(*args):
return _recompute_grad(
- tensor_only_fn, tensor_args, use_data_dep=use_data_dep,
- tupleize_grads=tupleize_grads)
+ fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
return wrapped
return control_flow_util.GetContainingXLAContext(ctxt) is not None
-def _make_tensor_only(fn, args, kwargs, tensor_arg_names):
- """Return fn such that it only takes Tensor args for tensor_arg_names."""
- argspec = tf_inspect.getargspec(fn)
- if argspec.varargs is not None or argspec.keywords is not None:
- raise ValueError("Function decorated with recompute_grad must not use "
- "*args or **kwargs.")
- fn_arg_names = list(argspec.args)
-
- # name_to_arg is a dict of argument name to argument value, including both
- # positional and keyword arguments passed.
- name_to_arg = {}
- # Populate positional arguments.
- for name, arg in zip(fn_arg_names[:len(args)], args):
- name_to_arg[name] = arg
- # Populate keyword arguments.
- name_to_arg.update(kwargs)
-
- # Separate the Tensor arguments from the non-Tensor arguments.
- # The default is that all arguments are Tensor arguments.
- tensor_arg_names = tensor_arg_names or fn_arg_names
- for name in tensor_arg_names:
- if name not in name_to_arg:
- raise ValueError("Must provide Tensor argument %s" % name)
- tensor_args = [name_to_arg[name] for name in tensor_arg_names]
- non_tensor_kwargs = dict([(name, arg) for name, arg in name_to_arg.items()
- if name not in tensor_arg_names])
-
- # Check that Tensor arguments are in fact Tensors and that non-Tensor
- # arguments are not.
- for name, arg in zip(tensor_arg_names, tensor_args):
- if not isinstance(arg, framework_ops.Tensor):
- raise TypeError("Fn argument %s must be a Tensor." % name)
- for name, arg in non_tensor_kwargs.items():
- if isinstance(arg, framework_ops.Tensor):
- raise TypeError("Fn argument %s must not be a Tensor." % name)
-
- # Construct a Tensor-only wrapper function that will pass the non-Tensor
- # arguments as well when called.
- def tensor_only_fn(*tensors):
- all_kwargs = dict(zip(tensor_arg_names, tensors))
- all_kwargs.update(non_tensor_kwargs)
- return fn(**all_kwargs)
-
- return tensor_only_fn, tensor_args
-
-
-def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT,
- tupleize_grads=False):
+def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
for arg in args:
if not isinstance(arg, framework_ops.Tensor):
raise ValueError("All inputs to function must be Tensors")
-
use_data_dep_ = use_data_dep
if use_data_dep_ == _USE_DEFAULT:
use_data_dep_ = _is_on_tpu()
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
- # TODO(rsepassi): Replace with tf.custom_gradient
@_fn_with_custom_grad(grad_fn)
def fn_with_recompute(*args):
cached_vs.append(variable_scope.get_variable_scope())
- cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
+ # TODO(rsepassi): Rm conditional in TF 1.4
+ if hasattr(contrib_framework_ops, "current_arg_scope"):
+ cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
+ else:
+ cached_arg_scope.append({})
return fn(*args)
return fn_with_recompute(*args)
self.assertEqual(1, len(grads))
self.assertTrue(grads[0] is not None)
- def testWithNontensorArgs(self):
- @rev_block_lib.recompute_grad(tupleize_grads=True,
- tensor_arg_names=["inputs"])
- def layer_with_recompute(inputs, plus=None):
- var = variable_scope.get_variable("var", ())
- self.assertFalse(plus) # called with False below
- if plus:
- return var + inputs
- else:
- return var * inputs
-
- inputs = array_ops.ones((), dtypes.float32)
- outputs = layer_with_recompute(inputs, plus=False)
- loss = math_ops.square(outputs)
- grads = gradients_impl.gradients(loss, variables.trainable_variables())
- self.assertEqual(1, len(grads))
- self.assertTrue(grads[0] is not None)
-
-
-class MakeTensorOnlyTest(test.TestCase):
-
- def testMakeTensorOnly(self):
- def fn(a, b, c, d=1, e=None, f=7):
- return (a, b, c, d, e, f)
-
- t1 = array_ops.ones(())
- t2 = array_ops.ones(())
- t3 = array_ops.ones(())
- args = [1, t1, 3, t2]
- kwargs = {"e": t3}
- tensor_only_fn, tensor_args = rev_block_lib._make_tensor_only(
- fn, args, kwargs, ["b", "d", "e"])
- self.assertAllEqual(tensor_args, [t1, t2, t3])
- out = tensor_only_fn(*tensor_args)
- self.assertAllEqual(out, (1, t1, 3, t2, t3, 7))
-
- def testMakeTensorOnlyPositionalArgsOnly(self):
- def fn(a, b, c):
- return (a, b, c)
-
- t1 = array_ops.ones(())
- t2 = array_ops.ones(())
- args = [t1, 3, t2]
- tensor_only_fn, tensor_args = rev_block_lib._make_tensor_only(
- fn, args, {}, ["a", "c"])
- self.assertAllEqual(tensor_args, [t1, t2])
- out = tensor_only_fn(*tensor_args)
- self.assertAllEqual(out, (t1, 3, t2))
-
- def testMakeTensorOnlyKwargsArgsOnly(self):
- def fn(a=1, b=2, c=3):
- return (a, b, c)
-
- t1 = array_ops.ones(())
- t2 = array_ops.ones(())
- args = [t1]
- kwargs = {"c": t2}
- tensor_only_fn, tensor_args = rev_block_lib._make_tensor_only(
- fn, args, kwargs, ["a", "c"])
- self.assertAllEqual(tensor_args, [t1, t2])
- out = tensor_only_fn(*tensor_args)
- self.assertAllEqual(out, (t1, 2, t2))
-
- def testErrorOnMissingTensorArg(self):
- def fn(a, b):
- return (a, b)
-
- with self.assertRaisesWithPredicateMatch(
- ValueError, "provide Tensor argument"):
- rev_block_lib._make_tensor_only(fn, [], {"b": 2}, ["a"])
-
- def testErrorOnSignatureSplats(self):
- def fn1(a, *args):
- return (a, args)
-
- err_msg = r"must not use \*args or \*\*kwargs"
- with self.assertRaisesWithPredicateMatch(ValueError, err_msg):
- rev_block_lib._make_tensor_only(fn1, [1, 2], {}, ["a"])
-
- def fn2(a, **kwargs):
- return (a, kwargs)
-
- with self.assertRaisesWithPredicateMatch(ValueError, err_msg):
- rev_block_lib._make_tensor_only(fn2, [], {"a": 1, "b": 2}, ["a"])
-
- def testErrorOnNonTensorForTensor(self):
- def fn(a, b):
- return (a, b)
-
- with self.assertRaisesWithPredicateMatch(TypeError, "must be a Tensor"):
- rev_block_lib._make_tensor_only(fn, [2, 3], {}, ["a"])
-
- def testErrorOnTensorForNonTensor(self):
- def fn(a, b):
- return (a, b)
-
- with self.assertRaisesWithPredicateMatch(
- TypeError, "must not be a Tensor"):
- t1 = array_ops.ones(())
- t2 = array_ops.ones(())
- rev_block_lib._make_tensor_only(fn, [t1, t2], {}, ["a"])
-
class FnWithCustomGradTest(test.TestCase):