[AppContext] Add registerer,invoke factory methods
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 11 Nov 2020 12:30:23 +0000 (21:30 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 27 Nov 2020 06:32:50 +0000 (15:32 +0900)
**Changes proposed in this PR:**
- Add factory registerer
- Add factory invoker
- Register built-in objects to each layers(postponed)
- Add tests

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nnstreamer/tensor_filter/tensor_filter_nntrainer.cc
nntrainer/app_context.cpp
nntrainer/app_context.h
test/unittest/unittest_nntrainer_appcontext.cpp

index 73164bbd515c62fe9f0a6ad8075fa37d7e1eeabb..e05f975348658082c979189fb412119929577e25 100644 (file)
 
 #include <neuralnet.h>
 
+#ifdef ml_loge
+#undef ml_loge
+#endif
+
 #define ml_loge g_critical
 
 /**
index 01360085503c89713756f34d1a78bdc2d59c4875..83aa3345958a0b358f7acfffdfb7c045c16a1734 100644 (file)
@@ -21,6 +21,8 @@
 
 namespace nntrainer {
 
+std::mutex factory_mutex;
+
 AppContext AppContext::instance;
 
 /**
index c9f8d2a31535aeb83a4e89dd1d85ec149c7e697f..e54f02d0c4c539047388fd9e8201c210287f79df 100644 (file)
 #ifndef __APP_CONTEXT_H__
 #define __APP_CONTEXT_H__
 
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <sstream>
 #include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <layer.h>
+#include <optimizer.h>
+
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
 
 namespace nntrainer {
 
+extern std::mutex factory_mutex;
+
 /**
  * @class AppContext contains user-dependent configuration
  * @brief App
  */
 class AppContext {
 public:
+  using PropsType = std::vector<std::string>;
+
+  template <typename T> using PtrType = std::unique_ptr<T>;
+
+  template <typename T>
+  using FactoryType = std::function<PtrType<T>(const PropsType &)>;
+
+  template <typename T>
+  using PtrFactoryType = PtrType<T> (*)(const PropsType &);
+  template <typename T>
+  using StrIndexType = std::unordered_map<std::string, FactoryType<T>>;
+
+  /** integer to string key */
+  using IntIndexType = std::unordered_map<int, std::string>;
+
+  /**
+   * This type contains tuple of
+   * 1) integer -> string index
+   * 2) string -> factory index
+   */
+  template <typename T>
+  using IndexType = std::tuple<StrIndexType<T>, IntIndexType>;
+
+  template <typename... Ts> using FactoryMap = std::tuple<IndexType<Ts>...>;
+
+  AppContext(){};
+
   /**
+   *
    * @brief Get Global app context.
    *
    * @return AppContext&
@@ -52,9 +94,110 @@ public:
    */
   const std::string getWorkingPath(const std::string &path = "");
 
+  template <typename T>
+  const int registerFactory(const PtrFactoryType<T> factory,
+                            const std::string &key = "",
+                            const int int_key = -1) {
+    FactoryType<T> f = factory;
+    return registerFactory(f, key, int_key);
+  }
+
+  /**
+   * @brief Factory register function, use this function to register custom
+   * object
+   *
+   * @tparam T object to create
+   * @param factory factory function that creates std::unique_ptr<T>
+   * @param key key to access the factory, if key is empty, try to find key by
+   * calling factory({})->getType();
+   * @param int_key key to access the factory by integer, if it is -1(default),
+   * the function automatically unsigned the key and return
+   * @return const int unique integer value to access the current factory
+   * @throw invalid argument when key and/or int_key is already taken
+   */
+  template <typename T>
+  const int registerFactory(const FactoryType<T> factory,
+                            const std::string &key = "",
+                            const int int_key = -1) {
+
+    auto &index = std::get<IndexType<T>>(factory_map);
+    auto &str_map = std::get<StrIndexType<T>>(index);
+    auto &int_map = std::get<IntIndexType>(index);
+
+    const std::string &assigned_key = key == "" ? factory({})->getType() : key;
+
+    const std::lock_guard<std::mutex> lock(factory_mutex);
+    if (str_map.find(assigned_key) != str_map.end()) {
+      std::stringstream ss;
+      ss << "cannot register factory with already taken key: " << key;
+      throw std::invalid_argument(ss.str().c_str());
+    }
+
+    if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
+      std::stringstream ss;
+      ss << "cannot register factory with already taken int key: " << int_key;
+      throw std::invalid_argument(ss.str().c_str());
+    }
+
+    int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
+
+    str_map[assigned_key] = factory;
+    int_map[assigned_int_key] = assigned_key;
+
+    ml_logd("factory has registered with key: %s, int_key: %d",
+            assigned_key.c_str(), assigned_int_key);
+
+    return assigned_int_key;
+  }
+
+  template <typename T>
+  PtrType<T> createObject(const int int_key, const PropsType &props = {}) {
+    auto &index = std::get<IndexType<T>>(factory_map);
+    auto &int_map = std::get<IntIndexType>(index);
+
+    const auto &entry = int_map.find(int_key);
+
+    if (entry == int_map.end()) {
+      std::stringstream ss;
+      ss << "Int Key is not found for the object. Key: " << int_key;
+      throw exception::not_supported(ss.str().c_str());
+    }
+
+    return createObject<T>(entry->second, props);
+  }
+
+  template <typename T>
+  PtrType<T> createObject(const std::string &key, const PropsType &props = {}) {
+    auto &index = std::get<IndexType<T>>(factory_map);
+    auto &str_map = std::get<StrIndexType<T>>(index);
+
+    const auto &entry = str_map.find(key);
+
+    if (entry == str_map.end()) {
+      std::stringstream ss;
+      ss << "Key is not found for the object. Key: " << key;
+      throw exception::not_supported(ss.str().c_str());
+    }
+
+    return entry->second(props);
+  }
+
+  /**
+   * @brief special factory that throws for unknown
+   *
+   * @tparam T object to create
+   * @param props props to pass, not used
+   * @throw always throw runtime_error
+   */
+  template <typename T>
+  static PtrType<T> unknownFactory(const PropsType &props) {
+    throw std::runtime_error("cannot create unknown object");
+  }
+
 private:
   static AppContext instance;
 
+  FactoryMap<ml::train::Optimizer> factory_map;
   std::string working_path_base;
 };
 
index fe78e3cea0c67a8724c2c4911896ebce4fb39729..dba9b1a85457c771d5144fb062786c59907a60d2 100644 (file)
 #include <gtest/gtest.h>
 
 #include <fstream>
+#include <memory>
+#include <typeinfo>
 #include <unistd.h>
 
+#include <optimizer.h>
+
 #include <app_context.h>
+#include <nntrainer_error.h>
 
 class nntrainerAppContextDirectory : public ::testing::Test {
 
@@ -83,6 +88,147 @@ TEST_F(nntrainerAppContextDirectory, notExisitingSetDirectory_n) {
                std::invalid_argument);
 }
 
+class CustomOptimizer : public ml::train::Optimizer {
+public:
+  const std::string getType() const { return "identity_optimizer"; }
+
+  float getLearningRate() { return 1.0f; }
+
+  float getDecayRate() { return 1.0f; }
+
+  float getDecaySteps() { return 1.0f; }
+
+  int setProperty(std::vector<std::string> values) { return 1; }
+
+  void setProperty(const PropertyType type, const std::string &value = "") {}
+
+  void checkValidation() {}
+};
+
+class CustomOptimizer2 : public ml::train::Optimizer {
+public:
+  const std::string getType() const { return "identity_optimizer"; }
+
+  float getLearningRate() { return 1.0f; }
+
+  float getDecayRate() { return 1.0f; }
+
+  float getDecaySteps() { return 1.0f; }
+
+  int setProperty(std::vector<std::string> values) { return 1; }
+
+  void setProperty(const PropertyType type, const std::string &value = "") {}
+
+  void checkValidation() {}
+};
+
+using AC = nntrainer::AppContext;
+
+AC::PtrType<ml::train::Optimizer>
+createCustomOptimizer(const AC::PropsType &v) {
+  auto p = std::make_unique<CustomOptimizer>();
+  p->setProperty(v);
+  return p;
+}
+
+TEST(nntrainerAppContextObjs, RegisterCreateCustomOptimizer_p) {
+
+  // register without key in this case, getType() will be called and used
+  {
+    auto ac = nntrainer::AppContext();
+    int num_id = ac.registerFactory(createCustomOptimizer);
+    auto opt = ac.createObject<ml::train::Optimizer>("identity_optimizer", {});
+    EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+    opt = ac.createObject<ml::train::Optimizer>(num_id, {});
+    EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+  }
+
+  // register with key
+  {
+    auto ac = nntrainer::AppContext();
+    int num_id = ac.registerFactory(createCustomOptimizer, "custom_key");
+    auto opt = ac.createObject<ml::train::Optimizer>("custom_key", {});
+    EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+    opt = ac.createObject<ml::train::Optimizer>(num_id, {});
+    EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+  }
+
+  // register with key and custom id
+  {
+    auto ac = nntrainer::AppContext();
+    int num_id = ac.registerFactory(createCustomOptimizer, "custom_key", 5);
+    EXPECT_EQ(num_id, 5);
+    auto opt = ac.createObject<ml::train::Optimizer>("custom_key", {});
+    EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+    opt = ac.createObject<ml::train::Optimizer>(num_id, {});
+    EXPECT_EQ(typeid(*opt).hash_code(), typeid(CustomOptimizer).hash_code());
+  }
+}
+
+TEST(nntrainerAppContextObjs, RegisterFactoryWithClashingKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  ac.registerFactory(createCustomOptimizer, "custom_key");
+
+  EXPECT_THROW(ac.registerFactory(createCustomOptimizer, "custom_key"),
+               std::invalid_argument);
+}
+
+TEST(nntrainerAppContextObjs, RegisterFactoryWithClashingIntKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  ac.registerFactory(createCustomOptimizer, "custom_key", 3);
+  EXPECT_THROW(ac.registerFactory(createCustomOptimizer, "custom_other_key", 3),
+               std::invalid_argument);
+}
+
+TEST(nntrainerAppContextObjs, RegisterFactoryWithClashingAutoKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  ac.registerFactory(createCustomOptimizer);
+  EXPECT_THROW(ac.registerFactory(createCustomOptimizer),
+               std::invalid_argument);
+}
+
+TEST(nntrainerAppContextObjs, createObjectNotExistingKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  ac.registerFactory(createCustomOptimizer);
+  EXPECT_THROW(ac.createObject<ml::train::Optimizer>("not_exisiting_key"),
+               nntrainer::exception::not_supported);
+}
+
+TEST(nntrainerAppContextObjs, createObjectNotExistingIntKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  int num = ac.registerFactory(createCustomOptimizer);
+  EXPECT_THROW(ac.createObject<ml::train::Optimizer>(num + 3),
+               nntrainer::exception::not_supported);
+}
+
+TEST(nntrainerAppContextObjs, callingUnknownFactoryOptimizerWithKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  int num = ac.registerFactory(
+    nntrainer::AppContext::unknownFactory<ml::train::Optimizer>, "unknown",
+    999);
+
+  EXPECT_EQ(num, 999);
+  EXPECT_THROW(ac.createObject<ml::train::Optimizer>("unknown"),
+               std::runtime_error);
+}
+
+TEST(nntrainerAppContextObjs, callingUnknownFactoryOptimizerWithIntKey_n) {
+  auto ac = nntrainer::AppContext();
+
+  int num = ac.registerFactory(
+    nntrainer::AppContext::unknownFactory<ml::train::Optimizer>, "unknown",
+    999);
+
+  EXPECT_EQ(num, 999);
+  EXPECT_THROW(ac.createObject<ml::train::Optimizer>(num), std::runtime_error);
+}
+
 /**
  * @brief Main gtest
  */