Internally rewrite RevBlock to use @custom_gradient
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sun, 29 Apr 2018 02:47:42 +0000 (19:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 29 Apr 2018 02:50:17 +0000 (19:50 -0700)
PiperOrigin-RevId: 194679657

tensorflow/contrib/layers/python/layers/rev_block_lib.py
tensorflow/contrib/layers/python/layers/rev_block_lib_test.py

index 1a439f0..8ed9f44 100644 (file)
@@ -35,7 +35,6 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 from tensorflow.contrib.framework.python import ops as contrib_framework_ops
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
 from tensorflow.python.framework import ops as framework_ops
 from tensorflow.python.layers import base
 from tensorflow.python.ops import array_ops
@@ -155,7 +154,7 @@ def _scope_wrap(fn, scope):
 
   @functools.wraps(fn)
   def wrap(*args, **kwargs):
-    with variable_scope.variable_scope(scope):
+    with variable_scope.variable_scope(scope, use_resource=True):
       return fn(*args, **kwargs)
 
   return wrap
@@ -230,95 +229,95 @@ class RevBlock(base.Layer):
                  "build.")
     self.built = True
 
-  def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
-    """Custom gradient fn for a block of reversible residual layers."""
-    # Inputs have passed through an Identity. Recover the original Tensors to
-    # be able to match up side inputs.
-    assert [u"Identity"] == list(set([x.op.type for x in inputs]))
-    inputs = [x.op.inputs[0] for x in inputs]
-    side_inputs = inputs[2:]
-    del inputs
-
-    f_side_idxs = [None] * len(self.f_side_input)
-    g_side_idxs = [None] * len(self.g_side_input)
-    assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
-
-    for i, t in enumerate(side_inputs):
-      if t in self.f_side_input:
-        f_side_idxs[self.f_side_input.index(t)] = i
-      elif t in self.g_side_input:
-        g_side_idxs[self.g_side_input.index(t)] = i
-      else:
-        assert False
-
-    f_vars = [[] for _ in range(self.num_layers)]
-    g_vars = [[] for _ in range(self.num_layers)]
-    f_vars_idxs = [[] for _ in range(self.num_layers)]
-    g_vars_idxs = [[] for _ in range(self.num_layers)]
-
-    for i, ref in enumerate(variables):
-      # Use the name to identify the layer number and function (f or g)
-      regex = LAYER_RE.match(ref.name)
-      layer_no = int(regex.group(1))
-      fn_name = regex.group(2)
-      if fn_name == "f":
-        f_vars[layer_no].append(ref)
-        f_vars_idxs[layer_no].append(i)
-      else:
-        assert fn_name == "g"
-        g_vars[layer_no].append(ref)
-        g_vars_idxs[layer_no].append(i)
-
-    f_var_grads = []
-    g_var_grads = []
-    f_side_grads = []
-    g_side_grads = []
-
-    # Reverse variable containers to go backward
-    f_vars.reverse()
-    g_vars.reverse()
-    f = list(self.f)
-    g = list(self.g)
-    f.reverse()
-    g.reverse()
-
-    with variable_scope.variable_scope(self.scope_name, reuse=True):
-      for i in xrange(self.num_layers):
-        ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
-            ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i],
-            self.g_side_input)
-
-        grad_f_vars, grad_f_side = f_ret
-        grad_g_vars, grad_g_side = g_ret
-        f_var_grads.append(grad_f_vars)
-        g_var_grads.append(grad_g_vars)
-        f_side_grads.append(grad_f_side)
-        g_side_grads.append(grad_g_side)
-
-    # Accumulate layer gradients for f_side_input and g_side_input
-    acc_f_side_grads = _acc_grads(*f_side_grads)
-    acc_g_side_grads = _acc_grads(*g_side_grads)
-
-    # Use the stored idxs to put gradients in the passed-in order.
-    side_input_grads = [None] * len(side_inputs)
-    variable_grads = [None] * len(variables)
-
-    # Variable gradients were collected in reverse layer order. Reverse to match
-    # idxs.
-    f_var_grads.reverse()
-    g_var_grads.reverse()
-    for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
-        zip(g_vars_idxs, g_var_grads)):
-      for i, grad in zip(idxs, grads):
-        variable_grads[i] = grad
-
-    for i, grad in zip(f_side_idxs, acc_f_side_grads):
-      side_input_grads[i] = grad
-    for i, grad in zip(g_side_idxs, acc_g_side_grads):
-      side_input_grads[i] = grad
-
-    grad_x1, grad_x2 = grad_ys
-    return [grad_x1, grad_x2] + side_input_grads, variable_grads
+  def _make_efficient_grad_fn(self, inputs_, ys_):
+    def _efficient_grad_fn(*grad_ys, **kwargs):
+      """Custom gradient fn for a block of reversible residual layers."""
+      inputs = inputs_
+      ys = ys_
+      variables = kwargs["variables"]
+      side_inputs = inputs[2:]
+
+      f_side_idxs = [None] * len(self.f_side_input)
+      g_side_idxs = [None] * len(self.g_side_input)
+      assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
+
+      for i, t in enumerate(side_inputs):
+        if t in self.f_side_input:
+          f_side_idxs[self.f_side_input.index(t)] = i
+        elif t in self.g_side_input:
+          g_side_idxs[self.g_side_input.index(t)] = i
+        else:
+          assert False
+
+      f_vars = [[] for _ in range(self.num_layers)]
+      g_vars = [[] for _ in range(self.num_layers)]
+      f_vars_idxs = [[] for _ in range(self.num_layers)]
+      g_vars_idxs = [[] for _ in range(self.num_layers)]
+
+      for i, ref in enumerate(variables):
+        # Use the name to identify the layer number and function (f or g)
+        regex = LAYER_RE.match(ref.name)
+        layer_no = int(regex.group(1))
+        fn_name = regex.group(2)
+        if fn_name == "f":
+          f_vars[layer_no].append(ref)
+          f_vars_idxs[layer_no].append(i)
+        else:
+          assert fn_name == "g"
+          g_vars[layer_no].append(ref)
+          g_vars_idxs[layer_no].append(i)
+
+      f_var_grads = []
+      g_var_grads = []
+      f_side_grads = []
+      g_side_grads = []
+
+      # Reverse variable containers to go backward
+      f_vars.reverse()
+      g_vars.reverse()
+      f = list(self.f)
+      g = list(self.g)
+      f.reverse()
+      g.reverse()
+
+      with variable_scope.variable_scope(self.scope_name, reuse=True):
+        for i in xrange(self.num_layers):
+          ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
+              ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i],
+              self.g_side_input)
+
+          grad_f_vars, grad_f_side = f_ret
+          grad_g_vars, grad_g_side = g_ret
+          f_var_grads.append(grad_f_vars)
+          g_var_grads.append(grad_g_vars)
+          f_side_grads.append(grad_f_side)
+          g_side_grads.append(grad_g_side)
+
+      # Accumulate layer gradients for f_side_input and g_side_input
+      acc_f_side_grads = _acc_grads(*f_side_grads)
+      acc_g_side_grads = _acc_grads(*g_side_grads)
+
+      # Use the stored idxs to put gradients in the passed-in order.
+      side_input_grads = [None] * len(side_inputs)
+      variable_grads = [None] * len(variables)
+
+      # Variable gradients were collected in reverse layer order. Reverse to
+      # match idxs.
+      f_var_grads.reverse()
+      g_var_grads.reverse()
+      for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
+          zip(g_vars_idxs, g_var_grads)):
+        for i, grad in zip(idxs, grads):
+          variable_grads[i] = grad
+
+      for i, grad in zip(f_side_idxs, acc_f_side_grads):
+        side_input_grads[i] = grad
+      for i, grad in zip(g_side_idxs, acc_g_side_grads):
+        side_input_grads[i] = grad
+
+      grad_x1, grad_x2 = grad_ys
+      return [grad_x1, grad_x2] + side_input_grads, variable_grads
+    return _efficient_grad_fn
 
   def _forward(self, x1, x2):
     """Run forward through the reversible layers."""
@@ -326,10 +325,6 @@ class RevBlock(base.Layer):
     side_inputs = [self.f_side_input, self.g_side_input]
     flat_side_inputs = nest.flatten(side_inputs)
 
-    custom_grad_fn = (
-        self._efficient_grad_fn if self._use_efficient_backprop else None)
-
-    @_fn_with_custom_grad(custom_grad_fn)
     def _forward_wrap(x1_, x2_, *flat_side_inputs):
       f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs)
       return _rev_block_forward(
@@ -342,7 +337,16 @@ class RevBlock(base.Layer):
           g_side_input=g_side,
           gate_outputs=self._use_efficient_backprop)
 
-    return _forward_wrap(x1, x2, *flat_side_inputs)
+    @custom_gradient.custom_gradient
+    def _forward_with_custom_grad(*args):
+      out = _forward_wrap(*args)  # pylint: disable=no-value-for-parameter
+      grad_fn = self._make_efficient_grad_fn(args, out)
+      return out, grad_fn
+
+    if self._use_efficient_backprop:
+      return _forward_with_custom_grad(x1, x2, *flat_side_inputs)
+    else:
+      return _forward_wrap(x1, x2, *flat_side_inputs)
 
   def _backward(self, y1, y2):
     """Run backward through the reversible layers."""
@@ -560,107 +564,6 @@ def _underlying_variable_ref(t):
     return None
 
 
-def _fn_with_custom_grad(grad_fn, use_global_vars=False):
-  """Decorator to create a subgraph with a custom gradient function.
-
-  The subgraph created by the decorated function is NOT put in a Defun and so
-  does not suffer from the limitations of the Defun (all subgraph ops on the
-  same device, no summaries).
-
-  Args:
-    grad_fn: function with signature
-      (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
-      all of which are lists of Tensors.
-    use_global_vars: if True, variables will be the global variables created.
-      If False, will be the trainable variables.
-
-  Returns:
-    Decorator for function such that the gradient is defined by grad_fn.
-  """
-
-  def dec(fn):
-
-    @functools.wraps(fn)
-    def wrapped(*args):
-      return _fn_with_custom_grad_internal(
-          fn, args, grad_fn, use_global_vars=use_global_vars)
-
-    return wrapped
-
-  return dec
-
-
-def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
-  """Create a subgraph with a custom gradient.
-
-  Args:
-    fn: function that takes inputs as arguments and produces 1 or more Tensors.
-    inputs: list<Tensor>, will be passed as fn(*inputs).
-    grad_fn: function with signature
-      (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
-      all of which are lists of Tensors.
-    use_global_vars: if True, variables will be the global variables created.
-      If False, will be the trainable variables.
-
-  Returns:
-    fn(*inputs)
-  """
-  vs = variable_scope.get_variable_scope()
-  get_vars_fn = (
-      vs.global_variables if use_global_vars else vs.trainable_variables)
-  len_before_vars = len(get_vars_fn())
-  inputs = [array_ops.identity(x) for x in inputs]
-  outputs = fn(*inputs)
-  train_vars = get_vars_fn()[len_before_vars:]
-
-  if grad_fn is None:
-    return outputs
-
-  if not (isinstance(outputs, tuple) or isinstance(outputs, list)):
-    outputs = [outputs]
-  outputs = list(outputs)
-
-  defun_inputs = [inputs, train_vars, outputs]
-
-  def custom_grad_fn(op, *dys):
-    """Custom grad fn applying grad_fn for identity Defun."""
-    fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as(
-        defun_inputs, list(op.inputs))
-    fn_vars = [_underlying_variable_ref(v) for v in fn_vars]
-    dys = list(dys)
-    assert len(fn_outputs) == len(outputs)
-    assert len(fn_outputs) == len(dys)
-
-    grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys)
-    grad_outputs = [None] * len(fn_outputs)
-    return tuple(grad_inputs + grad_vars + grad_outputs)
-
-  # The Defun takes as input the original inputs, the trainable variables
-  # created in fn, and the outputs. In the forward it passes through the
-  # outputs. In the backwards, it produces gradients for the original inputs
-  # and the trainable variables.
-  in_types = [t.dtype for t in inputs]
-  out_types = [t.dtype for t in outputs]
-  var_types = [t.dtype for t in train_vars]
-
-  # Get a unique name for the Defun
-  with framework_ops.name_scope("identity_custom_grad") as ns:
-    defun_name = ns
-
-  @function.Defun(
-      *(in_types + var_types + out_types),
-      func_name=defun_name,
-      python_grad_func=custom_grad_fn,
-      shape_func=lambda _: [t.get_shape() for t in outputs])
-  def identity(*args):
-    _, _, outs = nest.pack_sequence_as(defun_inputs, args)
-    return tuple([array_ops.identity(t) for t in outs])
-
-  flat_inputs = nest.flatten(defun_inputs)
-  id_out = identity(*flat_inputs)
-  return id_out
-
-
 def _force_data_dependency(first_compute, then_compute):
   """Force all of `then_compute` to depend on all of `first_compute`.
 
index 8107486..997f53b 100644 (file)
@@ -83,8 +83,8 @@ class RevBlockTest(test.TestCase):
       sess.run(variables.global_variables_initializer())
       y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
 
-      self.assertAllClose(y1, y1_inv)
-      self.assertAllClose(y2, y2_inv)
+      self.assertAllClose(y1, y1_inv, rtol=1e-5)
+      self.assertAllClose(y2, y2_inv, rtol=1e-5)
 
   def _testRevBlock(self,
                     x=None,
@@ -179,18 +179,16 @@ class RevBlockTest(test.TestCase):
 
     self._testRevBlock(f=[f1, f2, f1, f2])
 
-  # TODO(rsepassi): Recent change to conv seems to have broken this test. Find
-  # out why.
-  def _testConvAndBatchNorm(self):
+  def testConvAndBatchNorm(self):
 
     x = random_ops.random_uniform(
         [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32)
 
     def f(x):
       x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same")
-      x = layers.batch_norm(x, is_training=True)
+      x = layers.batch_norm(x, is_training=False)
       x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same")
-      x = layers.batch_norm(x, is_training=True)
+      x = layers.batch_norm(x, is_training=False)
       return x
 
     self._testRevBlock(x=x, f=f)
@@ -345,89 +343,5 @@ class RecomputeTest(test.TestCase):
       self.assertTrue(grad is not None)
 
 
-class FnWithCustomGradTest(test.TestCase):
-
-  def testCorrectness(self):
-
-    w = random_ops.random_uniform([6, 10])
-
-    def fn(a, b, c):
-      return core_layers.dense(
-          a,
-          10,
-          use_bias=False,
-          kernel_initializer=lambda shape, dtype, partition_info: w
-      ) + math_ops.matmul(b, c)
-
-    def grad_fn(inputs, trainable_variables, outputs, grad_outputs):
-      outputs = outputs[0]
-      grad_outputs = grad_outputs[0]
-      grad_inputs = gradients_impl.gradients(
-          outputs, inputs, grad_ys=grad_outputs)
-      grad_vars = gradients_impl.gradients(
-          outputs, trainable_variables, grad_ys=grad_outputs)
-      return grad_inputs, grad_vars
-
-    custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)
-
-    a = random_ops.random_uniform([11, 6])
-    b = random_ops.random_uniform([11, 7])
-    c = random_ops.random_uniform([7, 10])
-
-    out = fn(a, b, c)
-    custom_out = custom_fn(a, b, c)
-    self.assertEqual(out.get_shape().as_list(),
-                     custom_out.get_shape().as_list())
-
-    loss = math_ops.reduce_mean(out)
-    custom_loss = math_ops.reduce_mean(custom_out)
-
-    grads = gradients_impl.gradients(
-        loss, [a, b, c] + [variables.trainable_variables()[0]])
-    custom_grads = gradients_impl.gradients(
-        custom_loss, [a, b, c] + [variables.trainable_variables()[1]])
-
-    with self.test_session() as sess:
-      sess.run(variables.global_variables_initializer())
-      out_val, custom_out_val, grads_val, custom_grads_val = sess.run(
-          [out, custom_out, grads, custom_grads])
-      self.assertAllClose(out_val, custom_out_val)
-      for g1, g2 in zip(grads_val, custom_grads_val):
-        self.assertAllClose(g1, g2)
-
-  def testCustomGrad(self):
-
-    def fn(a, b, c):
-      return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c)
-
-    def grad_fn(inputs, trainable_variables, unused_outputs,
-                unused_grad_outputs):
-      grad_inputs = [
-          array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs)
-      ]
-      grad_vars = [
-          array_ops.ones_like(t) * (i + len(inputs) + 1.)
-          for i, t in enumerate(trainable_variables)
-      ]
-      return grad_inputs, grad_vars
-
-    a = random_ops.random_uniform([11, 6])
-    b = random_ops.random_uniform([11, 7])
-    c = random_ops.random_uniform([7, 10])
-    w = random_ops.random_uniform([6, 10])
-    out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c)
-    loss = math_ops.reduce_mean(out)
-    grads = gradients_impl.gradients(
-        loss, [a, b, c, variables.trainable_variables()[0]])
-    expected_grads = [
-        array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w])
-    ]
-    with self.test_session() as sess:
-      sess.run(variables.global_variables_initializer())
-      g_val, eg_val = sess.run([grads, expected_grads])
-      for g1, g2 in zip(g_val, eg_val):
-        self.assertAllClose(g1, g2)
-
-
 if __name__ == "__main__":
   test.main()