#ifndef __APP_CONTEXT_H__
#define __APP_CONTEXT_H__
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <sstream>
#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <layer.h>
+#include <optimizer.h>
+
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
namespace nntrainer {
+extern std::mutex factory_mutex;
+
/**
* @class AppContext contains user-dependent configuration
* @brief App
*/
class AppContext {
public:
+ using PropsType = std::vector<std::string>;
+
+ template <typename T> using PtrType = std::unique_ptr<T>;
+
+ template <typename T>
+ using FactoryType = std::function<PtrType<T>(const PropsType &)>;
+
+ template <typename T>
+ using PtrFactoryType = PtrType<T> (*)(const PropsType &);
+ template <typename T>
+ using StrIndexType = std::unordered_map<std::string, FactoryType<T>>;
+
+ /** integer to string key */
+ using IntIndexType = std::unordered_map<int, std::string>;
+
+ /**
+ * This type contains tuple of
+ * 1) integer -> string index
+ * 2) string -> factory index
+ */
+ template <typename T>
+ using IndexType = std::tuple<StrIndexType<T>, IntIndexType>;
+
+ template <typename... Ts> using FactoryMap = std::tuple<IndexType<Ts>...>;
+
+ AppContext(){};
+
/**
+ *
* @brief Get Global app context.
*
* @return AppContext&
*/
const std::string getWorkingPath(const std::string &path = "");
+ template <typename T>
+ const int registerFactory(const PtrFactoryType<T> factory,
+ const std::string &key = "",
+ const int int_key = -1) {
+ FactoryType<T> f = factory;
+ return registerFactory(f, key, int_key);
+ }
+
+ /**
+ * @brief Factory register function, use this function to register custom
+ * object
+ *
+ * @tparam T object to create
+ * @param factory factory function that creates std::unique_ptr<T>
+ * @param key key to access the factory, if key is empty, try to find key by
+ * calling factory({})->getType();
+ * @param int_key key to access the factory by integer, if it is -1(default),
+ * the function automatically unsigned the key and return
+ * @return const int unique integer value to access the current factory
+ * @throw invalid argument when key and/or int_key is already taken
+ */
+ template <typename T>
+ const int registerFactory(const FactoryType<T> factory,
+ const std::string &key = "",
+ const int int_key = -1) {
+
+ auto &index = std::get<IndexType<T>>(factory_map);
+ auto &str_map = std::get<StrIndexType<T>>(index);
+ auto &int_map = std::get<IntIndexType>(index);
+
+ const std::string &assigned_key = key == "" ? factory({})->getType() : key;
+
+ const std::lock_guard<std::mutex> lock(factory_mutex);
+ if (str_map.find(assigned_key) != str_map.end()) {
+ std::stringstream ss;
+ ss << "cannot register factory with already taken key: " << key;
+ throw std::invalid_argument(ss.str().c_str());
+ }
+
+ if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
+ std::stringstream ss;
+ ss << "cannot register factory with already taken int key: " << int_key;
+ throw std::invalid_argument(ss.str().c_str());
+ }
+
+ int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
+
+ str_map[assigned_key] = factory;
+ int_map[assigned_int_key] = assigned_key;
+
+ ml_logd("factory has registered with key: %s, int_key: %d",
+ assigned_key.c_str(), assigned_int_key);
+
+ return assigned_int_key;
+ }
+
+ template <typename T>
+ PtrType<T> createObject(const int int_key, const PropsType &props = {}) {
+ auto &index = std::get<IndexType<T>>(factory_map);
+ auto &int_map = std::get<IntIndexType>(index);
+
+ const auto &entry = int_map.find(int_key);
+
+ if (entry == int_map.end()) {
+ std::stringstream ss;
+ ss << "Int Key is not found for the object. Key: " << int_key;
+ throw exception::not_supported(ss.str().c_str());
+ }
+
+ return createObject<T>(entry->second, props);
+ }
+
+ template <typename T>
+ PtrType<T> createObject(const std::string &key, const PropsType &props = {}) {
+ auto &index = std::get<IndexType<T>>(factory_map);
+ auto &str_map = std::get<StrIndexType<T>>(index);
+
+ const auto &entry = str_map.find(key);
+
+ if (entry == str_map.end()) {
+ std::stringstream ss;
+ ss << "Key is not found for the object. Key: " << key;
+ throw exception::not_supported(ss.str().c_str());
+ }
+
+ return entry->second(props);
+ }
+
+ /**
+ * @brief special factory that throws for unknown
+ *
+ * @tparam T object to create
+ * @param props props to pass, not used
+ * @throw always throw runtime_error
+ */
+ template <typename T>
+ static PtrType<T> unknownFactory(const PropsType &props) {
+ throw std::runtime_error("cannot create unknown object");
+ }
+
private:
static AppContext instance;
+ FactoryMap<ml::train::Optimizer> factory_map;
std::string working_path_base;
};
#include <gtest/gtest.h>
#include <fstream>
+#include <memory>
+#include <typeinfo>
#include <unistd.h>
+#include <optimizer.h>
+
#include <app_context.h>
+#include <nntrainer_error.h>
class nntrainerAppContextDirectory : public ::testing::Test {
std::invalid_argument);
}
+class CustomOptimizer : public ml::train::Optimizer {
+public:
+ const std::string getType() const { return "identity_optimizer"; }
+
+ float getLearningRate() { return 1.0f; }
+
+ float getDecayRate() { return 1.0f; }
+
+ float getDecaySteps() { return 1.0f; }
+
+ int setProperty(std::vector<std::string> values) { return 1; }
+
+ void setProperty(const PropertyType type, const std::string &value = "") {}
+
+ void checkValidation() {}
+};
+
+class CustomOptimizer2 : public ml::train::Optimizer {
+public:
+ const std::string getType() const { return "identity_optimizer"; }
+
+ float getLearningRate() { return 1.0f; }
+
+ float getDecayRate() { return 1.0f; }
+
+ float getDecaySteps() { return 1.0f; }
+
+ int setProperty(std::vector<std::string> values) { return 1; }
+
+ void setProperty(const PropertyType type, const std::string &value = "") {}
+
+ void checkValidation() {}
+};
+
+using AC = nntrainer::AppContext;
+
+AC::PtrType<ml::train::Optimizer>
+createCustomOptimizer(const AC::PropsType &v) {
+ auto p = std::make_unique<CustomOptimizer>();
+ p->setProperty(v);
+ return p;
+}
+
+TEST(nntrainerAppContextObjs, RegisterCreateCustomOptimizer_p) {
+
+ // register without key in this case, getType() will be called and used
+ {
+ auto ac = nntrainer::AppContext();
+ int num_id = ac.registerFactory(createCustomOptimizer);
+ auto opt = ac.createObject<ml::train::Optimizer>("identity_optimizer", {});
+ EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+ opt = ac.createObject<ml::train::Optimizer>(num_id, {});
+ EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+ }
+
+ // register with key
+ {
+ auto ac = nntrainer::AppContext();
+ int num_id = ac.registerFactory(createCustomOptimizer, "custom_key");
+ auto opt = ac.createObject<ml::train::Optimizer>("custom_key", {});
+ EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+ opt = ac.createObject<ml::train::Optimizer>(num_id, {});
+ EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+ }
+
+ // register with key and custom id
+ {
+ auto ac = nntrainer::AppContext();
+ int num_id = ac.registerFactory(createCustomOptimizer, "custom_key", 5);
+ EXPECT_EQ(num_id, 5);
+ auto opt = ac.createObject<ml::train::Optimizer>("custom_key", {});
+ EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+ opt = ac.createObject<ml::train::Optimizer>(num_id, {});
+ EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+ }
+}
+
+TEST(nntrainerAppContextObjs, RegisterFactoryWithClashingKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ ac.registerFactory(createCustomOptimizer, "custom_key");
+
+ EXPECT_THROW(ac.registerFactory(createCustomOptimizer, "custom_key"),
+ std::invalid_argument);
+}
+
+TEST(nntrainerAppContextObjs, RegisterFactoryWithClashingIntKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ ac.registerFactory(createCustomOptimizer, "custom_key", 3);
+ EXPECT_THROW(ac.registerFactory(createCustomOptimizer, "custom_other_key", 3),
+ std::invalid_argument);
+}
+
+TEST(nntrainerAppContextObjs, RegisterFactoryWithClashingAutoKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ ac.registerFactory(createCustomOptimizer);
+ EXPECT_THROW(ac.registerFactory(createCustomOptimizer),
+ std::invalid_argument);
+}
+
+TEST(nntrainerAppContextObjs, createObjectNotExistingKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ ac.registerFactory(createCustomOptimizer);
+ EXPECT_THROW(ac.createObject<ml::train::Optimizer>("not_exisiting_key"),
+ nntrainer::exception::not_supported);
+}
+
+TEST(nntrainerAppContextObjs, createObjectNotExistingIntKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ int num = ac.registerFactory(createCustomOptimizer);
+ EXPECT_THROW(ac.createObject<ml::train::Optimizer>(num + 3),
+ nntrainer::exception::not_supported);
+}
+
+TEST(nntrainerAppContextObjs, callingUnknownFactoryOptimizerWithKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ int num = ac.registerFactory(
+ nntrainer::AppContext::unknownFactory<ml::train::Optimizer>, "unknown",
+ 999);
+
+ EXPECT_EQ(num, 999);
+ EXPECT_THROW(ac.createObject<ml::train::Optimizer>("unknown"),
+ std::runtime_error);
+}
+
+TEST(nntrainerAppContextObjs, callingUnknownFactoryOptimizerWithIntKey_n) {
+ auto ac = nntrainer::AppContext();
+
+ int num = ac.registerFactory(
+ nntrainer::AppContext::unknownFactory<ml::train::Optimizer>, "unknown",
+ 999);
+
+ EXPECT_EQ(num, 999);
+ EXPECT_THROW(ac.createObject<ml::train::Optimizer>(num), std::runtime_error);
+}
+
/**
* @brief Main gtest
*/