Before this change, it was easy to forget [] around the source tensor.
This mistake lead to GradientTape.gradient(), returning a list of Nones.
Nones normally tell to the user that the source and the target are
not connected via differentiable operations, which is not the source
of the error in this case.
Instead of adding a check that `sources` is a list of tensors, this CL
adds ability to handle structured source (which includes a lone tensor),
similarly to many existing TensorFlow APIs.
Also, with Alex's help, it fixes a bug where repeated tensors in
`sources` were not handled correctly.
PiperOrigin-RevId:
190878583
}
CHECK(state.op_tape.empty());
result->reserve(source_tensor_ids.size());
+ gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
for (auto is : source_tensor_ids) {
auto grad_it = gradients.find(is);
if (grad_it == gradients.end()) {
result->push_back(nullptr);
} else {
- if (grad_it->second.size() == 1) {
- result->push_back(grad_it->second[0]);
- } else {
- result->push_back(vspace.AggregateGradients(grad_it->second));
+ if (grad_it->second.size() > 1) {
+ Gradient* grad = vspace.AggregateGradients(grad_it->second);
+ grad_it->second.clear();
+ grad_it->second.push_back(grad);
}
- gradients.erase(grad_it);
+ result->push_back(grad_it->second[0]);
+ used_gradient_ids.insert(is);
}
}
- VLOG(1) << "Final gradients size: " << gradients.size();
+ VLOG(1) << "Final gradients size: "
+ << gradients.size() - used_gradient_ids.size();
for (auto grad_pair : gradients) {
- for (const auto& g : grad_pair.second) {
- vspace.DeleteGradient(g);
+ if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
+ for (const auto& g : grad_pair.second) {
+ vspace.DeleteGradient(g);
+ }
}
}
return Status::OK();
ones=_ones)
+def _handle_or_self(x):
+ """If x is ResourceVariable, return its handle, else x."""
+ if isinstance(x, resource_variable_ops.ResourceVariable):
+ x = x.handle
+ return x
+
+
@tf_export("GradientTape")
class GradientTape(object):
"""Record operations for automatic differentiation.
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- if isinstance(t, resource_variable_ops.ResourceVariable):
- t = t.handle
- tape.watch(t)
+ tape.watch(_handle_or_self(t))
def watched_variables(self):
# Sorting variables by id, which is monotonically increasing in construction
Args:
target: Tensor to be differentiated.
- sources: a list of Tensors or Variables. `target` will be differentiated
- against elements in `sources`.
+ sources: a list or nested structure of Tensors or Variables. `target`
+ will be differentiated against elements in `sources`.
output_gradients: a list of gradients, one for each element of
target. Defaults to None.
Returns:
- a list of Tensors (or IndexedSlices, or None), one for each element in
- `sources`.
+ a list or nested structure of Tensors (or IndexedSlices, or None),
+ one for each element in `sources`. Returned structure is the same as
+ the structure of `sources`.
Raises:
RuntimeError: if called inside the context of the tape, or if called more
raise RuntimeError("GradientTape.gradient can only be called once "
"on non-persistent tapes, and "
"only when the context manager has exited.")
- sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
- else x
- for x in sources]
- grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, [target], sources,
+ flat_sources = nest.flatten(sources)
+ flat_sources = [_handle_or_self(x) for x in flat_sources]
+
+ flat_grad = imperative_grad.imperative_grad(
+ _default_vspace, self._tape, [target], flat_sources,
output_gradients=output_gradients)
+
if not self._persistent:
self._tape = None
+
+ grad = nest.pack_sequence_as(sources, flat_grad)
return grad
@test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes()
+ def testGradientTapeRepeatedSource(self):
+ with backprop.GradientTape(persistent=False) as g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = 2 * x
+ grad = g.gradient(target=y, sources=[x, x])
+ self.assertEqual(self.evaluate(grad), [2.0, 2.0])
+
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes()
+ def testPersistentGradientTapeRepeatedSource(self):
+ with backprop.GradientTape(persistent=True) as g:
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(5.0)
+ g.watch(x)
+ g.watch(y)
+ z = x * x + x * y
+ grad = g.gradient(target=z, sources=[x, x])
+ self.assertEqual(self.evaluate(grad), [11.0, 11.0])
+ grad = g.gradient(target=z, sources=[y, x])
+ self.assertEqual(self.evaluate(grad), [3.0, 11.0])
+
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes()
+ def testGradientTapeStructure(self):
+ with backprop.GradientTape(persistent=True) as g:
+ # Using different constant values because constant tensors are
+ # cached, leading to a different gradient then what one might expect.
+ x1 = constant_op.constant(3.0)
+ x2 = constant_op.constant(3.1)
+ x3 = constant_op.constant(3.2)
+ g.watch(x1)
+ g.watch(x2)
+ g.watch(x3)
+ y = x1 + 2 * x2 + 3 * x3
+ self.assertEqual(self.evaluate(g.gradient(y, x1)), [1.0])
+ self.assertEqual(self.evaluate(g.gradient(y, (x1,))), (1.0,))
+ self.assertEqual(self.evaluate(g.gradient(y, (x1, x2))), (1.0, 2.0))
+ self.assertEqual(self.evaluate(g.gradient(y, [(x1, x2), (x2, x3)])),
+ [(1.0, 2.0), (2.0, 3.0)])
+ self.assertEqual(self.evaluate(g.gradient(y, (x1, x2, [x1, x3]))),
+ (1.0, 2.0, [1.0, 3.0]))
+ self.assertEqual(self.evaluate(g.gradient(y, [x1, {'x2': x2, 'x3': x3}])),
+ [1.0, {'x2': 2.0, 'x3': 3.0}])
+
+ @test_util.assert_no_new_tensors
+ @test_util.run_in_graph_and_eager_modes()
def testGradientTape(self):
with backprop.GradientTape() as g:
x = constant_op.constant(3.0)
}
if (!result.empty()) {
PyObject* py_result = PyList_New(result.size());
+ tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
for (int i = 0; i < result.size(); ++i) {
if (result[i] == nullptr) {
Py_INCREF(Py_None);
result[i] = Py_None;
+ } else if (seen_results.find(result[i]) != seen_results.end()) {
+ Py_INCREF(result[i]);
}
+ seen_results.insert(result[i]);
PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
}
return py_result;