Support structured source in GradientTape.gradient
authorIgor Ganichev <iga@google.com>
Thu, 29 Mar 2018 03:51:01 +0000 (20:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 03:53:48 +0000 (20:53 -0700)
Before this change, it was easy to forget [] around the source tensor.
This mistake lead to GradientTape.gradient(), returning a list of Nones.
Nones normally tell to the user that the source and the target are
not connected via differentiable operations, which is not the source
of the error in this case.

Instead of adding a check that `sources` is a list of tensors, this CL
adds ability to handle structured source (which includes a lone tensor),
similarly to many existing TensorFlow APIs.

Also, with Alex's help, it fixes a bug where repeated tensors in
`sources` were not handled correctly.

PiperOrigin-RevId: 190878583

tensorflow/c/eager/tape.h
tensorflow/python/eager/backprop.py
tensorflow/python/eager/backprop_test.py
tensorflow/python/eager/pywrap_tfe_src.cc

index c7bd3bd..97c323b 100644 (file)
@@ -601,23 +601,28 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
   }
   CHECK(state.op_tape.empty());
   result->reserve(source_tensor_ids.size());
+  gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
   for (auto is : source_tensor_ids) {
     auto grad_it = gradients.find(is);
     if (grad_it == gradients.end()) {
       result->push_back(nullptr);
     } else {
-      if (grad_it->second.size() == 1) {
-        result->push_back(grad_it->second[0]);
-      } else {
-        result->push_back(vspace.AggregateGradients(grad_it->second));
+      if (grad_it->second.size() > 1) {
+        Gradient* grad = vspace.AggregateGradients(grad_it->second);
+        grad_it->second.clear();
+        grad_it->second.push_back(grad);
       }
-      gradients.erase(grad_it);
+      result->push_back(grad_it->second[0]);
+      used_gradient_ids.insert(is);
     }
   }
-  VLOG(1) << "Final gradients size: " << gradients.size();
+  VLOG(1) << "Final gradients size: "
+          << gradients.size() - used_gradient_ids.size();
   for (auto grad_pair : gradients) {
-    for (const auto& g : grad_pair.second) {
-      vspace.DeleteGradient(g);
+    if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
+      for (const auto& g : grad_pair.second) {
+        vspace.DeleteGradient(g);
+      }
     }
   }
   return Status::OK();
index c54a5a1..209b012 100644 (file)
@@ -646,6 +646,13 @@ _default_vspace = imperative_grad.VSpace(
     ones=_ones)
 
 
+def _handle_or_self(x):
+  """If x is ResourceVariable, return its handle, else x."""
+  if isinstance(x, resource_variable_ops.ResourceVariable):
+    x = x.handle
+  return x
+
+
 @tf_export("GradientTape")
 class GradientTape(object):
   """Record operations for automatic differentiation.
@@ -723,9 +730,7 @@ class GradientTape(object):
       tensor: a Tensor or list of Tensors.
     """
     for t in nest.flatten(tensor):
-      if isinstance(t, resource_variable_ops.ResourceVariable):
-        t = t.handle
-      tape.watch(t)
+      tape.watch(_handle_or_self(t))
 
   def watched_variables(self):
     # Sorting variables by id, which is monotonically increasing in construction
@@ -739,14 +744,15 @@ class GradientTape(object):
 
     Args:
       target: Tensor to be differentiated.
-      sources: a list of Tensors or Variables. `target` will be differentiated
-        against elements in `sources`.
+      sources: a list or nested structure of Tensors or Variables. `target`
+        will be differentiated against elements in `sources`.
       output_gradients: a list of gradients, one for each element of
         target. Defaults to None.
 
     Returns:
-      a list of Tensors (or IndexedSlices, or None), one for each element in
-      `sources`.
+      a list or nested structure of Tensors (or IndexedSlices, or None),
+      one for each element in `sources`. Returned structure is the same as
+      the structure of `sources`.
 
     Raises:
       RuntimeError: if called inside the context of the tape, or if called more
@@ -756,12 +762,15 @@ class GradientTape(object):
       raise RuntimeError("GradientTape.gradient can only be called once "
                          "on non-persistent tapes, and "
                          "only when the context manager has exited.")
-    sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
-               else x
-               for x in sources]
-    grad = imperative_grad.imperative_grad(
-        _default_vspace, self._tape, [target], sources,
+    flat_sources = nest.flatten(sources)
+    flat_sources = [_handle_or_self(x) for x in flat_sources]
+
+    flat_grad = imperative_grad.imperative_grad(
+        _default_vspace, self._tape, [target], flat_sources,
         output_gradients=output_gradients)
+
     if not self._persistent:
       self._tape = None
+
+    grad = nest.pack_sequence_as(sources, flat_grad)
     return grad
index f04d89a..991b4db 100644 (file)
@@ -371,6 +371,53 @@ class BackpropTest(test.TestCase):
 
   @test_util.assert_no_new_tensors
   @test_util.run_in_graph_and_eager_modes()
+  def testGradientTapeRepeatedSource(self):
+    with backprop.GradientTape(persistent=False) as g:
+      x = constant_op.constant(3.0)
+      g.watch(x)
+      y = 2 * x
+    grad = g.gradient(target=y, sources=[x, x])
+    self.assertEqual(self.evaluate(grad), [2.0, 2.0])
+
+  @test_util.assert_no_new_tensors
+  @test_util.run_in_graph_and_eager_modes()
+  def testPersistentGradientTapeRepeatedSource(self):
+    with backprop.GradientTape(persistent=True) as g:
+      x = constant_op.constant(3.0)
+      y = constant_op.constant(5.0)
+      g.watch(x)
+      g.watch(y)
+      z = x * x + x * y
+    grad = g.gradient(target=z, sources=[x, x])
+    self.assertEqual(self.evaluate(grad), [11.0, 11.0])
+    grad = g.gradient(target=z, sources=[y, x])
+    self.assertEqual(self.evaluate(grad), [3.0, 11.0])
+
+  @test_util.assert_no_new_tensors
+  @test_util.run_in_graph_and_eager_modes()
+  def testGradientTapeStructure(self):
+    with backprop.GradientTape(persistent=True) as g:
+      # Using different constant values because constant tensors are
+      # cached, leading to a different gradient then what one might expect.
+      x1 = constant_op.constant(3.0)
+      x2 = constant_op.constant(3.1)
+      x3 = constant_op.constant(3.2)
+      g.watch(x1)
+      g.watch(x2)
+      g.watch(x3)
+      y = x1  + 2 * x2  + 3 * x3
+    self.assertEqual(self.evaluate(g.gradient(y, x1)), [1.0])
+    self.assertEqual(self.evaluate(g.gradient(y, (x1,))), (1.0,))
+    self.assertEqual(self.evaluate(g.gradient(y, (x1, x2))), (1.0, 2.0))
+    self.assertEqual(self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])),
+                     [(1.0, 2.0), (2.0, 3.0)])
+    self.assertEqual(self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
+                     (1.0, 2.0, [1.0, 3.0]))
+    self.assertEqual(self.evaluate(g.gradient(y, [x1, {'x2': x2, 'x3': x3}])),
+                     [1.0, {'x2': 2.0, 'x3': 3.0}])
+
+  @test_util.assert_no_new_tensors
+  @test_util.run_in_graph_and_eager_modes()
   def testGradientTape(self):
     with backprop.GradientTape() as g:
       x = constant_op.constant(3.0)
index 7348279..8a398f6 100644 (file)
@@ -1372,11 +1372,15 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
   }
   if (!result.empty()) {
     PyObject* py_result = PyList_New(result.size());
+    tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
     for (int i = 0; i < result.size(); ++i) {
       if (result[i] == nullptr) {
         Py_INCREF(Py_None);
         result[i] = Py_None;
+      } else if (seen_results.find(result[i]) != seen_results.end()) {
+        Py_INCREF(result[i]);
       }
+      seen_results.insert(result[i]);
       PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
     }
     return py_result;