}
} // namespace
+void Device::validate() {
+ AT_CHECK(index_ == -1 || index_ >= 0,
+ "Device index must be -1 or non-negative, got ", index_);
+ AT_CHECK(!is_cpu() || index_ <= 0,
+ "CPU device index must be -1 or zero, got ", index_);
+}
+
// `std::regex` is still in a very incomplete state in GCC 4.8.x,
// so we have to do our own parsing, like peasants.
// https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions
int index = device_string.find(":");
if (index == std::string::npos) {
type_ = parse_type(device_string);
- return;
} else {
std::string s;
s = device_string.substr(0, index);
AT_CHECK(!s.empty(), "Device string must not be empty");
type_ = parse_type(s);
+
+ std::string device_index = device_string.substr(index + 1);
+ try {
+ index_ = c10::stoi(device_index);
+ } catch (const std::exception &) {
+ AT_ERROR("Could not parse device index '", device_index,
+ "' in device string '", device_string, "'");
+ }
+ AT_CHECK(index_ >= 0,
+ "Device index must be non-negative, got ", index_);
}
- std::string device_index = device_string.substr(index + 1);
- try {
- index_ = c10::stoi(device_index);
- } catch (const std::exception&) {
- AT_ERROR(
- "Could not parse device index '",
- device_index,
- "' in device string '",
- device_string,
- "'");
- }
+ validate();
}
std::ostream& operator<<(std::ostream& stream, const Device& device) {
/// index.
/* implicit */ Device(DeviceType type, DeviceIndex index = -1)
: type_(type), index_(index) {
- AT_CHECK(
- index == -1 || index >= 0,
- "Device index must be -1 or non-negative, got ",
- index);
- AT_CHECK(
- !is_cpu() || index <= 0,
- "CPU device index must be -1 or zero, got ",
- index);
+ validate();
}
/// Constructs a `Device` from a string description, for convenience.
private:
DeviceType type_;
DeviceIndex index_ = -1;
+ void validate();
};
C10_API std::ostream& operator<<(
self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1))
self.assertRaises(RuntimeError, lambda: torch.device(-1))
- self.assertRaises(TypeError, lambda: torch.device('other'))
- self.assertRaises(TypeError, lambda: torch.device('other:0'))
+ self.assertRaises(RuntimeError, lambda: torch.device('other'))
+ self.assertRaises(RuntimeError, lambda: torch.device('other:0'))
device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
device_hash_set = set()
AT_CHECK(device_index >= 0, "Device index must not be negative");
return at::Device(at::DeviceType::CUDA, device_index);
}
- const std::string device_str = THPUtils_unpackString(args[i]);
- if (device_str == cpu_str) {
- return at::Device(at::DeviceType::CPU);
- } else if (device_str == cuda_str) {
- return at::Device(at::DeviceType::CUDA);
- } else if (device_str.compare(0, cpu_prefix.length(), cpu_prefix) == 0) {
- const auto device_index = std::stoi(device_str.substr(cpu_prefix.length()));
- AT_CHECK(device_index >= 0, "Device index must not be negative");
- return at::Device(at::DeviceType::CPU, device_index);
- } else if (device_str.compare(0, cuda_prefix.length(), cuda_prefix) == 0) {
- const auto device_index = std::stoi(device_str.substr(cuda_prefix.length()));
- AT_CHECK(device_index >= 0, "Device index must not be negative");
- return at::Device(at::DeviceType::CUDA, device_index);
- }
- throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
+ const std::string &device_str = THPUtils_unpackString(args[i]);
+ return at::Device(device_str);
}
inline at::Device PythonArgs::deviceWithDefault(int i, const at::Device& default_device) {