f_vars_idxs = [[] for _ in range(self.num_layers)]
g_vars_idxs = [[] for _ in range(self.num_layers)]
- for i, t in enumerate(variables):
- ref = _underlying_variable_ref(t)
-
+ for i, ref in enumerate(variables):
# Use the name to identify the layer number and function (f or g)
regex = LAYER_RE.match(ref.name)
layer_no = int(regex.group(1))
"""Custom grad fn applying grad_fn for identity Defun."""
fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as(
defun_inputs, list(op.inputs))
+ fn_vars = [_underlying_variable_ref(v) for v in fn_vars]
dys = list(dys)
assert len(fn_outputs) == len(outputs)
assert len(fn_outputs) == len(dys)
self.assertAllClose(current, g)
current = g
+ def testResourceVariable(self):
+ @rev_block_lib.recompute_grad(tupleize_grads=True)
+ def layer_with_recompute(inputs):
+ var = variable_scope.get_variable("var", ())
+ return var * inputs
+
+ inputs = array_ops.ones((), dtypes.float32)
+ with variable_scope.variable_scope("layer", use_resource=True):
+ outputs = layer_with_recompute(inputs)
+ loss = math_ops.square(outputs)
+ grads = gradients_impl.gradients(loss, variables.trainable_variables())
+ self.assertEqual(1, len(grads))
+ self.assertTrue(grads[0] is not None)
+
class FnWithCustomGradTest(test.TestCase):