Avoid using static unordered_map (#1304)
authorLei Zhang <antiagainst@google.com>
Thu, 15 Feb 2018 15:19:15 +0000 (10:19 -0500)
committerGitHub <noreply@github.com>
Thu, 15 Feb 2018 15:19:15 +0000 (10:19 -0500)
unordered_map is not POD. Using it as static may cause problems
when operator new() and operator delete() is customized.

Also changed some function signatures to use const char* instead
of std::string, which will give caller the flexibility to avoid
creating a std::string.

source/enum_string_mapping.cpp
source/enum_string_mapping.h
source/opt/feature_manager.cpp
source/validate.cpp
source/validate_instruction.cpp
test/enum_string_mapping_test.cpp
utils/generate_grammar_tables.py

index ff40bd8..e993b58 100644 (file)
@@ -14,7 +14,9 @@
 
 #include "enum_string_mapping.h"
 
+#include <algorithm>
 #include <cassert>
+#include <cstring>
 #include <string>
 #include <unordered_map>
 
index c9ac58a..0345d5b 100644 (file)
 namespace libspirv {
 
 // Finds Extension enum corresponding to |str|. Returns false if not found.
-bool GetExtensionFromString(const std::string& str, Extension* extension);
+bool GetExtensionFromString(const char* str, Extension* extension);
 
 // Returns text string corresponding to |extension|.
-std::string ExtensionToString(Extension extension);
+const char* ExtensionToString(Extension extension);
 
 // Returns text string corresponding to |capability|.
-std::string CapabilityToString(SpvCapability capability);
+const char* CapabilityToString(SpvCapability capability);
 
 }  // namespace libspirv
 
index ebb1dd5..8e1bcdc 100644 (file)
@@ -31,7 +31,7 @@ void FeatureManager::AddExtensions(ir::Module* module) {
     const std::string name =
         reinterpret_cast<const char*>(ext.GetInOperand(0u).words.data());
     libspirv::Extension extension;
-    if (libspirv::GetExtensionFromString(name, &extension)) {
+    if (libspirv::GetExtensionFromString(name.c_str(), &extension)) {
       extensions_.Add(extension);
     }
   }
index fadecd0..b1aad07 100644 (file)
@@ -127,7 +127,7 @@ void RegisterExtension(ValidationState_t& _,
                        const spv_parsed_instruction_t* inst) {
   const std::string extension_str = libspirv::GetExtensionString(inst);
   Extension extension;
-  if (!GetExtensionFromString(extension_str, &extension)) {
+  if (!GetExtensionFromString(extension_str.c_str(), &extension)) {
     // The error will be logged in the ProcessInstruction pass.
     return;
   }
index dc013cb..60a0688 100644 (file)
@@ -424,7 +424,7 @@ void CheckIfKnownExtension(ValidationState_t& _,
                            const spv_parsed_instruction_t* inst) {
   const std::string extension_str = GetExtensionString(inst);
   Extension extension;
-  if (!GetExtensionFromString(extension_str, &extension)) {
+  if (!GetExtensionFromString(extension_str.c_str(), &extension)) {
     _.diag(SPV_SUCCESS) << "Found unrecognized extension " << extension_str;
     return;
   }
index 01d3227..4e46fb4 100644 (file)
@@ -38,8 +38,8 @@ TEST_P(ExtensionTest, TestExtensionFromString) {
   const Extension extension = param.first;
   const std::string extension_str = param.second;
   Extension result_extension;
-  ASSERT_TRUE(
-      libspirv::GetExtensionFromString(extension_str, &result_extension));
+  ASSERT_TRUE(libspirv::GetExtensionFromString(extension_str.c_str(),
+                                               &result_extension));
   EXPECT_EQ(extension, result_extension);
 }
 
@@ -53,7 +53,8 @@ TEST_P(ExtensionTest, TestExtensionToString) {
 
 TEST_P(UnknownExtensionTest, TestExtensionFromStringFails) {
   Extension result_extension;
-  ASSERT_FALSE(libspirv::GetExtensionFromString(GetParam(), &result_extension));
+  ASSERT_FALSE(
+      libspirv::GetExtensionFromString(GetParam().c_str(), &result_extension));
 }
 
 TEST_P(CapabilityTest, TestCapabilityToString) {
@@ -86,6 +87,8 @@ INSTANTIATE_TEST_CASE_P(
 
 INSTANTIATE_TEST_CASE_P(UnknownExtensions, UnknownExtensionTest,
                         Values("", "SPV_KHR_", "SPV_KHR_device_group_ERROR",
+                               /*alphabetically before all extensions*/ "A",
+                               /*alphabetically after all extensions*/ "Z",
                                "SPV_ERROR_random_string_hfsdklhlktherh"));
 
 INSTANTIATE_TEST_CASE_P(
index b307554..f83b113 100755 (executable)
@@ -328,7 +328,7 @@ def generate_instruction(inst, version, is_ext_inst):
 def generate_instruction_table(inst_table, version):
     """Returns the info table containing all SPIR-V instructions,
     sorted by opcode, and prefixed by capability arrays.
-    
+
     Note:
       - the built-in sorted() function is guaranteed to be stable.
         https://docs.python.org/3/library/functions.html#sorted
@@ -551,7 +551,7 @@ def generate_capability_to_string_table(operands):
 def generate_extension_to_string_mapping(operands):
     """Returns mapping function from extensions to corresponding strings."""
     extensions = get_extension_list(operands)
-    function = 'std::string ExtensionToString(Extension extension) {\n'
+    function = 'const char* ExtensionToString(Extension extension) {\n'
     function += '  switch (extension) {\n'
     template = '    case Extension::k{extension}:\n' \
         '      return "{extension}";\n'
@@ -563,14 +563,26 @@ def generate_extension_to_string_mapping(operands):
 
 def generate_string_to_extension_mapping(operands):
     """Returns mapping function from strings to corresponding extensions."""
-    function = 'bool GetExtensionFromString(' \
-        'const std::string& str, Extension* extension) {\n ' \
-        'static const std::unordered_map<std::string, Extension> mapping =\n'
-    function += generate_string_to_extension_table(operands)
-    function += ';\n\n'
-    function += '  const auto it = mapping.find(str);\n\n' \
-        '  if (it == mapping.end()) return false;\n\n' \
-        '  *extension = it->second;\n  return true;\n}'
+    extensions = get_extension_list(operands)  # Already sorted
+
+    function = '''
+    bool GetExtensionFromString(const char* str, Extension* extension) {{
+        static const char* known_ext_strs[] = {{ {strs} }};
+        static const Extension known_ext_ids[] = {{ {ids} }};
+        const auto b = std::begin(known_ext_strs);
+        const auto e = std::end(known_ext_strs);
+        const auto found = std::equal_range(
+            b, e, str, [](const char* str1, const char* str2) {{
+                return std::strcmp(str1, str2) < 0;
+            }});
+        if (found.first == e || found.first == found.second) return false;
+
+        *extension = known_ext_ids[found.first - b];
+        return true;
+    }}
+    '''.format(strs=', '.join(['"{}"'.format(e) for e in extensions]),
+               ids=', '.join(['Extension::k{}'.format(e) for e in extensions]))
+
     return function
 
 
@@ -578,7 +590,7 @@ def generate_capability_to_string_mapping(operands):
     """Returns mapping function from capabilities to corresponding strings.
     We take care to avoid emitting duplicate values.
     """
-    function = 'std::string CapabilityToString(SpvCapability capability) {\n'
+    function = 'const char* CapabilityToString(SpvCapability capability) {\n'
     function += '  switch (capability) {\n'
     template = '    case SpvCapability{capability}:\n' \
         '      return "{capability}";\n'