#pragma once
#include <c10/util/irange.h>
-#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/saved_variable.h>
-#include <torch/csrc/autograd/generated/Functions.h>
#include <torch/csrc/autograd/functions/tensor.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/jit/frontend/tracer.h>
#endif
#endif
-using namespace at;
-using namespace torch::autograd::generated;
-
namespace torch { namespace autograd {
// The requires_grad argument is used to know if the inplace operation needs
// a = torch.rand(2)
// b = torch.rand(2, requires_grad=True)
// a.copy_(b)
-inline void check_inplace(const Tensor& tensor, bool requires_grad) {
+inline void check_inplace(const at::Tensor& tensor, bool requires_grad) {
if (requires_grad && GradMode::is_enabled()) {
auto diff_view_meta = impl::get_view_autograd_meta(tensor);
if (diff_view_meta && diff_view_meta->has_bw_view()) {
}
}
-inline void check_inplace(const TensorList tensors, bool requires_grad) {
+inline void check_inplace(const at::TensorList tensors, bool requires_grad) {
for (const auto& tensor : tensors) {
check_inplace(tensor, requires_grad);
}
"but one of the arguments requires grad.");
}
-inline void throw_error_for_complex_autograd(const Tensor& tensor, const char* name) {
+inline void throw_error_for_complex_autograd(const at::Tensor& tensor, const char* name) {
if (tensor.requires_grad()) {
TORCH_CHECK(!tensor.is_complex(), name,
" does not support automatic differentiation for outputs with complex dtype.");
}
}
-inline void throw_error_for_complex_autograd(const TensorList& tensorlist, const char* name) {
+inline void throw_error_for_complex_autograd(const at::TensorList& tensorlist, const char* name) {
for (const auto& tensor: tensorlist) {
throw_error_for_complex_autograd(tensor, name);
}
}
}
-inline void increment_version(const Tensor & t) {
+inline void increment_version(const at::Tensor & t) {
impl::bump_version(t);
}
}
// See NOTE [ Autograd View Variables ] for details.
-inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_differentiable,
- bool is_fw_differentiable, std::function<Tensor(const Tensor&)> view_func=nullptr,
+inline at::Tensor as_view(const at::Tensor & base, const at::Tensor & tensor, bool is_bw_differentiable,
+ bool is_fw_differentiable, std::function<at::Tensor(const at::Tensor&)> view_func=nullptr,
CreationMeta creation_meta=CreationMeta::DEFAULT, bool allow_tensor_metadata_change=true) {
// Note [View of inference tensor]
// For inference tensor this code can only be hit outside InferenceMode
}
// See NOTE [ Autograd View Variables ] for details.
-inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& tensors, bool is_bw_differentiable,
+inline std::vector<at::Tensor> as_view(const at::Tensor & base, std::vector<at::Tensor>& tensors, bool is_bw_differentiable,
bool is_fw_differentiable, CreationMeta creation_meta=CreationMeta::DEFAULT) {
// See Note [View of inference tensor]
if (base.is_inference()) return tensors;
new_shared_info = ViewInfo(base, /* view_func */ nullptr);
}
- for(Tensor &tensor : tensors) {
+ for(at::Tensor &tensor : tensors) {
if (is_fw_differentiable || is_bw_differentiable) {
tensor = make_variable_differentiable_view(tensor, new_shared_info, c10::nullopt, /*shared_view_info*/ true, creation_meta);
} else {
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
}
- for(Tensor &tensor : tensors) {
+ for(at::Tensor &tensor : tensors) {
if (is_fw_differentiable || is_bw_differentiable) {
tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, /*shared_view_info*/ false, creation_meta);
} else {
return tensors;
}
-inline void check_no_requires_grad(const Tensor& tensor, const char* name,
+inline void check_no_requires_grad(const at::Tensor& tensor, const char* name,
const char* fn_name="", bool check_grad_mode=true) {
TORCH_CHECK(!(tensor.defined() && tensor.requires_grad()) || !(check_grad_mode && GradMode::is_enabled()),
"The function '", fn_name, "' is not differentiable with respect to argument '", name,
"'. This input cannot have requires_grad True.");
}
-inline void check_no_requires_grad(const c10::optional<Tensor>& tensor, const char* name, const char* fn_name="") {
+inline void check_no_requires_grad(const c10::optional<at::Tensor>& tensor, const char* name, const char* fn_name="") {
if (tensor.has_value()) {
check_no_requires_grad(*tensor, name, fn_name);
}
}
-inline void check_no_requires_grad(TensorList tensors, const char* name, const char* fn_name="") {
+inline void check_no_requires_grad(at::TensorList tensors, const char* name, const char* fn_name="") {
// GradMode check is expensive, so check it only once for TensorLists
if (!GradMode::is_enabled()) {
return;
}
}
-inline void check_no_requires_grad(const c10::List<c10::optional<Tensor>>& tensors, const char* name, const char* fn_name="") {
+inline void check_no_requires_grad(const c10::List<c10::optional<at::Tensor>>& tensors, const char* name, const char* fn_name="") {
// GradMode check is expensive, so check it only once for TensorLists
if (!GradMode::is_enabled()) {
return;
}
- for (c10::optional<Tensor> tensor : tensors) {
+ for (c10::optional<at::Tensor> tensor : tensors) {
if (tensor.has_value()) {
check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false);
}
}
// Assumed that saved tensor lists are never inplace outputs
-inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
- return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
+inline std::vector<SavedVariable> make_saved_variable_list(at::TensorList tensors) {
+ return fmap(tensors, [](const at::Tensor& tensor) -> SavedVariable {
return SavedVariable{tensor, false /* is output */}; });
}
// Assumed that saved tensor lists are never inplace outputs
inline std::vector<SavedVariable> make_saved_variable_list(const c10::List<c10::optional<at::Tensor>>& tensors) {
- return fmap(tensors, [](const c10::optional<Tensor>& tensor) -> SavedVariable {
+ return fmap(tensors, [](const c10::optional<at::Tensor>& tensor) -> SavedVariable {
if (tensor.has_value()) {
return SavedVariable{*tensor, false /* is output */};
} else {
- return SavedVariable{Tensor(), false /* is output */};
+ return SavedVariable{at::Tensor(), false /* is output */};
}
});
}
-inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
+inline std::vector<std::vector<int64_t>> to_args_sizes(at::TensorList tensors) {
std::vector<std::vector<int64_t>> args_sizes(tensors.size());
for (const auto i : c10::irange(tensors.size())) {
args_sizes[i] = tensors[i].sizes().vec();
return args_sizes;
}
-inline std::vector<ScalarType> to_args_scalartypes(TensorList tensors) {
- std::vector<ScalarType> args_scalartypes(tensors.size());
+inline std::vector<c10::ScalarType> to_args_scalartypes(at::TensorList tensors) {
+ std::vector<c10::ScalarType> args_scalartypes(tensors.size());
for (const auto i : c10::irange(tensors.size())) {
args_scalartypes[i] = tensors[i].scalar_type();
}
return args_scalartypes;
}
+
}} // namespace torch::autograd