internal
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 20 Apr 2018 21:32:07 +0000 (14:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 20 Apr 2018 21:35:48 +0000 (14:35 -0700)
END_PUBLIC

BEGIN_PUBLIC
Automated g4 rollback of changelist 193600682

PiperOrigin-RevId: 193723856

tensorflow/contrib/layers/python/layers/rev_block_lib.py
tensorflow/contrib/layers/python/layers/rev_block_lib_test.py

index 9f904cc..02d294c 100644 (file)
@@ -45,7 +45,6 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
-from tensorflow.python.util import tf_inspect
 
 __all__ = ["rev_block", "RevBlock", "recompute_grad"]
 
@@ -430,13 +429,12 @@ def enable_with_args(dec):
 
 
 @enable_with_args
-def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False,
-                   tensor_arg_names=None):
+def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
   """Decorator that recomputes the function on the backwards pass.
 
   Args:
-    fn: the subgraph-producing function to wrap and recompute when computing
-      gradients. Provide `tensor_arg_names` if not all arguments are `Tensor`s.
+    fn: a function that takes Tensors (all as positional arguments) and returns
+      a tuple of Tensors.
     use_data_dep: `bool`, if `True` will use a dummy data dependency to force
       the recompute to happen. If `False` will use a control dependency. By
       default will be `True` if in an XLA context and `False` otherwise. XLA
@@ -445,25 +443,17 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False,
       that all gradients are produced before any are consumed by downstream ops.
       If `use_data_dep` is also `True`, will use a data dependency instead of
       a control dependency.
-    tensor_arg_names: `list<str>`, names of the `Tensor` arguments to `fn`. If
-      `None`, assumes all arguments are `Tensor`s.
 
   Returns:
     A wrapped fn that is identical to fn when called, but its activations will
     be discarded and recomputed on the backwards pass (i.e. on a call to
     tf.gradients).
   """
-  if tensor_arg_names:
-    if not isinstance(tensor_arg_names, (list, tuple)):
-      raise TypeError("tensor_arg_names must be a list")
 
   @functools.wraps(fn)
-  def wrapped(*args, **kwargs):
-    tensor_only_fn, tensor_args = _make_tensor_only(fn, args, kwargs,
-                                                    tensor_arg_names)
+  def wrapped(*args):
     return _recompute_grad(
-        tensor_only_fn, tensor_args, use_data_dep=use_data_dep,
-        tupleize_grads=tupleize_grads)
+        fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads)
 
   return wrapped
 
@@ -473,59 +463,11 @@ def _is_on_tpu():
   return control_flow_util.GetContainingXLAContext(ctxt) is not None
 
 
-def _make_tensor_only(fn, args, kwargs, tensor_arg_names):
-  """Return fn such that it only takes Tensor args for tensor_arg_names."""
-  argspec = tf_inspect.getargspec(fn)
-  if argspec.varargs is not None or argspec.keywords is not None:
-    raise ValueError("Function decorated with recompute_grad must not use "
-                     "*args or **kwargs.")
-  fn_arg_names = list(argspec.args)
-
-  # name_to_arg is a dict of argument name to argument value, including both
-  # positional and keyword arguments passed.
-  name_to_arg = {}
-  # Populate positional arguments.
-  for name, arg in zip(fn_arg_names[:len(args)], args):
-    name_to_arg[name] = arg
-  # Populate keyword arguments.
-  name_to_arg.update(kwargs)
-
-  # Separate the Tensor arguments from the non-Tensor arguments.
-  # The default is that all arguments are Tensor arguments.
-  tensor_arg_names = tensor_arg_names or fn_arg_names
-  for name in tensor_arg_names:
-    if name not in name_to_arg:
-      raise ValueError("Must provide Tensor argument %s" % name)
-  tensor_args = [name_to_arg[name] for name in tensor_arg_names]
-  non_tensor_kwargs = dict([(name, arg) for name, arg in name_to_arg.items()
-                            if name not in tensor_arg_names])
-
-  # Check that Tensor arguments are in fact Tensors and that non-Tensor
-  # arguments are not.
-  for name, arg in zip(tensor_arg_names, tensor_args):
-    if not isinstance(arg, framework_ops.Tensor):
-      raise TypeError("Fn argument %s must be a Tensor." % name)
-  for name, arg in non_tensor_kwargs.items():
-    if isinstance(arg, framework_ops.Tensor):
-      raise TypeError("Fn argument %s must not be a Tensor." % name)
-
-  # Construct a Tensor-only wrapper function that will pass the non-Tensor
-  # arguments as well when called.
-  def tensor_only_fn(*tensors):
-    all_kwargs = dict(zip(tensor_arg_names, tensors))
-    all_kwargs.update(non_tensor_kwargs)
-    return fn(**all_kwargs)
-
-  return tensor_only_fn, tensor_args
-
-
-def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT,
-                    tupleize_grads=False):
+def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
   """See recompute_grad."""
   for arg in args:
     if not isinstance(arg, framework_ops.Tensor):
       raise ValueError("All inputs to function must be Tensors")
-
   use_data_dep_ = use_data_dep
   if use_data_dep_ == _USE_DEFAULT:
     use_data_dep_ = _is_on_tpu()
@@ -559,11 +501,14 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT,
     grad_vars = grads[len(inputs):]
     return grad_inputs, grad_vars
 
-  # TODO(rsepassi): Replace with tf.custom_gradient
   @_fn_with_custom_grad(grad_fn)
   def fn_with_recompute(*args):
     cached_vs.append(variable_scope.get_variable_scope())
-    cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
+    # TODO(rsepassi): Rm conditional in TF 1.4
+    if hasattr(contrib_framework_ops, "current_arg_scope"):
+      cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
+    else:
+      cached_arg_scope.append({})
     return fn(*args)
 
   return fn_with_recompute(*args)
index 66ccc69..392a490 100644 (file)
@@ -318,108 +318,6 @@ class RecomputeTest(test.TestCase):
       self.assertEqual(1, len(grads))
       self.assertTrue(grads[0] is not None)
 
-  def testWithNontensorArgs(self):
-    @rev_block_lib.recompute_grad(tupleize_grads=True,
-                                  tensor_arg_names=["inputs"])
-    def layer_with_recompute(inputs, plus=None):
-      var = variable_scope.get_variable("var", ())
-      self.assertFalse(plus)  # called with False below
-      if plus:
-        return var + inputs
-      else:
-        return var * inputs
-
-    inputs = array_ops.ones((), dtypes.float32)
-    outputs = layer_with_recompute(inputs, plus=False)
-    loss = math_ops.square(outputs)
-    grads = gradients_impl.gradients(loss, variables.trainable_variables())
-    self.assertEqual(1, len(grads))
-    self.assertTrue(grads[0] is not None)
-
-
-class MakeTensorOnlyTest(test.TestCase):
-
-  def testMakeTensorOnly(self):
-    def fn(a, b, c, d=1, e=None, f=7):
-      return (a, b, c, d, e, f)
-
-    t1 = array_ops.ones(())
-    t2 = array_ops.ones(())
-    t3 = array_ops.ones(())
-    args = [1, t1, 3, t2]
-    kwargs = {"e": t3}
-    tensor_only_fn, tensor_args = rev_block_lib._make_tensor_only(
-        fn, args, kwargs, ["b", "d", "e"])
-    self.assertAllEqual(tensor_args, [t1, t2, t3])
-    out = tensor_only_fn(*tensor_args)
-    self.assertAllEqual(out, (1, t1, 3, t2, t3, 7))
-
-  def testMakeTensorOnlyPositionalArgsOnly(self):
-    def fn(a, b, c):
-      return (a, b, c)
-
-    t1 = array_ops.ones(())
-    t2 = array_ops.ones(())
-    args = [t1, 3, t2]
-    tensor_only_fn, tensor_args = rev_block_lib._make_tensor_only(
-        fn, args, {}, ["a", "c"])
-    self.assertAllEqual(tensor_args, [t1, t2])
-    out = tensor_only_fn(*tensor_args)
-    self.assertAllEqual(out, (t1, 3, t2))
-
-  def testMakeTensorOnlyKwargsArgsOnly(self):
-    def fn(a=1, b=2, c=3):
-      return (a, b, c)
-
-    t1 = array_ops.ones(())
-    t2 = array_ops.ones(())
-    args = [t1]
-    kwargs = {"c": t2}
-    tensor_only_fn, tensor_args = rev_block_lib._make_tensor_only(
-        fn, args, kwargs, ["a", "c"])
-    self.assertAllEqual(tensor_args, [t1, t2])
-    out = tensor_only_fn(*tensor_args)
-    self.assertAllEqual(out, (t1, 2, t2))
-
-  def testErrorOnMissingTensorArg(self):
-    def fn(a, b):
-      return (a, b)
-
-    with self.assertRaisesWithPredicateMatch(
-        ValueError, "provide Tensor argument"):
-      rev_block_lib._make_tensor_only(fn, [], {"b": 2}, ["a"])
-
-  def testErrorOnSignatureSplats(self):
-    def fn1(a, *args):
-      return (a, args)
-
-    err_msg = r"must not use \*args or \*\*kwargs"
-    with self.assertRaisesWithPredicateMatch(ValueError, err_msg):
-      rev_block_lib._make_tensor_only(fn1, [1, 2], {}, ["a"])
-
-    def fn2(a, **kwargs):
-      return (a, kwargs)
-
-    with self.assertRaisesWithPredicateMatch(ValueError, err_msg):
-      rev_block_lib._make_tensor_only(fn2, [], {"a": 1, "b": 2}, ["a"])
-
-  def testErrorOnNonTensorForTensor(self):
-    def fn(a, b):
-      return (a, b)
-
-    with self.assertRaisesWithPredicateMatch(TypeError, "must be a Tensor"):
-      rev_block_lib._make_tensor_only(fn, [2, 3], {}, ["a"])
-
-  def testErrorOnTensorForNonTensor(self):
-    def fn(a, b):
-      return (a, b)
-
-    with self.assertRaisesWithPredicateMatch(
-        TypeError, "must not be a Tensor"):
-      t1 = array_ops.ones(())
-      t2 = array_ops.ones(())
-      rev_block_lib._make_tensor_only(fn, [t1, t2], {}, ["a"])
-
 
 class FnWithCustomGradTest(test.TestCase):