Unify device argument parsing between torch and c10
authorJunjie Bai <bai@in.tum.de>
Thu, 6 Dec 2018 02:35:21 +0000 (18:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 6 Dec 2018 02:37:32 +0000 (18:37 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14786

Differential Revision: D13334501

Pulled By: bddppq

fbshipit-source-id: ae3536be1fe0dcd6a1552ec93629ecc9554c0d7c

c10/Device.cpp
c10/Device.h
test/test_torch.py
torch/csrc/utils/python_arg_parser.h

index 44f14ea..c0bb507 100644 (file)
@@ -36,6 +36,13 @@ DeviceType parse_type(const std::string& device_string) {
 }
 } // namespace
 
+void Device::validate() {
+  AT_CHECK(index_ == -1 || index_ >= 0,
+           "Device index must be -1 or non-negative, got ", index_);
+  AT_CHECK(!is_cpu() || index_ <= 0,
+           "CPU device index must be -1 or zero, got ", index_);
+}
+
 // `std::regex` is still in a very incomplete state in GCC 4.8.x,
 // so we have to do our own parsing, like peasants.
 // https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions
@@ -64,24 +71,23 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
   int index = device_string.find(":");
   if (index == std::string::npos) {
     type_ = parse_type(device_string);
-    return;
   } else {
     std::string s;
     s = device_string.substr(0, index);
     AT_CHECK(!s.empty(), "Device string must not be empty");
     type_ = parse_type(s);
+
+    std::string device_index = device_string.substr(index + 1);
+    try {
+      index_ = c10::stoi(device_index);
+    } catch (const std::exception &) {
+      AT_ERROR("Could not parse device index '", device_index,
+               "' in device string '", device_string, "'");
+    }
+    AT_CHECK(index_ >= 0,
+             "Device index must be non-negative, got ", index_);
   }
-  std::string device_index = device_string.substr(index + 1);
-  try {
-    index_ = c10::stoi(device_index);
-  } catch (const std::exception&) {
-    AT_ERROR(
-        "Could not parse device index '",
-        device_index,
-        "' in device string '",
-        device_string,
-        "'");
-  }
+  validate();
 }
 
 std::ostream& operator<<(std::ostream& stream, const Device& device) {
index 3c9fafa..81c5cee 100644 (file)
@@ -34,14 +34,7 @@ struct C10_API Device final {
   /// index.
   /* implicit */ Device(DeviceType type, DeviceIndex index = -1)
       : type_(type), index_(index) {
-    AT_CHECK(
-        index == -1 || index >= 0,
-        "Device index must be -1 or non-negative, got ",
-        index);
-    AT_CHECK(
-        !is_cpu() || index <= 0,
-        "CPU device index must be -1 or zero, got ",
-        index);
+    validate();
   }
 
   /// Constructs a `Device` from a string description, for convenience.
@@ -96,6 +89,7 @@ struct C10_API Device final {
  private:
   DeviceType type_;
   DeviceIndex index_ = -1;
+  void validate();
 };
 
 C10_API std::ostream& operator<<(
index f9634dc..602bc6f 100644 (file)
@@ -2203,8 +2203,8 @@ class _TestTorchMixin(object):
         self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1))
         self.assertRaises(RuntimeError, lambda: torch.device(-1))
 
-        self.assertRaises(TypeError, lambda: torch.device('other'))
-        self.assertRaises(TypeError, lambda: torch.device('other:0'))
+        self.assertRaises(RuntimeError, lambda: torch.device('other'))
+        self.assertRaises(RuntimeError, lambda: torch.device('other:0'))
 
         device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
         device_hash_set = set()
index 70906fb..cb93bf1 100644 (file)
@@ -369,21 +369,8 @@ inline at::Device PythonArgs::device(int i) {
     AT_CHECK(device_index >= 0, "Device index must not be negative");
     return at::Device(at::DeviceType::CUDA, device_index);
   }
-  const std::string device_str = THPUtils_unpackString(args[i]);
-  if (device_str == cpu_str) {
-    return at::Device(at::DeviceType::CPU);
-  } else if (device_str == cuda_str) {
-    return at::Device(at::DeviceType::CUDA);
-  } else if (device_str.compare(0, cpu_prefix.length(), cpu_prefix) == 0) {
-    const auto device_index = std::stoi(device_str.substr(cpu_prefix.length()));
-    AT_CHECK(device_index >= 0, "Device index must not be negative");
-    return at::Device(at::DeviceType::CPU, device_index);
-  } else if (device_str.compare(0, cuda_prefix.length(), cuda_prefix) == 0) {
-    const auto device_index = std::stoi(device_str.substr(cuda_prefix.length()));
-    AT_CHECK(device_index >= 0, "Device index must not be negative");
-    return at::Device(at::DeviceType::CUDA, device_index);
-  }
-  throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
+  const std::string &device_str = THPUtils_unpackString(args[i]);
+  return at::Device(device_str);
 }
 
 inline at::Device PythonArgs::deviceWithDefault(int i, const at::Device& default_device) {