Small custom function refactor which doesn't change anything (#63433)
authoralbanD <desmaison.alban@gmail.com>
Fri, 20 Aug 2021 15:42:31 +0000 (08:42 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 15:44:23 +0000 (08:44 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63433

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D30431970

Pulled By: albanD

fbshipit-source-id: 905fa4d2ddeca18005b1bcb13dd6f8a080327e7c

torch/csrc/autograd/custom_function.cpp
torch/csrc/autograd/custom_function.h
torch/csrc/autograd/python_function.cpp
torch/csrc/autograd/python_function.h

index 502919f..fdcf997 100644 (file)
@@ -26,17 +26,13 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
   }
 }
 
-std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_vars,
+optional_variable_list _process_backward_mode_ad(
+  const std::unordered_set<at::TensorImpl*> &inputs_set,
   const std::unordered_set<at::TensorImpl*> &non_differentiable,
   const std::unordered_set<at::TensorImpl*> &dirty_inputs,
   const at::ArrayRef<c10::optional<Variable>> raw_outputs,
   const std::shared_ptr<Node> &cdata) {
 
-  std::unordered_set<at::TensorImpl*> inputs;
-  inputs.reserve(input_vars.size());
-  for (auto& var : input_vars) {
-    inputs.emplace(var.unsafeGetTensorImpl());
-  }
 
   int num_outputs = raw_outputs.size();
 
@@ -63,7 +59,7 @@ std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_va
       // Here, `y` requires_grad (!).
     } else if (is_modified) {
       if (var.is_leaf() && var.requires_grad()) {
-        throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
+        TORCH_CHECK(false, "a leaf Variable that requires grad has been used in an in-place operation.");
       }
       // No need to mark as modified Tensors that are not inputs.
       if (!is_input) {
@@ -105,7 +101,7 @@ std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_va
     }
   };
 
-  std::vector<c10::optional<Variable>> outputs;
+  optional_variable_list outputs;
   std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
   outputs.reserve(num_outputs);
   int num_diff_outputs = 0;
@@ -125,7 +121,7 @@ std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_va
     Variable var = raw_outputs[i].value();
 
     auto out_tensor_impl = var.unsafeGetTensorImpl();
-    bool is_input = inputs.count(out_tensor_impl) > 0;
+    bool is_input = inputs_set.count(out_tensor_impl) > 0;
     bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
     bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0
                               && isDifferentiableType(var.scalar_type());
@@ -177,6 +173,26 @@ std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_va
   return outputs;
 }
 
+
+
+optional_variable_list _wrap_outputs(const variable_list &input_vars,
+  const std::unordered_set<at::TensorImpl*> &non_differentiable,
+  const std::unordered_set<at::TensorImpl*> &dirty_inputs,
+  const at::ArrayRef<c10::optional<Variable>> raw_outputs,
+  const std::shared_ptr<Node> &cdata) {
+
+  std::unordered_set<at::TensorImpl*> inputs_set;
+  inputs_set.reserve(input_vars.size());
+  for (auto& var : input_vars) {
+    inputs_set.emplace(var.unsafeGetTensorImpl());
+  }
+
+  auto outputs = _process_backward_mode_ad(inputs_set, non_differentiable, dirty_inputs, raw_outputs, cdata);
+
+
+  return outputs;
+}
+
 void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) {
   if (!original.options().type_equal(result.options())) {
     std::stringstream ss;
index 243622f..376cab6 100644 (file)
@@ -9,6 +9,8 @@
 
 namespace torch { namespace autograd {
 
+using optional_variable_list = std::vector<c10::optional<Variable>>;
+
 TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
   const variable_list &input_vars,
   const std::unordered_set<at::TensorImpl*> &non_differentiable,
index dd58a68..1487418 100644 (file)
@@ -45,14 +45,29 @@ PyObject *THPFunctionClass = nullptr;
 #define THPFunction_assert(condition, ...)                                     \
   if (!(condition)) { THPUtils_setError(__VA_ARGS__); throw python_error(); }
 
-namespace torch { namespace autograd {
+// Anonymous namespace for helpful functions used in this file
+namespace {
 
-void PyNode::throw_python_error() {
+// Throw a python_error with the PyErr state persisted, so that we
+// don't lose the error state if the GIL is released when we don't
+// have a PyThreadState created beforehand, this is made so that
+// even for pure C++ thread without a pre-created PyThreadState could
+// also capture the correct error message.
+// TODO: This is a temporary approach to allow C++ thread to correctly
+// capture Python Error in autograd, remove this when c10 thread pool
+// allow to do one time initialization.
+// see discussion in https://github.com/pytorch/pytorch/pull/34845
+// Follow up issue: https://github.com/pytorch/pytorch/issues/35006
+void throw_python_error() {
   python_error err;
   err.persist();
   throw err;
 }
 
+}
+
+namespace torch { namespace autograd {
+
 // NOTE: this function is written in a way that assumes it's only called for backward;
 // it's used by engine.cpp.  This is responsible for forwarding a call from
 // C++'s Node::apply to a Python method "apply".
index 8f4d12b..3657807 100644 (file)
@@ -27,17 +27,6 @@ struct PyNode : public Node {
 
   variable_list apply(variable_list&& inputs) override;
 
-  // Throw a python_error with the PyErr state persisted, so that we
-  // don't lose the error state if the GIL is released when we don't
-  // have a PyThreadState created beforehand, this is made so that
-  // even for pure C++ thread without a pre-created PyThreadState could
-  // also capture the correct error message.
-  // TODO: This is a temporary approach to allow C++ thread to correctly
-  // capture Python Error in autograd, remove this when c10 thread pool
-  // allow to do one time initialization.
-  // see discussion in https://github.com/pytorch/pytorch/pull/34845
-  // Follow up issue: https://github.com/pytorch/pytorch/issues/35006
-  void throw_python_error();
   void release_variables() override;
   std::string name() const override;
   bool is_traceable() override;