Fix Python device type property for XLA and MSNPU
authorAlex Şuhan <asuhan@google.com>
Thu, 28 Feb 2019 21:28:17 +0000 (13:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 21:36:19 +0000 (13:36 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17361

Differential Revision: D14243546

Pulled By: soumith

fbshipit-source-id: b7498968f72e3d97de5bf6e5b44c5a59b6913acb

torch/csrc/autograd/python_variable.cpp
torch/csrc/autograd/utils/python_arg_parsing.h
torch/csrc/tensor/python_tensor.cpp
torch/csrc/tensor/python_tensor.h

index c4ed2ab..1d47a0e 100644 (file)
@@ -404,7 +404,7 @@ static PyObject * THPVariable_layout(THPVariable* self) {
 
 static PyObject * THPVariable_device(THPVariable* self) {
   HANDLE_TH_ERRORS
-  return THPDevice_New(torch::tensors::getDevice(self->cdata));
+  return THPDevice_New(self->cdata.device());
   END_HANDLE_TH_ERRORS
 }
 
index baf769f..5a4e6c2 100644 (file)
@@ -31,7 +31,7 @@ inline std::tuple<c10::optional<at::Device>, c10::optional<at::ScalarType>, bool
     if (!allow_copy && !r.isNone(2))
       throw std::runtime_error(".to() does not accept copy argument");
     return std::make_tuple(
-      torch::tensors::getDevice(tensor),
+      tensor.device(),
       tensor.type().scalarType(),
       r.toBool(1),
       r.toBool(2)
index 54bda65..f2eceb3 100644 (file)
@@ -387,11 +387,4 @@ at::Type& get_default_tensor_type() {
   AT_ASSERT(default_tensor_type);
   return *default_tensor_type;
 }
-
-Device getDevice(const at::Tensor& tensor) {
-  if (tensor.is_cuda()) {
-    return at::Device(at::DeviceType::CUDA, tensor.get_device());
-  }
-  return at::Device(at::DeviceType::CPU);
-}
 }} // namespace torch::tensors
index 0ea1349..3449359 100644 (file)
@@ -31,7 +31,4 @@ void py_set_default_dtype(PyObject* dtype_obj);
 // returned value will be a VariableType instance.
 at::Type& get_default_tensor_type();
 
-// Gets the torch::Device object of a given at::Tensor
-c10::Device getDevice(const at::Tensor& tensor);
-
 }} // namespace torch::tensors