#include <c10/util/Optional.h>
#include <c10/util/Flags.h>
#include <c10/util/Logging.h>
+#include <c10/util/python_stub.h>
// A global boolean variable to control whether we free memory when a Tensor
// is shrinked to a smaller size. As a result, a Tensor is always going to
return impl;
}
+ inline void set_pyobj(PyObject* pyobj) noexcept {
+ pyobj_ = pyobj;
+ }
+
+ inline PyObject* pyobj() const noexcept {
+ return pyobj_;
+ }
+
private:
// As an optimization, get_device handles the typical CUDA Tensor case and
// calls get_device_slow if the tensor stores its device somewhere else
// at a time).
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
+ PyObject* pyobj_ = nullptr; // weak reference
+
// We could save a word or two by combining the SmallVector structs,
// since their size is redundant, and if we need to overflow the buffer space
// we could keep the two pointers together. However, that would require
// numel
// data type pointer
// autograd metadata pointer
+// PyObject pointer
// miscellaneous bitfield
//
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
- sizeof(TensorImpl) == sizeof(int64_t) * 25,
+ sizeof(TensorImpl) == sizeof(int64_t) * 26,
"You changed the size of TensorImpl on 64-bit arch."
"See Note [TensorImpl size constraints] on how to proceed.");
autograd_meta->requires_grad_ = false;
autograd_meta->is_view_ = false;
autograd_meta->output_nr_ = gradient_edge.input_nr;
- autograd_meta->pyobj_ = nullptr;
// set_requires_grad also checks error conditions.
autograd_meta->set_requires_grad(requires_grad, this);
// We use this to make sure we can setup the backwards trace
// correctly when this variable is passed to another function.
uint32_t output_nr_;
- PyObject* pyobj_ = nullptr; // weak reference
// Mutex to ensure that concurrent read operations that modify internal
// state are still thread-safe. Used by grad_fn() and
}
inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
- get_autograd_meta()->pyobj_ = pyobj;
+ get()->set_pyobj(pyobj);
}
inline PyObject* Variable::pyobj() const noexcept {
- return get_autograd_meta()->pyobj_;
+ return get()->pyobj();
}
inline Variable::AutogradMeta* Variable::get_autograd_meta() const noexcept {