[Header] Remove nntrainer_log.h from app_context.h
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 1 Dec 2021 08:53:23 +0000 (17:53 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 2 Dec 2021 06:54:31 +0000 (15:54 +0900)
This patch removes nntrainer_log.h from app_context.h and implement
additional safecheck

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/app_context.cpp
nntrainer/app_context.h

index 7a52e70..66c930f 100644 (file)
 #include <iniparser.h>
 
 #include <app_context.h>
+#include <layer.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
+#include <optimizer.h>
 #include <util_func.h>
 
 #include <adam.h>
@@ -484,4 +486,58 @@ AppContext::registerPluggableFromDirectory(const std::string &base_path) {
   return keys;
 }
 
+template <typename T>
+const int AppContext::registerFactory(const FactoryType<T> factory,
+                                      const std::string &key,
+                                      const int int_key) {
+  static_assert(isSupported<T>::value,
+                "given type is not supported for current app context");
+
+  auto &index = std::get<IndexType<T>>(factory_map);
+  auto &str_map = std::get<StrIndexType<T>>(index);
+  auto &int_map = std::get<IntIndexType>(index);
+
+  std::string assigned_key = key == "" ? factory({})->getType() : key;
+
+  std::transform(assigned_key.begin(), assigned_key.end(), assigned_key.begin(),
+                 [](unsigned char c) { return std::tolower(c); });
+
+  const std::lock_guard<std::mutex> lock(factory_mutex);
+  if (str_map.find(assigned_key) != str_map.end()) {
+    std::stringstream ss;
+    ss << "cannot register factory with already taken key: " << key;
+    throw std::invalid_argument(ss.str().c_str());
+  }
+
+  if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
+    std::stringstream ss;
+    ss << "cannot register factory with already taken int key: " << int_key;
+    throw std::invalid_argument(ss.str().c_str());
+  }
+
+  int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
+
+  str_map[assigned_key] = factory;
+  int_map[assigned_int_key] = assigned_key;
+
+  ml_logd("factory has registered with key: %s, int_key: %d",
+          assigned_key.c_str(), assigned_int_key);
+
+  return assigned_int_key;
+}
+
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+template const int AppContext::registerFactory<ml::train::Optimizer>(
+  const FactoryType<ml::train::Optimizer> factory, const std::string &key,
+  const int int_key);
+
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+template const int AppContext::registerFactory<nntrainer::Layer>(
+  const FactoryType<nntrainer::Layer> factory, const std::string &key,
+  const int int_key);
+
 } // namespace nntrainer
index 105ad8c..9b14b5d 100644 (file)
@@ -20,7 +20,9 @@
 #include <memory>
 #include <mutex>
 #include <sstream>
+#include <stdexcept>
 #include <string>
+#include <type_traits>
 #include <unordered_map>
 #include <vector>
 
 #include <optimizer.h>
 
 #include <nntrainer_error.h>
-#include <nntrainer_log.h>
 
 namespace nntrainer {
 
 extern std::mutex factory_mutex;
+namespace {} // namespace
 
 /**
  * @class AppContext contains user-dependent configuration
@@ -188,41 +190,7 @@ public:
   template <typename T>
   const int registerFactory(const FactoryType<T> factory,
                             const std::string &key = "",
-                            const int int_key = -1) {
-
-    auto &index = std::get<IndexType<T>>(factory_map);
-    auto &str_map = std::get<StrIndexType<T>>(index);
-    auto &int_map = std::get<IntIndexType>(index);
-
-    std::string assigned_key = key == "" ? factory({})->getType() : key;
-
-    std::transform(assigned_key.begin(), assigned_key.end(),
-                   assigned_key.begin(),
-                   [](unsigned char c) { return std::tolower(c); });
-
-    const std::lock_guard<std::mutex> lock(factory_mutex);
-    if (str_map.find(assigned_key) != str_map.end()) {
-      std::stringstream ss;
-      ss << "cannot register factory with already taken key: " << key;
-      throw std::invalid_argument(ss.str().c_str());
-    }
-
-    if (int_key != -1 && int_map.find(int_key) != int_map.end()) {
-      std::stringstream ss;
-      ss << "cannot register factory with already taken int key: " << int_key;
-      throw std::invalid_argument(ss.str().c_str());
-    }
-
-    int assigned_int_key = int_key == -1 ? str_map.size() + 1 : int_key;
-
-    str_map[assigned_key] = factory;
-    int_map[assigned_int_key] = assigned_key;
-
-    ml_logd("factory has registered with key: %s, int_key: %d",
-            assigned_key.c_str(), assigned_int_key);
-
-    return assigned_int_key;
-  }
+                            const int int_key = -1);
 
   /**
    * @brief Create an Object from the integer key
@@ -235,6 +203,8 @@ public:
   template <typename T>
   PtrType<T> createObject(const int int_key,
                           const PropsType &props = {}) const {
+    static_assert(isSupported<T>::value,
+                  "given type is not supported for current app context");
     auto &index = std::get<IndexType<T>>(factory_map);
     auto &int_map = std::get<IntIndexType>(index);
 
@@ -295,8 +265,39 @@ public:
 private:
   FactoryMap<ml::train::Optimizer, nntrainer::Layer> factory_map;
   std::string working_path_base;
+
+  template <typename Args, typename T> struct isSupportedHelper;
+
+  /**
+   * @brief supportHelper to check if given type is supported within appcontext
+   */
+  template <typename T, typename... Args>
+  struct isSupportedHelper<T, AppContext::FactoryMap<Args...>> {
+    static constexpr bool value =
+      (std::is_same_v<std::decay_t<T>, std::decay_t<Args>> || ...);
+  };
+
+  /**
+   * @brief supportHelper to check if given type is supported within appcontext
+   */
+  template <typename T>
+  struct isSupported : isSupportedHelper<T, decltype(factory_map)> {};
 };
 
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+extern template const int AppContext::registerFactory<ml::train::Optimizer>(
+  const FactoryType<ml::train::Optimizer> factory, const std::string &key,
+  const int int_key);
+
+/**
+ * @copydoc const int AppContext::registerFactory
+ */
+extern template const int AppContext::registerFactory<nntrainer::Layer>(
+  const FactoryType<nntrainer::Layer> factory, const std::string &key,
+  const int int_key);
+
 namespace plugin {}
 
 } // namespace nntrainer