Add support for ResourceVariable to recompute_grad
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 23:09:37 +0000 (16:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 23:11:48 +0000 (16:11 -0700)
PiperOrigin-RevId: 191954813

tensorflow/contrib/layers/python/layers/rev_block_lib.py
tensorflow/contrib/layers/python/layers/rev_block_lib_test.py

index e49589d..02d294c 100644 (file)
@@ -247,9 +247,7 @@ class RevBlock(base.Layer):
     f_vars_idxs = [[] for _ in range(self.num_layers)]
     g_vars_idxs = [[] for _ in range(self.num_layers)]
 
-    for i, t in enumerate(variables):
-      ref = _underlying_variable_ref(t)
-
+    for i, ref in enumerate(variables):
       # Use the name to identify the layer number and function (f or g)
       regex = LAYER_RE.match(ref.name)
       layer_no = int(regex.group(1))
@@ -604,6 +602,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
     """Custom grad fn applying grad_fn for identity Defun."""
     fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as(
         defun_inputs, list(op.inputs))
+    fn_vars = [_underlying_variable_ref(v) for v in fn_vars]
     dys = list(dys)
     assert len(fn_outputs) == len(outputs)
     assert len(fn_outputs) == len(dys)
index d1ad4e8..392a490 100644 (file)
@@ -304,6 +304,20 @@ class RecomputeTest(test.TestCase):
           self.assertAllClose(current, g)
           current = g
 
+  def testResourceVariable(self):
+    @rev_block_lib.recompute_grad(tupleize_grads=True)
+    def layer_with_recompute(inputs):
+      var = variable_scope.get_variable("var", ())
+      return var * inputs
+
+    inputs = array_ops.ones((), dtypes.float32)
+    with variable_scope.variable_scope("layer", use_resource=True):
+      outputs = layer_with_recompute(inputs)
+      loss = math_ops.square(outputs)
+      grads = gradients_impl.gradients(loss, variables.trainable_variables())
+      self.assertEqual(1, len(grads))
+      self.assertTrue(grads[0] is not None)
+
 
 class FnWithCustomGradTest(test.TestCase):