Move variabletype functions around (#63330)
authorsoulitzer <soulitzer@gmail.com>
Thu, 26 Aug 2021 23:00:21 +0000 (16:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 23:02:39 +0000 (16:02 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63330

 - This is in preparation for templated/boxed autograd-not-implemented fallback
 - Make sure VariableTypeUtils does not depend on generated code
 - Lift `isFwGradDefined` into `autograd/functions/utils.cpp` so it's available to mobile builds
 - Removes `using namespace at` from VariableTypeUtils, previously we needed this for Templated version, but now its not strictly necessary but still a good change to avoid name conflicts if this header is included elsewhere in the future.

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D30518573

Pulled By: soulitzer

fbshipit-source-id: a0fb904baafc9713de609fffec4b813f6cfcc000

tools/autograd/templates/VariableType.cpp
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/autograd/VariableTypeUtils.h
torch/csrc/autograd/functions/utils.h

index 1ff3604..605a700 100644 (file)
@@ -1,4 +1,5 @@
 #include "torch/csrc/autograd/VariableTypeUtils.h"
+#include "torch/csrc/autograd/generated/VariableType.h"
 #include "torch/csrc/autograd/FunctionsManual.h"
 
 #include <ATen/RedispatchFunctions.h>
index 86639c1..95170f0 100644 (file)
@@ -1,5 +1,7 @@
 #include <torch/csrc/autograd/FunctionsManual.h>
 #include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/autograd/functions/utils.h>
+
 
 #include <ATen/ATen.h>
 #include <ATen/AccumulateType.h>
@@ -44,10 +46,6 @@ bool isDefined(const c10::optional<Tensor>& t) {
   return t.has_value() && t->defined();
 }
 
-bool isFwGradDefined(const c10::optional<Tensor>& t) {
-  return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined();
-}
-
 Tensor toNonOptTensor(const c10::optional<Tensor>& t) {
   return t.has_value() ? *t : Tensor();
 }
index d397f55..31a972e 100644 (file)
@@ -31,7 +31,6 @@ struct IndexRangeGenerator {
     size_t i = 0;
 };
 
-bool isFwGradDefined(const c10::optional<Tensor>& t);
 Tensor toNonOptFwGrad(const c10::optional<Tensor>& t);
 Tensor toNonOptPrimal(const c10::optional<Tensor>& t);
 Tensor toNonOptTensor(const c10::optional<Tensor>& t);
index f409daa..25f05fc 100644 (file)
@@ -2,6 +2,7 @@
 #include <c10/core/ScalarType.h>
 #include <torch/csrc/autograd/VariableTypeUtils.h>
 #include <torch/csrc/autograd/FunctionsManual.h>
+#include <torch/csrc/autograd/functions/utils.h>
 #include <torch/csrc/utils/memory.h>
 #include <torch/csrc/autograd/autograd.h>
 #include <ATen/TracerMode.h>
@@ -100,7 +101,7 @@ Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor & self, int64_t level) {
   if (grad_fn) {
       set_history(flatten_tensor_args( result ), grad_fn);
   }
-  if (generated::details::isFwGradDefined(self)) {
+  if (isFwGradDefined(self)) {
     // Modified from original codegen
     // We explicitly want to ignore the forward grad at the given level
     TORCH_CHECK(level == 0, "Invalid level given to _fw_primal");
@@ -131,7 +132,7 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n
   rebase_history(self , std::move(grad_fn));
 
   if (isDifferentiableType(self.scalar_type()) &&
-      (generated::details::isFwGradDefined(self) || generated::details::isFwGradDefined(src))) {
+      (isFwGradDefined(self) || isFwGradDefined(src))) {
     auto self_fw_grad = generated::details::toNonOptFwGrad(self);
     auto src_fw_grad = generated::details::toNonOptFwGrad(src);
     Tensor new_fw_grad;
index bde2dc4..977e9e4 100644 (file)
@@ -1,14 +1,12 @@
 #pragma once
 
 #include <c10/util/irange.h>
-#include <torch/csrc/autograd/generated/VariableType.h>
 
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/autograd/edge.h>
 #include <torch/csrc/autograd/grad_mode.h>
 #include <torch/csrc/autograd/saved_variable.h>
-#include <torch/csrc/autograd/generated/Functions.h>
 #include <torch/csrc/autograd/functions/tensor.h>
 #include <torch/csrc/autograd/functions/basic_ops.h>
 #include <torch/csrc/jit/frontend/tracer.h>
@@ -35,9 +33,6 @@
 #endif
 #endif
 
-using namespace at;
-using namespace torch::autograd::generated;
-
 namespace torch { namespace autograd {
 
 // The requires_grad argument is used to know if the inplace operation needs
@@ -47,7 +42,7 @@ namespace torch { namespace autograd {
 // a = torch.rand(2)
 // b = torch.rand(2, requires_grad=True)
 // a.copy_(b)
-inline void check_inplace(const Tensor& tensor, bool requires_grad) {
+inline void check_inplace(const at::Tensor& tensor, bool requires_grad) {
   if (requires_grad && GradMode::is_enabled()) {
     auto diff_view_meta = impl::get_view_autograd_meta(tensor);
     if (diff_view_meta && diff_view_meta->has_bw_view()) {
@@ -65,7 +60,7 @@ inline void check_inplace(const Tensor& tensor, bool requires_grad) {
   }
 }
 
-inline void check_inplace(const TensorList tensors, bool requires_grad) {
+inline void check_inplace(const at::TensorList tensors, bool requires_grad) {
   for (const auto& tensor : tensors) {
     check_inplace(tensor, requires_grad);
   }
@@ -77,14 +72,14 @@ inline void throw_error_out_requires_grad(const char* name) {
       "but one of the arguments requires grad.");
 }
 
-inline void throw_error_for_complex_autograd(const Tensor& tensor, const char* name) {
+inline void throw_error_for_complex_autograd(const at::Tensor& tensor, const char* name) {
   if (tensor.requires_grad()) {
     TORCH_CHECK(!tensor.is_complex(), name,
                 " does not support automatic differentiation for outputs with complex dtype.");
   }
 }
 
-inline void throw_error_for_complex_autograd(const TensorList& tensorlist, const char* name) {
+inline void throw_error_for_complex_autograd(const at::TensorList& tensorlist, const char* name) {
   for (const auto& tensor: tensorlist) {
     throw_error_for_complex_autograd(tensor, name);
   }
@@ -114,7 +109,7 @@ inline void rebase_history(std::vector<Variable>&& vars, std::shared_ptr<Node> g
   }
 }
 
-inline void increment_version(const Tensor & t) {
+inline void increment_version(const at::Tensor & t) {
   impl::bump_version(t);
 }
 
@@ -138,8 +133,8 @@ template<typename... Args> inline variable_list flatten_tensor_args(Args&&... ar
 }
 
 // See NOTE [ Autograd View Variables ] for details.
-inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable,
-        bool is_fw_differentiable, std::function<Tensor(const Tensor&)> view_func=nullptr,
+inline at::Tensor as_view(const at::Tensor & base, const at::Tensor & tensor, bool is_bw_differentiable,
+        bool is_fw_differentiable, std::function<at::Tensor(const at::Tensor&)> view_func=nullptr,
         CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) {
   // Note [View of inference tensor]
   // For inference tensor this code can only be hit outside InferenceMode
@@ -202,7 +197,7 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif
 }
 
 // See NOTE [ Autograd View Variables ] for details.
-inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& tensors, bool is_bw_differentiable,
+inline std::vector<at::Tensor> as_view(const at::Tensor & base, std::vector<at::Tensor>& tensors, bool is_bw_differentiable,
                                    bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) {
   // See Note [View of inference tensor]
   if (base.is_inference()) return tensors;
@@ -228,7 +223,7 @@ inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& ten
       new_shared_info = ViewInfo(base, /* view_func */ nullptr);
     }
 
-    for(Tensor &tensor : tensors) {
+    for(at::Tensor &tensor : tensors) {
       if (is_fw_differentiable || is_bw_differentiable) {
         tensor = make_variable_differentiable_view(tensor, new_shared_info, c10::nullopt, /*shared_view_info*/ true, creation_meta);
       } else {
@@ -282,7 +277,7 @@ inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& ten
     creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
   }
 
-  for(Tensor &tensor : tensors) {
+  for(at::Tensor &tensor : tensors) {
     if (is_fw_differentiable || is_bw_differentiable) {
       tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, /*shared_view_info*/ false, creation_meta);
     } else {
@@ -292,20 +287,20 @@ inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& ten
   return tensors;
 }
 
-inline void check_no_requires_grad(const Tensor& tensor, const char* name,
+inline void check_no_requires_grad(const at::Tensor& tensor, const char* name,
                                    const char* fn_name="", bool check_grad_mode=true) {
   TORCH_CHECK(!(tensor.defined() && tensor.requires_grad()) || !(check_grad_mode && GradMode::is_enabled()),
               "The function '", fn_name, "' is not differentiable with respect to argument '", name,
               "'. This input cannot have requires_grad True.");
 }
 
-inline void check_no_requires_grad(const c10::optional<Tensor>& tensor, const char* name, const char* fn_name="") {
+inline void check_no_requires_grad(const c10::optional<at::Tensor>& tensor, const char* name, const char* fn_name="") {
   if (tensor.has_value()) {
     check_no_requires_grad(*tensor, name, fn_name);
   }
 }
 
-inline void check_no_requires_grad(TensorList tensors, const char* name, const char* fn_name="") {
+inline void check_no_requires_grad(at::TensorList tensors, const char* name, const char* fn_name="") {
   // GradMode check is expensive, so check it only once for TensorLists
   if (!GradMode::is_enabled()) {
     return;
@@ -315,12 +310,12 @@ inline void check_no_requires_grad(TensorList tensors, const char* name, const c
   }
 }
 
-inline void check_no_requires_grad(const c10::List<c10::optional<Tensor>>& tensors, const char* name, const char* fn_name="") {
+inline void check_no_requires_grad(const c10::List<c10::optional<at::Tensor>>& tensors, const char* name, const char* fn_name="") {
   // GradMode check is expensive, so check it only once for TensorLists
   if (!GradMode::is_enabled()) {
     return;
   }
-  for (c10::optional<Tensor> tensor : tensors) {
+  for (c10::optional<at::Tensor> tensor : tensors) {
     if (tensor.has_value()) {
       check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false);
     }
@@ -328,23 +323,23 @@ inline void check_no_requires_grad(const c10::List<c10::optional<Tensor>>& tenso
 }
 
 // Assumed that saved tensor lists are never inplace outputs
-inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
-  return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
+inline std::vector<SavedVariable> make_saved_variable_list(at::TensorList tensors) {
+  return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable {
       return SavedVariable{tensor, false /* is output */}; });
 }
 
 // Assumed that saved tensor lists are never inplace outputs
 inline std::vector<SavedVariable> make_saved_variable_list(const c10::List<c10::optional<at::Tensor>>& tensors) {
-  return fmap(tensors, [](const c10::optional<Tensor>& tensor) -> SavedVariable {
+  return fmap(tensors, [](const c10::optional<at::Tensor>& tensor) -> SavedVariable {
     if (tensor.has_value()) {
       return SavedVariable{*tensor, false /* is output */};
     } else {
-      return SavedVariable{Tensor(), false /* is output */};
+      return SavedVariable{at::Tensor(), false /* is output */};
     }
   });
 }
 
-inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
+inline std::vector<std::vector<int64_t>> to_args_sizes(at::TensorList tensors) {
   std::vector<std::vector<int64_t>> args_sizes(tensors.size());
   for (const auto i : c10::irange(tensors.size())) {
     args_sizes[i] = tensors[i].sizes().vec();
@@ -352,11 +347,12 @@ inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
   return args_sizes;
 }
 
-inline std::vector<ScalarType> to_args_scalartypes(TensorList tensors) {
-  std::vector<ScalarType> args_scalartypes(tensors.size());
+inline std::vector<c10::ScalarType> to_args_scalartypes(at::TensorList tensors) {
+  std::vector<c10::ScalarType> args_scalartypes(tensors.size());
   for (const auto i : c10::irange(tensors.size())) {
     args_scalartypes[i] = tensors[i].scalar_type();
   }
   return args_scalartypes;
 }
+
 }} // namespace torch::autograd
index 90811e2..331db5d 100644 (file)
@@ -86,4 +86,9 @@ inline void set_history(
     set_history(variable, grad_fn);
   }
 }
+
+inline bool isFwGradDefined(const c10::optional<at::Tensor>& t) {
+  return t.has_value() && t->defined() && t->_fw_grad(/*level */ 0).defined();
+}
+
 }}