#pragma once
#include <c10/util/StringUtil.h>
+#include <c10/util/string_view.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
});
}
- c10::optional<int> argumentIndexWithName(const std::string& name) const {
+ c10::optional<int> argumentIndexWithName(c10::string_view name) const {
for(size_t i = 0; i < arguments().size(); ++i) {
if(name == arguments()[i].name())
return i;
}
c10::optional<int> OperatorBase::argumentIndexWithName(
- const std::string& name) const {
+ c10::string_view name) const {
#if defined(EXPOSE_C2_OPS) || \
!defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
return getFunctionSchema().argumentIndexWithName(name);
#include <c10/macros/Macros.h>
#include <c10/util/Registry.h>
+#include <c10/util/string_view.h>
#include <c10/util/typeid.h>
#include <c10/core/Stream.h>
#include "caffe2/core/blob.h"
/** @brief Checks if the operator has an argument of the given name.
*/
- inline bool HasArgument(const string& name) const {
+ inline bool HasArgument(c10::string_view name) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
// Functions that deal with arguments. Basically, this allows us to map an
// argument name to a specific type of argument that we are trying to access.
template <typename T>
- inline T GetSingleArgument(const string& name, const T& default_value) const {
+ inline T GetSingleArgument(c10::string_view name, const T& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
}
template <typename T>
- inline bool HasSingleArgumentOfType(const string& name) const {
+ inline bool HasSingleArgumentOfType(c10::string_view name) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name);
template <typename T>
inline vector<T> GetRepeatedArgument(
- const string& name,
+ c10::string_view name,
const vector<T>& default_value = {}) const;
// Get the inputs and outputs as specific types.
}
}
- c10::optional<int> argumentIndexWithName(const std::string& name) const;
+ c10::optional<int> argumentIndexWithName(c10::string_view name) const;
// An event used by asynchronous execution.
std::unique_ptr<Event> event_;
template <>
inline NetDef OperatorBase::GetSingleArgument<NetDef>(
- const std::string& name,
+ c10::string_view name,
const NetDef& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
template <typename T>
inline vector<T> OperatorBase::GetRepeatedArgument(
- const string& name,
+ c10::string_view name,
const vector<T>& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
// int16_t. We need to load it as List<int64_t> and transform to int16_t.
template <>
inline vector<int16_t> OperatorBase::GetRepeatedArgument<int16_t>(
- const string& name,
+ c10::string_view name,
const vector<int16_t>& default_value) const {
if (isLegacyOperator()) {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
}
}
-C10_EXPORT bool ArgumentHelper::HasArgument(const string& name) const {
+C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const {
+#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
return arg_map_.count(name);
+#else
+ return arg_map_.count(std::string(name));
+#endif
}
namespace {
T, fieldname, enforce_lossless_conversion) \
template <> \
C10_EXPORT T ArgumentHelper::GetSingleArgument<T>( \
- const string& name, const T& default_value) const { \
- if (arg_map_.count(name) == 0) { \
+ c10::string_view name, const T& default_value) const { \
+ auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
+ if (it == arg_map_.end()) { \
VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
return default_value; \
} \
CAFFE_ENFORCE( \
- arg_map_.at(name).has_##fieldname(), \
+ it->second.has_##fieldname(), \
"Argument ", \
name, \
" does not have the right field: expected field " #fieldname); \
- auto value = arg_map_.at(name).fieldname(); \
+ auto value = it->second.fieldname(); \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \
} \
template <> \
C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>( \
- const string& name) const { \
- if (arg_map_.count(name) == 0) { \
+ c10::string_view name) const { \
+ auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
+ if (it == arg_map_.end()) { \
return false; \
} \
- return arg_map_.at(name).has_##fieldname(); \
+ return it->second.has_##fieldname(); \
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \
T, fieldname, enforce_lossless_conversion) \
template <> \
- C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
- const string& name, const std::vector<T>& default_value) const { \
- if (arg_map_.count(name) == 0) { \
+ C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
+ c10::string_view name, const std::vector<T>& default_value) const { \
+ auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
+ if (it == arg_map_.end()) { \
return default_value; \
} \
- std::vector<T> values; \
- for (const auto& v : arg_map_.at(name).fieldname()) { \
+ std::vector<T> values; \
+ for (const auto& v : it->second.fieldname()) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
// Return the argument index or -1 if it does not exist.
C10_EXPORT int GetArgumentIndex(
const google::protobuf::RepeatedPtrField<Argument>& args,
- const string& name) {
+ c10::string_view name) {
int index = 0;
for (const Argument& arg : args) {
if (arg.name() == name) {
C10_EXPORT const Argument& GetArgument(
const OperatorDef& def,
- const string& name) {
+ c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return def.arg(index);
}
}
-C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
+C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return def.arg(index);
C10_EXPORT const Argument* GetArgumentPtr(
const OperatorDef& def,
- const string& name) {
+ c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return &def.arg(index);
C10_EXPORT const Argument* GetArgumentPtr(
const NetDef& def,
- const string& name) {
+ c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return &def.arg(index);
C10_EXPORT bool GetFlagArgument(
const google::protobuf::RepeatedPtrField<Argument>& args,
- const string& name,
+ c10::string_view name,
bool default_value) {
int index = GetArgumentIndex(args, name);
if (index != -1) {
C10_EXPORT bool GetFlagArgument(
const OperatorDef& def,
- const string& name,
+ c10::string_view name,
bool default_value) {
return GetFlagArgument(def.arg(), name, default_value);
}
C10_EXPORT bool
-GetFlagArgument(const NetDef& def, const string& name, bool default_value) {
+GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) {
return GetFlagArgument(def.arg(), name, default_value);
}
#endif // !CAFFE2_USE_LITE_PROTO
#include <c10/util/Logging.h>
+#include <c10/util/string_view.h>
#include "caffe2/utils/proto_wrap.h"
#include "caffe2/proto/caffe2_pb.h"
+#ifndef C10_ANDROID
+#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
+#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key)
+#else
+#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key))
+#endif
+
namespace caffe2 {
using std::string;
class C10_EXPORT ArgumentHelper {
public:
template <typename Def>
- static bool HasArgument(const Def& def, const string& name) {
+ static bool HasArgument(const Def& def, c10::string_view name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T>
static T GetSingleArgument(
const Def& def,
- const string& name,
+ c10::string_view name,
const T& default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
}
template <typename Def, typename T>
- static bool HasSingleArgumentOfType(const Def& def, const string& name) {
+ static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static std::vector<T> GetRepeatedArgument(
const Def& def,
- const string& name,
+ c10::string_view name,
const std::vector<T>& default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
}
template <typename Def, typename MessageType>
- static MessageType GetMessageArgument(const Def& def, const string& name) {
+ static MessageType GetMessageArgument(const Def& def, c10::string_view name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
}
template <typename Def, typename MessageType>
static std::vector<MessageType> GetRepeatedMessageArgument(
const Def& def,
- const string& name) {
+ c10::string_view name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
explicit ArgumentHelper(const OperatorDef& def);
explicit ArgumentHelper(const NetDef& netdef);
- bool HasArgument(const string& name) const;
+ bool HasArgument(c10::string_view name) const;
template <typename T>
- T GetSingleArgument(const string& name, const T& default_value) const;
+ T GetSingleArgument(c10::string_view name, const T& default_value) const;
template <typename T>
- bool HasSingleArgumentOfType(const string& name) const;
+ bool HasSingleArgumentOfType(c10::string_view name) const;
template <typename T>
std::vector<T> GetRepeatedArgument(
- const string& name,
+ c10::string_view name,
const std::vector<T>& default_value = std::vector<T>()) const;
template <typename MessageType>
- MessageType GetMessageArgument(const string& name) const {
- CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
+ MessageType GetMessageArgument(c10::string_view name) const {
+ auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
+ CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
MessageType message;
- if (arg_map_.at(name).has_s()) {
+ if (it->second.has_s()) {
CAFFE_ENFORCE(
- message.ParseFromString(arg_map_.at(name).s()),
+ message.ParseFromString(it->second.s()),
"Failed to parse content from the string");
} else {
VLOG(1) << "Return empty message for parameter " << name;
}
template <typename MessageType>
- std::vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
- CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
- std::vector<MessageType> messages(arg_map_.at(name).strings_size());
+ std::vector<MessageType> GetRepeatedMessageArgument(c10::string_view name) const {
+ auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
+ CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
+ std::vector<MessageType> messages(it->second.strings_size());
for (int i = 0; i < messages.size(); ++i) {
CAFFE_ENFORCE(
- messages[i].ParseFromString(arg_map_.at(name).strings(i)),
+ messages[i].ParseFromString(it->second.strings(i)),
"Failed to parse content from the string");
}
return messages;
}
private:
- std::map<string, Argument> arg_map_;
+ std::map<string, Argument
+#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
+ , std::less<>
+#endif
+ > arg_map_;
};
// **** Arguments Utils *****
// Helper methods to get an argument from OperatorDef or NetDef given argument
// name. Throws if argument does not exist.
-TORCH_API const Argument& GetArgument(const OperatorDef& def, const string& name);
-TORCH_API const Argument& GetArgument(const NetDef& def, const string& name);
+TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name);
+TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name);
// Helper methods to get an argument from OperatorDef or NetDef given argument
// name. Returns nullptr if argument does not exist.
-TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, const string& name);
-TORCH_API const Argument* GetArgumentPtr(const NetDef& def, const string& name);
+TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name);
+TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name);
// Helper methods to query a boolean argument flag from OperatorDef or NetDef
// given argument name. If argument does not exist, return default value.
// Throws if argument exists but the type is not boolean.
TORCH_API bool GetFlagArgument(
const OperatorDef& def,
- const string& name,
+ c10::string_view name,
bool default_value = false);
TORCH_API bool GetFlagArgument(
const NetDef& def,
- const string& name,
+ c10::string_view name,
bool default_value = false);
TORCH_API Argument* GetMutableArgument(