From 99e28baeba4f1ffb2623e64694b2aac13df5e0fb Mon Sep 17 00:00:00 2001 From: albanD Date: Fri, 20 Aug 2021 08:42:31 -0700 Subject: [PATCH] Small custom function refactor which doesn't change anything (#63433) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63433 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D30431970 Pulled By: albanD fbshipit-source-id: 905fa4d2ddeca18005b1bcb13dd6f8a080327e7c --- torch/csrc/autograd/custom_function.cpp | 34 ++++++++++++++++++++++++--------- torch/csrc/autograd/custom_function.h | 2 ++ torch/csrc/autograd/python_function.cpp | 19 ++++++++++++++++-- torch/csrc/autograd/python_function.h | 11 ----------- 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 502919f..fdcf997 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -26,17 +26,13 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { } } -std::vector> _wrap_outputs(const variable_list &input_vars, +optional_variable_list _process_backward_mode_ad( + const std::unordered_set &inputs_set, const std::unordered_set &non_differentiable, const std::unordered_set &dirty_inputs, const at::ArrayRef> raw_outputs, const std::shared_ptr &cdata) { - std::unordered_set inputs; - inputs.reserve(input_vars.size()); - for (auto& var : input_vars) { - inputs.emplace(var.unsafeGetTensorImpl()); - } int num_outputs = raw_outputs.size(); @@ -63,7 +59,7 @@ std::vector> _wrap_outputs(const variable_list &input_va // Here, `y` requires_grad (!). } else if (is_modified) { if (var.is_leaf() && var.requires_grad()) { - throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation."); + TORCH_CHECK(false, "a leaf Variable that requires grad has been used in an in-place operation."); } // No need to mark as modified Tensors that are not inputs. if (!is_input) { @@ -105,7 +101,7 @@ std::vector> _wrap_outputs(const variable_list &input_va } }; - std::vector> outputs; + optional_variable_list outputs; std::unordered_set outputs_impl; // For dirty_inputs check outputs.reserve(num_outputs); int num_diff_outputs = 0; @@ -125,7 +121,7 @@ std::vector> _wrap_outputs(const variable_list &input_va Variable var = raw_outputs[i].value(); auto out_tensor_impl = var.unsafeGetTensorImpl(); - bool is_input = inputs.count(out_tensor_impl) > 0; + bool is_input = inputs_set.count(out_tensor_impl) > 0; bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0 && isDifferentiableType(var.scalar_type()); @@ -177,6 +173,26 @@ std::vector> _wrap_outputs(const variable_list &input_va return outputs; } + + +optional_variable_list _wrap_outputs(const variable_list &input_vars, + const std::unordered_set &non_differentiable, + const std::unordered_set &dirty_inputs, + const at::ArrayRef> raw_outputs, + const std::shared_ptr &cdata) { + + std::unordered_set inputs_set; + inputs_set.reserve(input_vars.size()); + for (auto& var : input_vars) { + inputs_set.emplace(var.unsafeGetTensorImpl()); + } + + auto outputs = _process_backward_mode_ad(inputs_set, non_differentiable, dirty_inputs, raw_outputs, cdata); + + + return outputs; +} + void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) { if (!original.options().type_equal(result.options())) { std::stringstream ss; diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 243622f..376cab6 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -9,6 +9,8 @@ namespace torch { namespace autograd { +using optional_variable_list = std::vector>; + TORCH_API std::vector> _wrap_outputs( const variable_list &input_vars, const std::unordered_set &non_differentiable, diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index dd58a68..1487418 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -45,14 +45,29 @@ PyObject *THPFunctionClass = nullptr; #define THPFunction_assert(condition, ...) \ if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); } -namespace torch { namespace autograd { +// Anonymous namespace for helpful functions used in this file +namespace { -void PyNode::throw_python_error() { +// Throw a python_error with the PyErr state persisted, so that we +// don't lose the error state if the GIL is released when we don't +// have a PyThreadState created beforehand, this is made so that +// even for pure C++ thread without a pre-created PyThreadState could +// also capture the correct error message. +// TODO: This is a temporary approach to allow C++ thread to correctly +// capture Python Error in autograd, remove this when c10 thread pool +// allow to do one time initialization. +// see discussion in https://github.com/pytorch/pytorch/pull/34845 +// Follow up issue: https://github.com/pytorch/pytorch/issues/35006 +void throw_python_error() { python_error err; err.persist(); throw err; } +} + +namespace torch { namespace autograd { + // NOTE: this function is written in a way that assumes it's only called for backward; // it's used by engine.cpp. This is responsible for forwarding a call from // C++'s Node::apply to a Python method "apply". diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 8f4d12b..3657807 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -27,17 +27,6 @@ struct PyNode : public Node { variable_list apply(variable_list&& inputs) override; - // Throw a python_error with the PyErr state persisted, so that we - // don't lose the error state if the GIL is released when we don't - // have a PyThreadState created beforehand, this is made so that - // even for pure C++ thread without a pre-created PyThreadState could - // also capture the correct error message. - // TODO: This is a temporary approach to allow C++ thread to correctly - // capture Python Error in autograd, remove this when c10 thread pool - // allow to do one time initialization. - // see discussion in https://github.com/pytorch/pytorch/pull/34845 - // Follow up issue: https://github.com/pytorch/pytorch/issues/35006 - void throw_python_error(); void release_variables() override; std::string name() const override; bool is_traceable() override; -- 2.7.4