#include <algorithm>
#include <array>
+#include <cctype>
#include <exception>
#include <ostream>
-#include <regex>
#include <string>
#include <tuple>
#include <vector>
-// Check if compiler has working std::regex implementation
-//
-// Test below is adapted from https://stackoverflow.com/a/41186162
-#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
-// Compiler has working regex. MSVC has erroneous __cplusplus.
-#elif __cplusplus >= 201103L && \
- (!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
- (defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \
- defined(_GLIBCXX_REGEX_STATE_LIMIT) || \
- (defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4)))
-// Compiler has working regex.
-#else
-static_assert(false, "Compiler does not have proper regex support.");
-#endif
-
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
device_string);
}
+enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
+
} // namespace
Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
- // We assume gcc 5+, so we can use proper regex.
- static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
- std::smatch match;
- TORCH_CHECK(
- std::regex_match(device_string, match, regex),
- "Invalid device string: '",
- device_string,
- "'");
- type_ = parse_type(match[1].str());
- if (match[2].matched) {
- try {
- index_ = c10::stoi(match[2].str());
- } catch (const std::exception&) {
- TORCH_CHECK(
- false,
- "Could not parse device index '",
- match[2].str(),
- "' in device string '",
- device_string,
- "'");
+ std::string device_name, device_index_str;
+ DeviceStringParsingState pstate = DeviceStringParsingState::START;
+
+ // The code below tries to match the string in the variable
+ // device_string against the regular expression:
+ // ([a-zA-Z_]+)(?::([1-9]\\d*|0))?
+ for (size_t i = 0;
+ pstate != DeviceStringParsingState::ERROR && i < device_string.size();
+ ++i) {
+ const char ch = device_string.at(i);
+ switch (pstate) {
+ case DeviceStringParsingState::START:
+ if (ch != ':') {
+ if (isalpha(ch) || ch == '_') {
+ device_name.push_back(ch);
+ } else {
+ pstate = DeviceStringParsingState::ERROR;
+ }
+ } else {
+ pstate = DeviceStringParsingState::INDEX_START;
+ }
+ break;
+
+ case DeviceStringParsingState::INDEX_START:
+ if (isdigit(ch)) {
+ device_index_str.push_back(ch);
+ pstate = DeviceStringParsingState::INDEX_REST;
+ } else {
+ pstate = DeviceStringParsingState::ERROR;
+ }
+ break;
+
+ case DeviceStringParsingState::INDEX_REST:
+ if (device_index_str.at(0) == '0') {
+ pstate = DeviceStringParsingState::ERROR;
+ break;
+ }
+ if (isdigit(ch)) {
+ device_index_str.push_back(ch);
+ } else {
+ pstate = DeviceStringParsingState::ERROR;
+ }
+ break;
+
+ case DeviceStringParsingState::ERROR:
+ // Execution won't reach here.
+ break;
+ }
+ }
+
+ const bool has_error = device_name.empty() ||
+ pstate == DeviceStringParsingState::ERROR ||
+ (pstate == DeviceStringParsingState::INDEX_START &&
+ device_index_str.empty());
+
+ TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
+
+ try {
+ if (!device_index_str.empty()) {
+ index_ = c10::stoi(device_index_str);
}
+ } catch (const std::exception&) {
+ TORCH_CHECK(
+ false,
+ "Could not parse device index '",
+ device_index_str,
+ "' in device string '",
+ device_string,
+ "'");
}
+ type_ = parse_type(device_name);
validate();
}