From 9bbf80969ece148ca5da2107ef9ad26a99891738 Mon Sep 17 00:00:00 2001 From: Dhruv Matani Date: Wed, 18 Aug 2021 14:47:19 -0700 Subject: [PATCH] [PyTorch] Avoid using std::regex for device string parsing in Device.cpp (#63464) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63464 This was previously committed as D30281388 (https://github.com/pytorch/pytorch/commit/4d6f98ecada2d85b2474b023838debad4305316d), but was reverted due to t98478641. jnkwok1 confirmed that this change was not the root cause, so trying to land it again. Currently, `std::regex` is used for parsing device strings. This is undesirable for a few reasons. 1. Increases binary size 2. Slows down model loading 3. Potentially uses more memory at runtime 4. Takes marginally longer time to build code that uses std::regex v/s not using std::regex This change avoids the use of `std::regex` for parsing the device string since we don't need to. ghstack-source-id: 136006963 ghstack-source-id: 136081898 Test Plan: ### AI Bench Runs **Before this change:** 1. Model Load time: [252ms](https://www.internalfb.com/intern/aibench/details/332471502816548) 2. Model unload time: 3.5ms **After this change:** 1. Model Load time: [240ms](https://www.internalfb.com/intern/aibench/details/652195589031318), which is an approx 5% reduction for the current model. I suspect percentage wise, it will be larger for smaller models since this is a fixed cost reduction. 2. Model unload time: 3.3ms (probably too small to be meaningfully impactful to an end user). ### BSB Results ``` D30281388 (https://github.com/pytorch/pytorch/commit/4d6f98ecada2d85b2474b023838debad4305316d)-V1 (https://www.internalfb.com/intern/diff/D30281388 (https://github.com/pytorch/pytorch/commit/4d6f98ecada2d85b2474b023838debad4305316d)/?dest_number=135713848) messenger-pika-optimized-device: Succeeded Change in Download Size for arm64 + 3x assets variation: -7.1 KiB Change in Uncompressed Size for arm64 + 3x assets variation: -17.6 KiB Mbex Comparison: https://our.intern.facebook.com/intern/mbex/bsb:551399955987465@base/bsb:551399955987465@diff/ ``` Reviewed By: raziel, pavithranrao Differential Revision: D30388269 fbshipit-source-id: 10942e7aa56f9ea47aa479a8f50187f2ce2899bf --- c10/core/Device.cpp | 108 +++++++++++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 36 deletions(-) diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index ee6f1b473f..2709c29ce8 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -4,28 +4,13 @@ #include #include +#include #include #include -#include #include #include #include -// Check if compiler has working std::regex implementation -// -// Test below is adapted from https://stackoverflow.com/a/41186162 -#if defined(_MSVC_LANG) && _MSVC_LANG >= 201103L -// Compiler has working regex. MSVC has erroneous __cplusplus. -#elif __cplusplus >= 201103L && \ - (!defined(__GLIBCXX__) || (__cplusplus >= 201402L) || \ - (defined(_GLIBCXX_REGEX_DFS_QUANTIFIERS_LIMIT) || \ - defined(_GLIBCXX_REGEX_STATE_LIMIT) || \ - (defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE > 4))) -// Compiler has working regex. -#else -static_assert(false, "Compiler does not have proper regex support."); -#endif - namespace c10 { namespace { DeviceType parse_type(const std::string& device_string) { @@ -65,33 +50,84 @@ DeviceType parse_type(const std::string& device_string) { "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ", device_string); } +enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR }; + } // namespace Device::Device(const std::string& device_string) : Device(Type::CPU) { TORCH_CHECK(!device_string.empty(), "Device string must not be empty"); - // We assume gcc 5+, so we can use proper regex. - static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?"); - std::smatch match; - TORCH_CHECK( - std::regex_match(device_string, match, regex), - "Invalid device string: '", - device_string, - "'"); - type_ = parse_type(match[1].str()); - if (match[2].matched) { - try { - index_ = c10::stoi(match[2].str()); - } catch (const std::exception&) { - TORCH_CHECK( - false, - "Could not parse device index '", - match[2].str(), - "' in device string '", - device_string, - "'"); + std::string device_name, device_index_str; + DeviceStringParsingState pstate = DeviceStringParsingState::START; + + // The code below tries to match the string in the variable + // device_string against the regular expression: + // ([a-zA-Z_]+)(?::([1-9]\\d*|0))? + for (size_t i = 0; + pstate != DeviceStringParsingState::ERROR && i < device_string.size(); + ++i) { + const char ch = device_string.at(i); + switch (pstate) { + case DeviceStringParsingState::START: + if (ch != ':') { + if (isalpha(ch) || ch == '_') { + device_name.push_back(ch); + } else { + pstate = DeviceStringParsingState::ERROR; + } + } else { + pstate = DeviceStringParsingState::INDEX_START; + } + break; + + case DeviceStringParsingState::INDEX_START: + if (isdigit(ch)) { + device_index_str.push_back(ch); + pstate = DeviceStringParsingState::INDEX_REST; + } else { + pstate = DeviceStringParsingState::ERROR; + } + break; + + case DeviceStringParsingState::INDEX_REST: + if (device_index_str.at(0) == '0') { + pstate = DeviceStringParsingState::ERROR; + break; + } + if (isdigit(ch)) { + device_index_str.push_back(ch); + } else { + pstate = DeviceStringParsingState::ERROR; + } + break; + + case DeviceStringParsingState::ERROR: + // Execution won't reach here. + break; + } + } + + const bool has_error = device_name.empty() || + pstate == DeviceStringParsingState::ERROR || + (pstate == DeviceStringParsingState::INDEX_START && + device_index_str.empty()); + + TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'"); + + try { + if (!device_index_str.empty()) { + index_ = c10::stoi(device_index_str); } + } catch (const std::exception&) { + TORCH_CHECK( + false, + "Could not parse device index '", + device_index_str, + "' in device string '", + device_string, + "'"); } + type_ = parse_type(device_name); validate(); } -- 2.34.1