[PyTorch] Avoid using std::regex for device string parsing in Device.cpp (#63464)
authorDhruv Matani <dhruvbird@fb.com>
Wed, 18 Aug 2021 21:47:19 +0000 (14:47 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 21:55:12 +0000 (14:55 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63464

This was previously committed as D30281388 (https://github.com/pytorch/pytorch/commit/4d6f98ecada2d85b2474b023838debad4305316d), but was reverted due to t98478641. jnkwok1 confirmed that this change was not the root cause, so trying to land it again.

Currently, `std::regex` is used for parsing device strings. This is undesirable for a few reasons.

1. Increases binary size
2. Slows down model loading
3. Potentially uses more memory at runtime
4. Takes marginally longer time to build code that uses std::regex v/s not using std::regex

This change avoids the use of `std::regex` for parsing the device string since we don't need to.
ghstack-source-id: 136006963
ghstack-source-id: 136081898

Test Plan:
### AI Bench Runs

**Before this change:**
1. Model Load time: [252ms](https://www.internalfb.com/intern/aibench/details/332471502816548)
2. Model unload time: 3.5ms

**After this change:**
1. Model Load time: [240ms](https://www.internalfb.com/intern/aibench/details/652195589031318), which is an approx 5% reduction for the current model. I suspect percentage wise, it will be larger for smaller models since this is a fixed cost reduction.
2. Model unload time: 3.3ms (probably too small to be meaningfully impactful to an end user).

### BSB Results

```
D30281388 (https://github.com/pytorch/pytorch/commit/4d6f98ecada2d85b2474b023838debad4305316d)-V1 (https://www.internalfb.com/intern/diff/D30281388 (https://github.com/pytorch/pytorch/commit/4d6f98ecada2d85b2474b023838debad4305316d)/?dest_number=135713848)

messenger-pika-optimized-device: Succeeded
Change in Download Size for arm64 + 3x assets variation: -7.1 KiB
Change in Uncompressed Size for arm64 + 3x assets variation: -17.6 KiB

Mbex Comparison: https://our.intern.facebook.com/intern/mbex/bsb:551399955987465@base/bsb:551399955987465@diff/
```

Reviewed By: raziel, pavithranrao

Differential Revision: D30388269

fbshipit-source-id: 10942e7aa56f9ea47aa479a8f50187f2ce2899bf

c10/core/Device.cpp

index ee6f1b473fe08a363576fff1a98211da02a40d5b..2709c29ce8460956c0cd8480ef874012be3b9470 100644 (file)
@@ -4,28 +4,13 @@
 
 #include <algorithm>
 #include <array>
+#include <cctype>
 #include <exception>
 #include <ostream>
-#include <regex>
 #include <string>
 #include <tuple>
 #include <vector>
 
-// Check if compiler has working std::regex implementation
-//
-// Test below is adapted from https://stackoverflow.com/a/41186162
-#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L
-// Compiler has working regex. MSVC has erroneous __cplusplus.
-#elif __cplusplus >= 201103L &&                           \
-    (!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \
-     (defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) ||    \
-      defined(_GLIBCXX_REGEX_STATE_LIMIT) ||              \
-      (defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4)))
-// Compiler has working regex.
-#else
-static_assert(false, "Compiler does not have proper regex support.");
-#endif
-
 namespace c10 {
 namespace {
 DeviceType parse_type(const std::string& device_string) {
@@ -65,33 +50,84 @@ DeviceType parse_type(const std::string& device_string) {
       "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
       device_string);
 }
+enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
+
 } // namespace
 
 Device::Device(const std::string& device_string) : Device(Type::CPU) {
   TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
 
-  // We assume gcc 5+, so we can use proper regex.
-  static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
-  std::smatch match;
-  TORCH_CHECK(
-      std::regex_match(device_string, match, regex),
-      "Invalid device string: '",
-      device_string,
-      "'");
-  type_ = parse_type(match[1].str());
-  if (match[2].matched) {
-    try {
-      index_ = c10::stoi(match[2].str());
-    } catch (const std::exception&) {
-      TORCH_CHECK(
-          false,
-          "Could not parse device index '",
-          match[2].str(),
-          "' in device string '",
-          device_string,
-          "'");
+  std::string device_name, device_index_str;
+  DeviceStringParsingState pstate = DeviceStringParsingState::START;
+
+  // The code below tries to match the string in the variable
+  // device_string against the regular expression:
+  // ([a-zA-Z_]+)(?::([1-9]\\d*|0))?
+  for (size_t i = 0;
+       pstate != DeviceStringParsingState::ERROR && i < device_string.size();
+       ++i) {
+    const char ch = device_string.at(i);
+    switch (pstate) {
+      case DeviceStringParsingState::START:
+        if (ch != ':') {
+          if (isalpha(ch) || ch == '_') {
+            device_name.push_back(ch);
+          } else {
+            pstate = DeviceStringParsingState::ERROR;
+          }
+        } else {
+          pstate = DeviceStringParsingState::INDEX_START;
+        }
+        break;
+
+      case DeviceStringParsingState::INDEX_START:
+        if (isdigit(ch)) {
+          device_index_str.push_back(ch);
+          pstate = DeviceStringParsingState::INDEX_REST;
+        } else {
+          pstate = DeviceStringParsingState::ERROR;
+        }
+        break;
+
+      case DeviceStringParsingState::INDEX_REST:
+        if (device_index_str.at(0) == '0') {
+          pstate = DeviceStringParsingState::ERROR;
+          break;
+        }
+        if (isdigit(ch)) {
+          device_index_str.push_back(ch);
+        } else {
+          pstate = DeviceStringParsingState::ERROR;
+        }
+        break;
+
+      case DeviceStringParsingState::ERROR:
+        // Execution won't reach here.
+        break;
+    }
+  }
+
+  const bool has_error = device_name.empty() ||
+      pstate == DeviceStringParsingState::ERROR ||
+      (pstate == DeviceStringParsingState::INDEX_START &&
+       device_index_str.empty());
+
+  TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
+
+  try {
+    if (!device_index_str.empty()) {
+      index_ = c10::stoi(device_index_str);
     }
+  } catch (const std::exception&) {
+    TORCH_CHECK(
+        false,
+        "Could not parse device index '",
+        device_index_str,
+        "' in device string '",
+        device_string,
+        "'");
   }
+  type_ = parse_type(device_name);
   validate();
 }