#include <iniparser.h>
#include <app_context.h>
+#include <layer.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
+#include <optimizer.h>
#include <util_func.h>
#include <adam.h>
return keys;
}
+template <typename T>
+const int AppContext::registerFactory(const FactoryType<T> factory,
+ const std::string &key,
+ const int int_key) {
+ static_assert(isSupported<T>::value,
+ "given type is not supported for current app context");
+
+ auto &index = std::get<IndexType<T>>(factory_map);
+ auto &str_map = std::get<StrIndexType<T>>(index);
+ auto &int_map = std::get<IntIndexType>(index);
+
+ std::string assigned_key = key == "" ? factory({})->getType() : key;
+
+ std::transform(assigned_key.begin(), assigned_key.end(), assigned_key.begin(),
+ [](unsigned char c) { return std::tolower(c); });
+
+ const std::lock_guard<std::mutex> lock(factory_mutex);
+ if (str_map.find(assigned_key) != str_map.end()) {
+ std::stringstream ss;
+ ss << "cannot register factory with already taken key: " << key;
+ throw std::invalid_argument(ss.str().c_str());
+ }
+
+ if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
+ std::stringstream ss;
+ ss << "cannot register factory with already taken int key: " << int_key;
+ throw std::invalid_argument(ss.str().c_str());
+ }
+
+ int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
+
+ str_map[assigned_key] = factory;
+ int_map[assigned_int_key] = assigned_key;
+
+ ml_logd("factory has registered with key: %s, int_key: %d",
+ assigned_key.c_str(), assigned_int_key);
+
+ return assigned_int_key;
+}
+
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+template const int AppContext::registerFactory<ml::train::Optimizer>(
+ const FactoryType<ml::train::Optimizer> factory, const std::string &key,
+ const int int_key);
+
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+template const int AppContext::registerFactory<nntrainer::Layer>(
+ const FactoryType<nntrainer::Layer> factory, const std::string &key,
+ const int int_key);
+
} // namespace nntrainer
#include <memory>
#include <mutex>
#include <sstream>
+#include <stdexcept>
#include <string>
+#include <type_traits>
#include <unordered_map>
#include <vector>
#include <optimizer.h>
#include <nntrainer_error.h>
-#include <nntrainer_log.h>
namespace nntrainer {
extern std::mutex factory_mutex;
+namespace {} // namespace
/**
* @class AppContext contains user-dependent configuration
template <typename T>
const int registerFactory(const FactoryType<T> factory,
const std::string &key = "",
- const int int_key = -1) {
-
- auto &index = std::get<IndexType<T>>(factory_map);
- auto &str_map = std::get<StrIndexType<T>>(index);
- auto &int_map = std::get<IntIndexType>(index);
-
- std::string assigned_key = key == "" ? factory({})->getType() : key;
-
- std::transform(assigned_key.begin(), assigned_key.end(),
- assigned_key.begin(),
- [](unsigned char c) { return std::tolower(c); });
-
- const std::lock_guard<std::mutex> lock(factory_mutex);
- if (str_map.find(assigned_key) != str_map.end()) {
- std::stringstream ss;
- ss << "cannot register factory with already taken key: " << key;
- throw std::invalid_argument(ss.str().c_str());
- }
-
- if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
- std::stringstream ss;
- ss << "cannot register factory with already taken int key: " << int_key;
- throw std::invalid_argument(ss.str().c_str());
- }
-
- int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
-
- str_map[assigned_key] = factory;
- int_map[assigned_int_key] = assigned_key;
-
- ml_logd("factory has registered with key: %s, int_key: %d",
- assigned_key.c_str(), assigned_int_key);
-
- return assigned_int_key;
- }
+ const int int_key = -1);
/**
* @brief Create an Object from the integer key
template <typename T>
PtrType<T> createObject(const int int_key,
const PropsType &props = {}) const {
+ static_assert(isSupported<T>::value,
+ "given type is not supported for current app context");
auto &index = std::get<IndexType<T>>(factory_map);
auto &int_map = std::get<IntIndexType>(index);
private:
FactoryMap<ml::train::Optimizer, nntrainer::Layer> factory_map;
std::string working_path_base;
+
+ template <typename Args, typename T> struct isSupportedHelper;
+
+ /**
+ * @brief supportHelper to check if given type is supported within appcontext
+ */
+ template <typename T, typename... Args>
+ struct isSupportedHelper<T, AppContext::FactoryMap<Args...>> {
+ static constexpr bool value =
+ (std::is_same_v<std::decay_t<T>, std::decay_t<Args>> || ...);
+ };
+
+ /**
+ * @brief supportHelper to check if given type is supported within appcontext
+ */
+ template <typename T>
+ struct isSupported : isSupportedHelper<T, decltype(factory_map)> {};
};
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+extern template const int AppContext::registerFactory<ml::train::Optimizer>(
+ const FactoryType<ml::train::Optimizer> factory, const std::string &key,
+ const int int_key);
+
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+extern template const int AppContext::registerFactory<nntrainer::Layer>(
+ const FactoryType<nntrainer::Layer> factory, const std::string &key,
+ const int int_key);
+
namespace plugin {}
} // namespace nntrainer