}
}
-std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_vars,
+optional_variable_list _process_backward_mode_ad(
+ const std::unordered_set<at::TensorImpl*> &inputs_set,
const std::unordered_set<at::TensorImpl*> &non_differentiable,
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node> &cdata) {
- std::unordered_set<at::TensorImpl*> inputs;
- inputs.reserve(input_vars.size());
- for (auto& var : input_vars) {
- inputs.emplace(var.unsafeGetTensorImpl());
- }
int num_outputs = raw_outputs.size();
// Here, `y` requires_grad (!).
} else if (is_modified) {
if (var.is_leaf() && var.requires_grad()) {
- throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
+ TORCH_CHECK(false, "a leaf Variable that requires grad has been used in an in-place operation.");
}
// No need to mark as modified Tensors that are not inputs.
if (!is_input) {
}
};
- std::vector<c10::optional<Variable>> outputs;
+ optional_variable_list outputs;
std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
outputs.reserve(num_outputs);
int num_diff_outputs = 0;
Variable var = raw_outputs[i].value();
auto out_tensor_impl = var.unsafeGetTensorImpl();
- bool is_input = inputs.count(out_tensor_impl) > 0;
+ bool is_input = inputs_set.count(out_tensor_impl) > 0;
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0
&& isDifferentiableType(var.scalar_type());
return outputs;
}
+
+
+optional_variable_list _wrap_outputs(const variable_list &input_vars,
+ const std::unordered_set<at::TensorImpl*> &non_differentiable,
+ const std::unordered_set<at::TensorImpl*> &dirty_inputs,
+ const at::ArrayRef<c10::optional<Variable>> raw_outputs,
+ const std::shared_ptr<Node> &cdata) {
+
+ std::unordered_set<at::TensorImpl*> inputs_set;
+ inputs_set.reserve(input_vars.size());
+ for (auto& var : input_vars) {
+ inputs_set.emplace(var.unsafeGetTensorImpl());
+ }
+
+ auto outputs = _process_backward_mode_ad(inputs_set, non_differentiable, dirty_inputs, raw_outputs, cdata);
+
+
+ return outputs;
+}
+
void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) {
if (!original.options().type_equal(result.options())) {
std::stringstream ss;
#define THPFunction_assert(condition, ...) \
if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); }
-namespace torch { namespace autograd {
+// Anonymous namespace for helpful functions used in this file
+namespace {
-void PyNode::throw_python_error() {
+// Throw a python_error with the PyErr state persisted, so that we
+// don't lose the error state if the GIL is released when we don't
+// have a PyThreadState created beforehand, this is made so that
+// even for pure C++ thread without a pre-created PyThreadState could
+// also capture the correct error message.
+// TODO: This is a temporary approach to allow C++ thread to correctly
+// capture Python Error in autograd, remove this when c10 thread pool
+// allow to do one time initialization.
+// see discussion in https://github.com/pytorch/pytorch/pull/34845
+// Follow up issue: https://github.com/pytorch/pytorch/issues/35006
+void throw_python_error() {
python_error err;
err.persist();
throw err;
}
+}
+
+namespace torch { namespace autograd {
+
// NOTE: this function is written in a way that assumes it's only called for backward;
// it's used by engine.cpp. This is responsible for forwarding a call from
// C++'s Node::apply to a Python method "apply".
variable_list apply(variable_list&& inputs) override;
- // Throw a python_error with the PyErr state persisted, so that we
- // don't lose the error state if the GIL is released when we don't
- // have a PyThreadState created beforehand, this is made so that
- // even for pure C++ thread without a pre-created PyThreadState could
- // also capture the correct error message.
- // TODO: This is a temporary approach to allow C++ thread to correctly
- // capture Python Error in autograd, remove this when c10 thread pool
- // allow to do one time initialization.
- // see discussion in https://github.com/pytorch/pytorch/pull/34845
- // Follow up issue: https://github.com/pytorch/pytorch/issues/35006
- void throw_python_error();
void release_variables() override;
std::string name() const override;
bool is_traceable() override;