Fixes to tape gradient for providing outputs and having multiple targets.
authorAlexandre Passos <apassos@google.com>
Mon, 30 Apr 2018 16:29:31 +0000 (09:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 16:32:36 +0000 (09:32 -0700)
PiperOrigin-RevId: 194796304

tensorflow/c/eager/tape.h
tensorflow/python/eager/backprop.py
tensorflow/python/eager/backprop_test.py

index 97c323b..8026076 100644 (file)
@@ -380,49 +380,39 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
                         gtl::ArraySlice<Gradient*> output_gradients,
                         const TensorTape& tensor_tape,
                         const OpTape<BackwardFunction>& op_tape,
-                        const gtl::FlatMap<int64, int64>& tensor_usage_counts,
                         gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
   for (int i = 0; i < target_tensor_ids.size(); ++i) {
     const int64 id = target_tensor_ids[i];
-    if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
-      if (!output_gradients.empty() && output_gradients[i] != nullptr) {
-        // TODO(apassos) figure out how to print debugging information here.
-        return errors::InvalidArgument(
-            "A gradient was provided for a tensor which is used as part of the "
-            "computation.");
-      }
-    } else {
-      if (output_gradients.empty() || output_gradients[i] == nullptr) {
-        auto tensor_it = tensor_tape.find(id);
-        if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
-          auto op_it = op_tape.find(tensor_it->second);
-          if (op_it == op_tape.end()) {
-            return errors::Internal(
-                "Internal state of the gradient tape is invalid: "
-                "failed to find operation producing a tensor");
-          }
-          bool found = false;
-          for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
-            if (op_it->second.output_tensor_info[j].id == id) {
-              found = true;
-              (*result)[id].push_back(
-                  vspace.Ones(op_it->second.output_tensor_info[j].shape,
-                              op_it->second.output_tensor_info[j].dtype));
-              break;
-            }
-          }
-          if (!found) {
-            return errors::Internal(
-                "Internal state of the gradient tape is invalid: "
-                "none of operations outputs match expected tensor");
+    if (output_gradients.empty() || output_gradients[i] == nullptr) {
+      auto tensor_it = tensor_tape.find(id);
+      if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+        auto op_it = op_tape.find(tensor_it->second);
+        if (op_it == op_tape.end()) {
+          return errors::Internal(
+              "Internal state of the gradient tape is invalid: "
+              "failed to find operation producing a tensor");
+        }
+        bool found = false;
+        for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+          if (op_it->second.output_tensor_info[j].id == id) {
+            found = true;
+            (*result)[id].push_back(
+                vspace.Ones(op_it->second.output_tensor_info[j].shape,
+                            op_it->second.output_tensor_info[j].dtype));
+            break;
           }
-        } else {
-          // No record of the target tensor found on the tape, so no gradient
-          // needs to be computed from it. Do nothing.
+        }
+        if (!found) {
+          return errors::Internal(
+              "Internal state of the gradient tape is invalid: "
+              "none of operations outputs match expected tensor");
         }
       } else {
-        (*result)[id].push_back(output_gradients[i]);
+        // No record of the target tensor found on the tape, so no gradient
+        // needs to be computed from it. Do nothing.
       }
+    } else {
+      (*result)[id].push_back(output_gradients[i]);
     }
   }
   return Status::OK();
@@ -451,8 +441,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
       InitialStack(state.op_tape, state.op_missing_tensor);
   gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
   Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
-                              tensor_tape_, state.op_tape,
-                              state.tensor_usage_counts, &gradients);
+                              tensor_tape_, state.op_tape, &gradients);
   auto cleanup = [this, &state]() {
     if (!persistent_) {
       // Release all backprop functions
index 92774d4..07aec59 100644 (file)
@@ -740,7 +740,7 @@ class GradientTape(object):
     """Computes the gradient using operations recorded in context of this tape.
 
     Args:
-      target: Tensor to be differentiated.
+      target: Tensor (or list of tensors) to be differentiated.
       sources: a list or nested structure of Tensors or Variables. `target`
         will be differentiated against elements in `sources`.
       output_gradients: a list of gradients, one for each element of
@@ -762,8 +762,12 @@ class GradientTape(object):
     flat_sources = nest.flatten(sources)
     flat_sources = [_handle_or_self(x) for x in flat_sources]
 
+    if output_gradients is not None:
+      output_gradients = [None if x is None else ops.convert_to_tensor(x)
+                          for x in nest.flatten(output_gradients)]
+
     flat_grad = imperative_grad.imperative_grad(
-        _default_vspace, self._tape, [target], flat_sources,
+        _default_vspace, self._tape, nest.flatten(target), flat_sources,
         output_gradients=output_gradients)
 
     if not self._persistent:
index 991b4db..8d9959f 100644 (file)
@@ -96,6 +96,26 @@ class BackpropTest(test.TestCase):
     self.assertAllEqual(grads_and_vars[0][0], 1.0)
     self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
 
+  def testTwoTargets(self):
+    with backprop.GradientTape() as t:
+      x = constant_op.constant(3.0)
+      y = constant_op.constant(2.0)
+      t.watch([x, y])
+      xx = 2 * x
+      yy = 3 * y
+    dx, dy = t.gradient([xx, yy], [x, y])
+    self.assertAllEqual(dx, 2.0)
+    self.assertAllEqual(dy, 3.0)
+
+  def testOutputGradUsedInComputation(self):
+    with backprop.GradientTape() as t:
+      x = constant_op.constant(3.0)
+      y = constant_op.constant(2.0)
+      t.watch([x, y])
+      loss = x * y
+    dx, = t.gradient([loss, x], [x], output_gradients=[1.0, 2.0])
+    self.assertAllEqual(dx, 4.0)
+
   def testDy(self):
 
     def f(x):