From 59ce970732e7f8f1a22c12e52819ee43a4d3fec3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 6 Apr 2018 16:09:37 -0700 Subject: [PATCH] Add support for ResourceVariable to recompute_grad PiperOrigin-RevId: 191954813 --- tensorflow/contrib/layers/python/layers/rev_block_lib.py | 5 ++--- .../contrib/layers/python/layers/rev_block_lib_test.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index e49589d..02d294c 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -247,9 +247,7 @@ class RevBlock(base.Layer): f_vars_idxs = [[] for _ in range(self.num_layers)] g_vars_idxs = [[] for _ in range(self.num_layers)] - for i, t in enumerate(variables): - ref = _underlying_variable_ref(t) - + for i, ref in enumerate(variables): # Use the name to identify the layer number and function (f or g) regex = LAYER_RE.match(ref.name) layer_no = int(regex.group(1)) @@ -604,6 +602,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): """Custom grad fn applying grad_fn for identity Defun.""" fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( defun_inputs, list(op.inputs)) + fn_vars = [_underlying_variable_ref(v) for v in fn_vars] dys = list(dys) assert len(fn_outputs) == len(outputs) assert len(fn_outputs) == len(dys) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index d1ad4e8..392a490 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -304,6 +304,20 @@ class RecomputeTest(test.TestCase): self.assertAllClose(current, g) current = g + def testResourceVariable(self): + @rev_block_lib.recompute_grad(tupleize_grads=True) + def layer_with_recompute(inputs): + var = variable_scope.get_variable("var", ()) + return var * inputs + + inputs = array_ops.ones((), dtypes.float32) + with variable_scope.variable_scope("layer", use_resource=True): + outputs = layer_with_recompute(inputs) + loss = math_ops.square(outputs) + grads = gradients_impl.gradients(loss, variables.trainable_variables()) + self.assertEqual(1, len(grads)) + self.assertTrue(grads[0] is not None) + class FnWithCustomGradTest(test.TestCase): -- 2.7.4