From 753cc5b3f7461b0b3f59605cba10b965aca0e3ad Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 21 May 2018 16:37:17 -0700 Subject: [PATCH] Fixes issue with gradient tape when asking for the gradient of an intermediate tensor. PiperOrigin-RevId: 197481473 --- tensorflow/c/eager/tape.h | 14 +++++++++++--- tensorflow/python/eager/backprop_test.py | 12 ++++++++++++ tensorflow/python/eager/pywrap_tfe_src.cc | 2 ++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index dcc2357..1833b25 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -104,6 +104,10 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; @@ -356,8 +360,7 @@ BackpropInitialState PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -522,10 +525,15 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector in_gradients; diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 9aaa2e3..826c668 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -615,6 +615,18 @@ class BackpropTest(test.TestCase): self.assertAllEqual(self.evaluate(grad), 2.0) @test_util.assert_no_new_tensors + @test_util.run_in_graph_and_eager_modes() + def testNestedGradients(self): + x = constant_op.constant(3.0) + with backprop.GradientTape() as g: + g.watch(x) + y = x * x + z = y * y + dz_dx, dz_dy = g.gradient(z, [x, y]) + self.assertEqual(self.evaluate(dz_dx), 108.0) + self.assertEqual(self.evaluate(dz_dy), 18.0) + + @test_util.assert_no_new_tensors def testEmptyParamsForValueAndGradFunction(self): def fn(a, b): return a * b diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index f78043e..62deb41 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1348,6 +1348,8 @@ class PyVSpace : public tensorflow::eager::VSpace { return result; } + void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } + PyObject* Zeros(tensorflow::TensorShape shape, tensorflow::DataType dtype) const final { PyObject* py_shape = PyTuple_New(shape.dims()); -- 2.7.4