Remove usages of TypeID (#19183)
authorRoy Li <royboy@fb.com>
Thu, 18 Apr 2019 06:52:44 +0000 (23:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 06:55:47 +0000 (23:55 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19183
ghimport-source-id: 9af190b072523459fa61e5e79419b88ac8586a4d

Differential Revision: D14909203

Pulled By: li-roy

fbshipit-source-id: d716179c484aebfe3ec30087c5ecd4a11848ffc3

aten/src/ATen/native/NNPACK.cpp
torch/csrc/cuda/Module.cpp
torch/csrc/nn/type_checks.h
torch/csrc/utils/tensor_flatten.cpp

index 25cb444..d513bb7 100644 (file)
@@ -213,10 +213,10 @@ Tensor _nnpack_spatial_convolution(
   auto algorithm = nnp_convolution_algorithm_auto;
 
   // All Tensors must be float Tensors
-  if (input.dispatch_type().ID() != at::TypeID::CPUFloat ||
-      weight.dispatch_type().ID() != at::TypeID::CPUFloat ||
-      output.dispatch_type().ID() != at::TypeID::CPUFloat ||
-      (bias.defined() && bias.dispatch_type().ID() != at::TypeID::CPUFloat)) {
+  if (input.device().type() != kCPU || input.scalar_type() != kFloat ||
+      weight.device().type() != kCPU || weight.scalar_type() != kFloat ||
+      output.device().type() != kCPU || output.scalar_type() != kFloat ||
+      (bias.defined() && (bias.device().type() != kCPU || bias.scalar_type() != kFloat))) {
     throw std::runtime_error(
         "Mismatched Tensor types in NNPack convolutionOutput");
   }
index 2860112..ac46b6c 100644 (file)
@@ -145,7 +145,8 @@ PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj)
 {
   HANDLE_TH_ERRORS
   if (!THPVariable_Check(obj) ||
-      at::globalContext().getNonVariableType(THPVariable_Unpack(obj).type().backend(), THPVariable_Unpack(obj).scalar_type()).ID() != at::TypeID::CPUByte) {
+      THPVariable_Unpack(obj).type_id() != at::CPUTensorId() ||
+      THPVariable_Unpack(obj).scalar_type() != at::kByte) {
     throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s",
         Py_TYPE(obj)->tp_name);
   }
index 2a00caf..d8903d4 100644 (file)
@@ -9,10 +9,10 @@
 
 namespace torch { namespace nn {
 
-inline bool check_type(PyObject* obj, at::TypeID typeID) {
+inline bool check_type(PyObject* obj, at::TensorTypeId id, at::ScalarType dtype) {
   if (THPVariable_Check(obj)) {
     auto& tensor = ((THPVariable*)obj)->cdata;
-    return at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type()).ID() == typeID;
+    return tensor.type_id() == id && tensor.dtype() == dtype;
   }
   return false;
 }
@@ -38,19 +38,19 @@ static inline int get_device(PyObject* args) {
 }
 
 static inline bool THNN_FloatTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CPUFloat);
+  return torch::nn::check_type(obj, at::CPUTensorId(), at::kFloat);
 }
 
 static inline bool THNN_DoubleTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CPUDouble);
+  return torch::nn::check_type(obj, at::CPUTensorId(), at::kDouble);
 }
 
 static inline bool THNN_LongTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CPULong);
+  return torch::nn::check_type(obj, at::CPUTensorId(), at::kLong);
 }
 
 static inline bool THNN_IntTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CPUInt);
+  return torch::nn::check_type(obj, at::CPUTensorId(), at::kInt);
 }
 
 static inline THFloatTensor* THNN_FloatTensor_Unpack(PyObject* obj) {
@@ -72,19 +72,19 @@ static inline THIntTensor* THNN_IntTensor_Unpack(PyObject* obj) {
 #ifdef USE_CUDA
 
 static inline bool THNN_CudaHalfTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CUDAHalf);
+  return torch::nn::check_type(obj, at::CUDATensorId(), at::kHalf);
 }
 
 static inline bool THNN_CudaFloatTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CUDAFloat);
+  return torch::nn::check_type(obj, at::CUDATensorId(), at::kFloat);
 }
 
 static inline bool THNN_CudaDoubleTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CUDADouble);
+  return torch::nn::check_type(obj, at::CUDATensorId(), at::kDouble);
 }
 
 static inline bool THNN_CudaLongTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::TypeID::CUDALong);
+  return torch::nn::check_type(obj, at::CUDATensorId(), at::kLong);
 }
 
 static inline THCudaHalfTensor* THNN_CudaHalfTensor_Unpack(PyObject* obj) {
index 0423cbc..64eab94 100644 (file)
@@ -14,7 +14,7 @@ std::vector<TensorGroup> take_tensors(
   std::vector<TensorGroup> results;
   // an overapproximation, but at least we won't have to copy stuff around
   results.reserve(tensors.size());
-  std::map<TypeID, TensorGroup> groups;
+  std::unordered_map<DeprecatedTypeProperties*, TensorGroup> groups;
   size_t cur_group_size = 0;
 
   for (const auto & tensor : tensors) {
@@ -28,7 +28,7 @@ std::vector<TensorGroup> take_tensors(
       tensor_size = tensor.numel() * tensor.element_size();
     }
 
-    auto& type_group = groups[tensor.dispatch_type().ID()];
+    auto& type_group = groups[&tensor.type()];
     type_group.tensors.push_back(tensor);
 
     if (fine_grained) {