Clarifying when is it possible to use a tape while it is still active.
authorAlexandre Passos <apassos@google.com>
Thu, 15 Mar 2018 22:58:44 +0000 (15:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 23:10:15 +0000 (16:10 -0700)
PiperOrigin-RevId: 189260773

tensorflow/c/eager/tape.h
tensorflow/python/eager/backprop_test.py
tensorflow/python/eager/pywrap_tfe_src.cc

index bdb0815..c7bd3bd 100644 (file)
@@ -152,6 +152,8 @@ class GradientTape {
                          gtl::ArraySlice<Gradient*> output_gradients,
                          std::vector<Gradient*>* result);
 
+  bool IsPersistent() const { return persistent_; }
+
  private:
   TensorTape tensor_tape_;
   OpTape<BackwardFunction> op_tape_;
index 07a2155..5934293 100644 (file)
@@ -195,6 +195,17 @@ class BackpropTest(test.TestCase):
     g, = backprop.gradients_function(loss, [0])(logits, labels)
     self.assertAllEqual(g.numpy(), [[-0.5, 0.5]])
 
+  def testGradientWithinTapeBlock(self):
+    v1 = resource_variable_ops.ResourceVariable(1.)
+    with backprop.GradientTape() as t:
+      loss = 2 * v1
+      with self.assertRaises(RuntimeError):
+        t.gradient(loss, [v1])
+    with backprop.GradientTape(persistent=True) as t:
+      loss = 2 * v1
+      grad = t.gradient(loss, [v1])
+    self.assertAllEqual(grad[0], 2.0)
+
   @test_util.assert_no_new_tensors
   def testSecondGrad(self):
 
index fe9785d..701f68b 100644 (file)
@@ -1323,6 +1323,16 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
                               PyObject* target, PyObject* sources,
                               PyObject* output_gradients, TF_Status* status) {
+  TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
+  if (!tape_obj->tape->IsPersistent()) {
+    auto* tape_set = GetTapeSet();
+    if (tape_set->find(tape_obj) != tape_set->end()) {
+      PyErr_SetString(PyExc_RuntimeError,
+                      "Trying to call tape.gradient on a non-persistent tape "
+                      "while it is still active.");
+      return nullptr;
+    }
+  }
   PyVSpace c_vspace(vspace);
   if (!c_vspace.Initialize().ok()) {
     return nullptr;
@@ -1348,7 +1358,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
       Py_INCREF(tensor);
     }
   }
-  TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
   std::vector<PyObject*> result;
   status->status = tape_obj->tape->ComputeGradient(
       c_vspace, target_vec, sources_vec, outgrad_vec, &result);