[Caffe2] Create fewer strings during argument fetching (#64285)
authorScott Wolchok <swolchok@fb.com>
Wed, 1 Sep 2021 20:24:11 +0000 (13:24 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 20:30:54 +0000 (13:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64285

With C++14 heterogeneous ordered container lookup, it is no longer necessary to create a `std::string` in order to look up elements of a `CaffeMap` keyed by std::string. Accordingly, this diff reworks the argument-getting operator functions to avoid that in favor of `c10::string_view`.
ghstack-source-id: 137139818
ghstack-source-id: 137139818

Test Plan: buildsizebot iOS apps -- code size win. less strings is probably marginally good for perf but this only happens at setup time anyway.

Reviewed By: dzhulgakov

Differential Revision: D26826676

fbshipit-source-id: ee653b14dc2c528bae8c90f0fc6a7a419cbca1d6

aten/src/ATen/core/function_schema.h
caffe2/core/operator.cc
caffe2/core/operator.h
caffe2/utils/proto_utils.cc
caffe2/utils/proto_utils.h

index a7b5149..f4b11fc 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <c10/util/StringUtil.h>
+#include <c10/util/string_view.h>
 #include <ATen/core/jit_type.h>
 #include <ATen/core/interned_strings.h>
 #include <ATen/core/ivalue.h>
@@ -272,7 +273,7 @@ struct FunctionSchema {
         });
   }
 
-  c10::optional<int> argumentIndexWithName(const std::string& name) const {
+  c10::optional<int> argumentIndexWithName(c10::string_view name) const {
     for(size_t i = 0; i < arguments().size(); ++i) {
       if(name == arguments()[i].name())
         return i;
index ca66f78..e25c92a 100644 (file)
@@ -831,7 +831,7 @@ std::function<void(const OperatorDef&)> GetOperatorLogger() {
 }
 
 c10::optional<int> OperatorBase::argumentIndexWithName(
-    const std::string& name) const {
+    c10::string_view name) const {
 #if defined(EXPOSE_C2_OPS) || \
     !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
   return getFunctionSchema().argumentIndexWithName(name);
index b840254..15d1ead 100644 (file)
@@ -15,6 +15,7 @@
 
 #include <c10/macros/Macros.h>
 #include <c10/util/Registry.h>
+#include <c10/util/string_view.h>
 #include <c10/util/typeid.h>
 #include <c10/core/Stream.h>
 #include "caffe2/core/blob.h"
@@ -97,7 +98,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
 
   /** @brief Checks if the operator has an argument of the given name.
    */
-  inline bool HasArgument(const string& name) const {
+  inline bool HasArgument(c10::string_view name) const {
     if (isLegacyOperator()) {
       CAFFE_ENFORCE(operator_def_, "operator_def was null!");
       return ArgumentHelper::HasArgument(*operator_def_, name);
@@ -108,7 +109,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
   // Functions that deal with arguments. Basically, this allows us to map an
   // argument name to a specific type of argument that we are trying to access.
   template <typename T>
-  inline T GetSingleArgument(const string& name, const T& default_value) const {
+  inline T GetSingleArgument(c10::string_view name, const T& default_value) const {
     if (isLegacyOperator()) {
       CAFFE_ENFORCE(operator_def_, "operator_def was null!");
       return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
@@ -126,7 +127,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
   }
 
   template <typename T>
-  inline bool HasSingleArgumentOfType(const string& name) const {
+  inline bool HasSingleArgumentOfType(c10::string_view name) const {
     CAFFE_ENFORCE(operator_def_, "operator_def was null!");
     return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
         *operator_def_, name);
@@ -141,7 +142,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
 
   template <typename T>
   inline vector<T> GetRepeatedArgument(
-      const string& name,
+      c10::string_view name,
       const vector<T>& default_value = {}) const;
 
   // Get the inputs and outputs as specific types.
@@ -654,7 +655,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
     }
   }
 
-  c10::optional<int> argumentIndexWithName(const std::string& name) const;
+  c10::optional<int> argumentIndexWithName(c10::string_view name) const;
 
   // An event used by asynchronous execution.
   std::unique_ptr<Event> event_;
@@ -664,7 +665,7 @@ class TORCH_API OperatorBase : public Observable<OperatorBase> {
 
 template <>
 inline NetDef OperatorBase::GetSingleArgument<NetDef>(
-    const std::string& name,
+    c10::string_view name,
     const NetDef& default_value) const {
   if (isLegacyOperator()) {
     CAFFE_ENFORCE(operator_def_, "operator_def was null!");
@@ -756,7 +757,7 @@ inline vector<int16_t> OperatorBase::GetVectorFromIValueList<int16_t>(
 
 template <typename T>
 inline vector<T> OperatorBase::GetRepeatedArgument(
-    const string& name,
+    c10::string_view name,
     const vector<T>& default_value) const {
   if (isLegacyOperator()) {
     CAFFE_ENFORCE(operator_def_, "operator_def was null!");
@@ -778,7 +779,7 @@ inline vector<T> OperatorBase::GetRepeatedArgument(
 // int16_t. We need to load it as List<int64_t> and transform to int16_t.
 template <>
 inline vector<int16_t> OperatorBase::GetRepeatedArgument<int16_t>(
-    const string& name,
+    c10::string_view name,
     const vector<int16_t>& default_value) const {
   if (isLegacyOperator()) {
     CAFFE_ENFORCE(operator_def_, "operator_def was null!");
index d2aa59e..db37946 100644 (file)
@@ -323,8 +323,12 @@ C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
   }
 }
 
-C10_EXPORT bool ArgumentHelper::HasArgument(const string& name) const {
+C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const {
+#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
   return arg_map_.count(name);
+#else
+  return arg_map_.count(std::string(name));
+#endif
 }
 
 namespace {
@@ -364,18 +368,19 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) {
     T, fieldname, enforce_lossless_conversion)                         \
   template <>                                                          \
   C10_EXPORT T ArgumentHelper::GetSingleArgument<T>(                   \
-      const string& name, const T& default_value) const {              \
-    if (arg_map_.count(name) == 0) {                                   \
+      c10::string_view name, const T& default_value) const {           \
+    auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);                      \
+    if (it == arg_map_.end()) {                                         \
       VLOG(1) << "Using default parameter value " << default_value     \
               << " for parameter " << name;                            \
       return default_value;                                            \
     }                                                                  \
     CAFFE_ENFORCE(                                                     \
-        arg_map_.at(name).has_##fieldname(),                           \
+        it->second.has_##fieldname(),                                  \
         "Argument ",                                                   \
         name,                                                          \
         " does not have the right field: expected field " #fieldname); \
-    auto value = arg_map_.at(name).fieldname();                        \
+    auto value = it->second.fieldname();                               \
     if (enforce_lossless_conversion) {                                 \
       auto supportsConversion =                                        \
           SupportsLosslessConversion<decltype(value), T>(value);       \
@@ -391,11 +396,12 @@ std::ostream& operator<<(std::ostream& output, const NetDef& n) {
   }                                                                    \
   template <>                                                          \
   C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>(          \
-      const string& name) const {                                      \
-    if (arg_map_.count(name) == 0) {                                   \
+      c10::string_view name) const {                                   \
+    auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);                     \
+    if (it == arg_map_.end()) {                                        \
       return false;                                                    \
     }                                                                  \
-    return arg_map_.at(name).has_##fieldname();                        \
+    return it->second.has_##fieldname();                               \
   }
 
 INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
@@ -415,13 +421,14 @@ INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
 #define INSTANTIATE_GET_REPEATED_ARGUMENT(                             \
     T, fieldname, enforce_lossless_conversion)                         \
   template <>                                                          \
-  C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>(         \
-      const string& name, const std::vector<T>& default_value) const { \
-    if (arg_map_.count(name) == 0) {                                   \
+  C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
+      c10::string_view name, const std::vector<T>& default_value) const { \
+    auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);                      \
+    if (it == arg_map_.end()) {                                         \
       return default_value;                                            \
     }                                                                  \
-    std::vector<T> values;                                                  \
-    for (const auto& v : arg_map_.at(name).fieldname()) {              \
+    std::vector<T> values;                                           \
+    for (const auto& v : it->second.fieldname()) {                     \
       if (enforce_lossless_conversion) {                               \
         auto supportsConversion =                                      \
             SupportsLosslessConversion<decltype(v), T>(v);             \
@@ -531,7 +538,7 @@ C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) {
 // Return the argument index or -1 if it does not exist.
 C10_EXPORT int GetArgumentIndex(
     const google::protobuf::RepeatedPtrField<Argument>& args,
-    const string& name) {
+    c10::string_view name) {
   int index = 0;
   for (const Argument& arg : args) {
     if (arg.name() == name) {
@@ -544,7 +551,7 @@ C10_EXPORT int GetArgumentIndex(
 
 C10_EXPORT const Argument& GetArgument(
     const OperatorDef& def,
-    const string& name) {
+    c10::string_view name) {
   int index = GetArgumentIndex(def.arg(), name);
   if (index != -1) {
     return def.arg(index);
@@ -557,7 +564,7 @@ C10_EXPORT const Argument& GetArgument(
   }
 }
 
-C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
+C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) {
   int index = GetArgumentIndex(def.arg(), name);
   if (index != -1) {
     return def.arg(index);
@@ -572,7 +579,7 @@ C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
 
 C10_EXPORT const Argument* GetArgumentPtr(
     const OperatorDef& def,
-    const string& name) {
+    c10::string_view name) {
   int index = GetArgumentIndex(def.arg(), name);
   if (index != -1) {
     return &def.arg(index);
@@ -583,7 +590,7 @@ C10_EXPORT const Argument* GetArgumentPtr(
 
 C10_EXPORT const Argument* GetArgumentPtr(
     const NetDef& def,
-    const string& name) {
+    c10::string_view name) {
   int index = GetArgumentIndex(def.arg(), name);
   if (index != -1) {
     return &def.arg(index);
@@ -594,7 +601,7 @@ C10_EXPORT const Argument* GetArgumentPtr(
 
 C10_EXPORT bool GetFlagArgument(
     const google::protobuf::RepeatedPtrField<Argument>& args,
-    const string& name,
+    c10::string_view name,
     bool default_value) {
   int index = GetArgumentIndex(args, name);
   if (index != -1) {
@@ -609,13 +616,13 @@ C10_EXPORT bool GetFlagArgument(
 
 C10_EXPORT bool GetFlagArgument(
     const OperatorDef& def,
-    const string& name,
+    c10::string_view name,
     bool default_value) {
   return GetFlagArgument(def.arg(), name, default_value);
 }
 
 C10_EXPORT bool
-GetFlagArgument(const NetDef& def, const string& name, bool default_value) {
+GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) {
   return GetFlagArgument(def.arg(), name, default_value);
 }
 
index 5767698..b5c6b31 100644 (file)
@@ -8,10 +8,18 @@
 #endif  // !CAFFE2_USE_LITE_PROTO
 
 #include <c10/util/Logging.h>
+#include <c10/util/string_view.h>
 
 #include "caffe2/utils/proto_wrap.h"
 #include "caffe2/proto/caffe2_pb.h"
 
+#ifndef C10_ANDROID
+#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
+#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key)
+#else
+#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key))
+#endif
+
 namespace caffe2 {
 
 using std::string;
@@ -204,40 +212,40 @@ TORCH_API bool HasInput(const OperatorDef& op, const std::string& input);
 class C10_EXPORT ArgumentHelper {
  public:
   template <typename Def>
-  static bool HasArgument(const Def& def, const string& name) {
+  static bool HasArgument(const Def& def, c10::string_view name) {
     return ArgumentHelper(def).HasArgument(name);
   }
 
   template <typename Def, typename T>
   static T GetSingleArgument(
       const Def& def,
-      const string& name,
+      c10::string_view name,
       const T& default_value) {
     return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
   }
 
   template <typename Def, typename T>
-  static bool HasSingleArgumentOfType(const Def& def, const string& name) {
+  static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) {
     return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
   }
 
   template <typename Def, typename T>
   static std::vector<T> GetRepeatedArgument(
       const Def& def,
-      const string& name,
+      c10::string_view name,
       const std::vector<T>& default_value = std::vector<T>()) {
     return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
   }
 
   template <typename Def, typename MessageType>
-  static MessageType GetMessageArgument(const Def& def, const string& name) {
+  static MessageType GetMessageArgument(const Def& def, c10::string_view name) {
     return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
   }
 
   template <typename Def, typename MessageType>
   static std::vector<MessageType> GetRepeatedMessageArgument(
       const Def& def,
-      const string& name) {
+      c10::string_view name) {
     return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
   }
 
@@ -255,24 +263,25 @@ class C10_EXPORT ArgumentHelper {
 
   explicit ArgumentHelper(const OperatorDef& def);
   explicit ArgumentHelper(const NetDef& netdef);
-  bool HasArgument(const string& name) const;
+  bool HasArgument(c10::string_view name) const;
 
   template <typename T>
-  T GetSingleArgument(const string& name, const T& default_value) const;
+  T GetSingleArgument(c10::string_view name, const T& default_value) const;
   template <typename T>
-  bool HasSingleArgumentOfType(const string& name) const;
+  bool HasSingleArgumentOfType(c10::string_view name) const;
   template <typename T>
   std::vector<T> GetRepeatedArgument(
-      const string& name,
+      c10::string_view name,
       const std::vector<T>& default_value = std::vector<T>()) const;
 
   template <typename MessageType>
-  MessageType GetMessageArgument(const string& name) const {
-    CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
+  MessageType GetMessageArgument(c10::string_view name) const {
+    auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
+    CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
     MessageType message;
-    if (arg_map_.at(name).has_s()) {
+    if (it->second.has_s()) {
       CAFFE_ENFORCE(
-          message.ParseFromString(arg_map_.at(name).s()),
+          message.ParseFromString(it->second.s()),
           "Failed to parse content from the string");
     } else {
       VLOG(1) << "Return empty message for parameter " << name;
@@ -281,42 +290,47 @@ class C10_EXPORT ArgumentHelper {
   }
 
   template <typename MessageType>
-  std::vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
-    CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
-    std::vector<MessageType> messages(arg_map_.at(name).strings_size());
+  std::vector<MessageType> GetRepeatedMessageArgument(c10::string_view name) const {
+    auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
+    CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
+    std::vector<MessageType> messages(it->second.strings_size());
     for (int i = 0; i < messages.size(); ++i) {
       CAFFE_ENFORCE(
-          messages[i].ParseFromString(arg_map_.at(name).strings(i)),
+          messages[i].ParseFromString(it->second.strings(i)),
           "Failed to parse content from the string");
     }
     return messages;
   }
 
  private:
-  std::map<string, Argument> arg_map_;
+  std::map<string, Argument
+#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
+  , std::less<>
+#endif
+  > arg_map_;
 };
 
 // **** Arguments Utils *****
 
 // Helper methods to get an argument from OperatorDef or NetDef given argument
 // name. Throws if argument does not exist.
-TORCH_API const Argument& GetArgument(const OperatorDef& def, const string& name);
-TORCH_API const Argument& GetArgument(const NetDef& def, const string& name);
+TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name);
+TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name);
 // Helper methods to get an argument from OperatorDef or NetDef given argument
 // name. Returns nullptr if argument does not exist.
-TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name);
-TORCH_API const Argument* GetArgumentPtr(const NetDef& def, const string& name);
+TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name);
+TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name);
 
 // Helper methods to query a boolean argument flag from OperatorDef or NetDef
 // given argument name. If argument does not exist, return default value.
 // Throws if argument exists but the type is not boolean.
 TORCH_API bool GetFlagArgument(
     const OperatorDef& def,
-    const string& name,
+    c10::string_view name,
     bool default_value = false);
 TORCH_API bool GetFlagArgument(
     const NetDef& def,
-    const string& name,
+    c10::string_view name,
     bool default_value = false);
 
 TORCH_API Argument* GetMutableArgument(