gtl::ArraySlice<Gradient*> output_gradients,
const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
- const gtl::FlatMap<int64, int64>& tensor_usage_counts,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
- if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
- if (!output_gradients.empty() && output_gradients[i] != nullptr) {
- // TODO(apassos) figure out how to print debugging information here.
- return errors::InvalidArgument(
- "A gradient was provided for a tensor which is used as part of the "
- "computation.");
- }
- } else {
- if (output_gradients.empty() || output_gradients[i] == nullptr) {
- auto tensor_it = tensor_tape.find(id);
- if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
- auto op_it = op_tape.find(tensor_it->second);
- if (op_it == op_tape.end()) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "failed to find operation producing a tensor");
- }
- bool found = false;
- for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
- found = true;
- (*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
- break;
- }
- }
- if (!found) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "none of operations outputs match expected tensor");
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ auto tensor_it = tensor_tape.find(id);
+ if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+ auto op_it = op_tape.find(tensor_it->second);
+ if (op_it == op_tape.end()) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "failed to find operation producing a tensor");
+ }
+ bool found = false;
+ for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+ if (op_it->second.output_tensor_info[j].id == id) {
+ found = true;
+ (*result)[id].push_back(
+ vspace.Ones(op_it->second.output_tensor_info[j].shape,
+ op_it->second.output_tensor_info[j].dtype));
+ break;
}
- } else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ }
+ if (!found) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "none of operations outputs match expected tensor");
}
} else {
- (*result)[id].push_back(output_gradients[i]);
+ // No record of the target tensor found on the tape, so no gradient
+ // needs to be computed from it. Do nothing.
}
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
- tensor_tape_, state.op_tape,
- state.tensor_usage_counts, &gradients);
+ tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions
"""Computes the gradient using operations recorded in context of this tape.
Args:
- target: Tensor to be differentiated.
+ target: Tensor (or list of tensors) to be differentiated.
sources: a list or nested structure of Tensors or Variables. `target`
will be differentiated against elements in `sources`.
output_gradients: a list of gradients, one for each element of
flat_sources = nest.flatten(sources)
flat_sources = [_handle_or_self(x) for x in flat_sources]
+ if output_gradients is not None:
+ output_gradients = [None if x is None else ops.convert_to_tensor(x)
+ for x in nest.flatten(output_gradients)]
+
flat_grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, [target], flat_sources,
+ _default_vspace, self._tape, nest.flatten(target), flat_sources,
output_gradients=output_gradients)
if not self._persistent:
self.assertAllEqual(grads_and_vars[0][0], 1.0)
self.assertAllEqual(id(grads_and_vars[0][1]), id(x))
+ def testTwoTargets(self):
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(2.0)
+ t.watch([x, y])
+ xx = 2 * x
+ yy = 3 * y
+ dx, dy = t.gradient([xx, yy], [x, y])
+ self.assertAllEqual(dx, 2.0)
+ self.assertAllEqual(dy, 3.0)
+
+ def testOutputGradUsedInComputation(self):
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(2.0)
+ t.watch([x, y])
+ loss = x * y
+ dx, = t.gradient([loss, x], [x], output_gradients=[1.0, 2.0])
+ self.assertAllEqual(dx, 4.0)
+
def testDy(self):
def f(x):