Revert D14909203: Remove usages of TypeID
authorPieter Noordhuis <pietern@fb.com>
Fri, 19 Apr 2019 00:44:37 +0000 (17:44 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 00:47:39 +0000 (17:47 -0700)
Differential Revision:
D14909203

Original commit changeset: d716179c484a

fbshipit-source-id: 992ff1fcd6d35d3f2ae768c7e164b7a0ba871914

aten/src/ATen/native/NNPACK.cpp
torch/csrc/cuda/Module.cpp
torch/csrc/nn/type_checks.h
torch/csrc/utils/tensor_flatten.cpp

index d513bb7..25cb444 100644 (file)
@@ -213,10 +213,10 @@ Tensor _nnpack_spatial_convolution(
   auto algorithm = nnp_convolution_algorithm_auto;
 
   // All Tensors must be float Tensors
-  if (input.device().type() != kCPU || input.scalar_type() != kFloat ||
-      weight.device().type() != kCPU || weight.scalar_type() != kFloat ||
-      output.device().type() != kCPU || output.scalar_type() != kFloat ||
-      (bias.defined() && (bias.device().type() != kCPU || bias.scalar_type() != kFloat))) {
+  if (input.dispatch_type().ID() != at::TypeID::CPUFloat ||
+      weight.dispatch_type().ID() != at::TypeID::CPUFloat ||
+      output.dispatch_type().ID() != at::TypeID::CPUFloat ||
+      (bias.defined() && bias.dispatch_type().ID() != at::TypeID::CPUFloat)) {
     throw std::runtime_error(
         "Mismatched Tensor types in NNPack convolutionOutput");
   }
index ac46b6c..2860112 100644 (file)
@@ -145,8 +145,7 @@ PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj)
 {
   HANDLE_TH_ERRORS
   if (!THPVariable_Check(obj) ||
-      THPVariable_Unpack(obj).type_id() != at::CPUTensorId() ||
-      THPVariable_Unpack(obj).scalar_type() != at::kByte) {
+      at::globalContext().getNonVariableType(THPVariable_Unpack(obj).type().backend(), THPVariable_Unpack(obj).scalar_type()).ID() != at::TypeID::CPUByte) {
     throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s",
         Py_TYPE(obj)->tp_name);
   }
index d8903d4..2a00caf 100644 (file)
@@ -9,10 +9,10 @@
 
 namespace torch { namespace nn {
 
-inline bool check_type(PyObject* obj, at::TensorTypeId id, at::ScalarType dtype) {
+inline bool check_type(PyObject* obj, at::TypeID typeID) {
   if (THPVariable_Check(obj)) {
     auto& tensor = ((THPVariable*)obj)->cdata;
-    return tensor.type_id() == id && tensor.dtype() == dtype;
+    return at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type()).ID() == typeID;
   }
   return false;
 }
@@ -38,19 +38,19 @@ static inline int get_device(PyObject* args) {
 }
 
 static inline bool THNN_FloatTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CPUTensorId(), at::kFloat);
+  return torch::nn::check_type(obj, at::TypeID::CPUFloat);
 }
 
 static inline bool THNN_DoubleTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CPUTensorId(), at::kDouble);
+  return torch::nn::check_type(obj, at::TypeID::CPUDouble);
 }
 
 static inline bool THNN_LongTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CPUTensorId(), at::kLong);
+  return torch::nn::check_type(obj, at::TypeID::CPULong);
 }
 
 static inline bool THNN_IntTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CPUTensorId(), at::kInt);
+  return torch::nn::check_type(obj, at::TypeID::CPUInt);
 }
 
 static inline THFloatTensor* THNN_FloatTensor_Unpack(PyObject* obj) {
@@ -72,19 +72,19 @@ static inline THIntTensor* THNN_IntTensor_Unpack(PyObject* obj) {
 #ifdef USE_CUDA
 
 static inline bool THNN_CudaHalfTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CUDATensorId(), at::kHalf);
+  return torch::nn::check_type(obj, at::TypeID::CUDAHalf);
 }
 
 static inline bool THNN_CudaFloatTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CUDATensorId(), at::kFloat);
+  return torch::nn::check_type(obj, at::TypeID::CUDAFloat);
 }
 
 static inline bool THNN_CudaDoubleTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CUDATensorId(), at::kDouble);
+  return torch::nn::check_type(obj, at::TypeID::CUDADouble);
 }
 
 static inline bool THNN_CudaLongTensor_Check(PyObject* obj) {
-  return torch::nn::check_type(obj, at::CUDATensorId(), at::kLong);
+  return torch::nn::check_type(obj, at::TypeID::CUDALong);
 }
 
 static inline THCudaHalfTensor* THNN_CudaHalfTensor_Unpack(PyObject* obj) {
index 64eab94..0423cbc 100644 (file)
@@ -14,7 +14,7 @@ std::vector<TensorGroup> take_tensors(
   std::vector<TensorGroup> results;
   // an overapproximation, but at least we won't have to copy stuff around
   results.reserve(tensors.size());
-  std::unordered_map<DeprecatedTypeProperties*, TensorGroup> groups;
+  std::map<TypeID, TensorGroup> groups;
   size_t cur_group_size = 0;
 
   for (const auto & tensor : tensors) {
@@ -28,7 +28,7 @@ std::vector<TensorGroup> take_tensors(
       tensor_size = tensor.numel() * tensor.element_size();
     }
 
-    auto& type_group = groups[&tensor.type()];
+    auto& type_group = groups[tensor.dispatch_type().ID()];
     type_group.tensors.push_back(tensor);
 
     if (fine_grained) {