[graph] Support creating RunLayerContext
authorParichay Kapoor <pk.kapoor@samsung.com>
Wed, 16 Jun 2021 10:55:19 +0000 (19:55 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 23 Jun 2021 07:42:19 +0000 (16:42 +0900)
Implement the function to create RunLayerContext given the
InitLayerContext and update it in the given LayerNode.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/layers/layer_context.h
nntrainer/layers/layer_node.h
nntrainer/tensor/manager.cpp
nntrainer/tensor/manager.h

index a95ca84..d7fd2e1 100644 (file)
@@ -695,16 +695,25 @@ void NetworkGraph::inPlaceOptimize(Manager &manager) {
   }
 }
 
-void NetworkGraph::init2runContext(InitLayerContext &init_context,
-                                   RunLayerContext &run_context) {
-  // NOTE: this just create the wrappers and does not actually memory inside
-  // these wrappers
-  // TODO: create wrappers for weights - initialize is already done, so the
-  // object creation can be done
-  // TODO: create wrapper for the temporary tensors
-  // TODO: create wrappers for outputs
-  // TODO: create a new context with these new wrappers and then copy assign
-  // run_context
+void NetworkGraph::updateRunContext(std::shared_ptr<Manager> &manager,
+                                    const std::shared_ptr<LayerNode> &lnode) {
+  /**
+   * using copy assignment allows setting run_context without adding more
+   * interfaces
+   */
+  const GraphNode &gnode = *lnode.get();
+  const InitLayerContext &init_context = lnode->getInitContext();
+  /**
+   * @todo must use existing properties like name/trainable of run_context to
+   * create the new run_context
+   */
+  // const RunLayerContext &run_context = lnode->getRunContext();
+
+  lnode->updateRunContext(RunLayerContext(
+    manager->requestWeights(gnode, init_context.getWeightsSpec()),
+    manager->requestInputs(gnode, init_context.getInputDimensions()),
+    manager->requestOutputs(gnode, init_context.getOutputDimensions()),
+    manager->requestTensors(gnode, init_context.getTensorsSpec())));
 }
 
 int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
@@ -751,13 +760,15 @@ int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
 
     /**
      * Initialize all the layers, allocate output tensors for each layer
-     * and add optimizer related weights for the layer
+     * init2and add optimizer related weights for the layer
      */
-    // TODO: pass init context, this call will fill it
+    // TODO: this call will fill the init context inside the layer
+    // lnode->initialize();
     status = lptr->initialize(*manager);
     NN_RETURN_STATUS();
 
-    // TODO: call init2runContext
+    updateRunContext(manager, lnode);
+    // TODO: remove this
     auto &in_out = manager->trackLayerOutputs(cur_type, lnode->getName(),
                                               lptr->getOutputDimension(),
                                               lptr->getInputDimension());
index d37469a..9573ba4 100644 (file)
@@ -262,8 +262,8 @@ public:
    * @param init_context Init layer context to create run context
    * @param run_context Run layer context to be created
    */
-  void init2runContext(InitLayerContext &init_context,
-                       RunLayerContext &run_context);
+  static void updateRunContext(std::shared_ptr<Manager> &manager,
+                               const std::shared_ptr<LayerNode> &lnode);
 
 private:
   std::map<std::string, std::string> sub_in_out; /** This is map to identify
index 70f48ea..0d4b3c0 100644 (file)
@@ -170,18 +170,17 @@ public:
 
   /**
    * @brief Construct a new Run Layer Context object
+   * @todo  Include properties like name/trainable later
    *
-   * @param props properties of the layer
    * @param w weights of the layer
    * @param in inputs of the layer
    * @param out outputs of the layer
    * @param t extra tensors of the layer
    */
-  RunLayerContext(std::tuple<props::Name> p, const std::vector<Weight *> &w,
+  RunLayerContext(const std::vector<Weight *> &w,
                   const std::vector<Var_Grad *> &in,
                   const std::vector<Var_Grad *> &out,
                   const std::vector<Var_Grad *> &t) :
-    props(p),
     weights(w),
     inputs(in),
     outputs(out),
index f911cb0..64cadc0 100644 (file)
@@ -308,6 +308,30 @@ public:
    */
   friend std::ostream &operator<<(std::ostream &out, const LayerNode &l);
 
+  /**
+   * @brief   Get init layer context
+   *
+   * @retval  init layer context
+   */
+  const InitLayerContext &getInitContext() const { return init_context; }
+
+  /**
+   * @brief   Get run layer context
+   *
+   * @retval  run layer context
+   */
+  const RunLayerContext &getRunContext() const { return run_context; }
+
+  /**
+   * @brief   Set run layer context
+   *
+   * @param  context Updated run layer context
+   */
+  void updateRunContext(RunLayerContext &&context) {
+    // TODO: ensure props/trainable must match
+    run_context = std::move(context);
+  }
+
 private:
   // TODO: make this unique_ptr once getObject API is removed
   std::shared_ptr<nntrainer::LayerV1>
index df7ea05..41696fb 100644 (file)
@@ -632,7 +632,7 @@ void Manager::deinitializeTensors() {
  */
 std::vector<Weight *>
 Manager::requestWeights(const GraphNode &node,
-                        std::vector<Weight::Spec> &weights_spec) {
+                        const std::vector<Weight::Spec> &weights_spec) {
   return _requestTensors<Weight>(node, weights_spec, weights_v2);
 }
 
@@ -642,7 +642,7 @@ Manager::requestWeights(const GraphNode &node,
  */
 std::vector<Var_Grad *>
 Manager::requestTensors(const GraphNode &node,
-                        std::vector<Var_Grad::Spec> &tensors_spec) {
+                        const std::vector<Var_Grad::Spec> &tensors_spec) {
   return _requestTensors<Var_Grad>(node, tensors_spec, tensors_v2);
 }
 
@@ -651,7 +651,7 @@ Manager::requestTensors(const GraphNode &node,
  */
 std::vector<Var_Grad *>
 Manager::requestInputs(const GraphNode &node,
-                       std::vector<TensorDim> &inputs_dim) {
+                       const std::vector<TensorDim> &inputs_dim) {
   unsigned int count = 0;
   std::vector<Var_Grad::Spec> inputs_spec;
   std::transform(
@@ -672,7 +672,7 @@ Manager::requestInputs(const GraphNode &node,
  */
 std::vector<Var_Grad *>
 Manager::requestOutputs(const GraphNode &node,
-                        std::vector<TensorDim> &outputs_dim) {
+                        const std::vector<TensorDim> &outputs_dim) {
   unsigned int count = 0;
   std::vector<Var_Grad::Spec> outputs_spec;
   std::transform(outputs_dim.begin(), outputs_dim.end(),
index f7f0efd..8b7fd10 100644 (file)
@@ -163,8 +163,9 @@ public:
    * @param w   node Graph node to extract node identifiers/info
    * @param w   weights_spec Specficiation for the weights
    */
-  std::vector<Weight *> requestWeights(const GraphNode &node,
-                                       std::vector<Weight::Spec> &weights_spec);
+  std::vector<Weight *>
+  requestWeights(const GraphNode &node,
+                 const std::vector<Weight::Spec> &weights_spec);
 
   /**
    * @brief     Create tensors with the given spec
@@ -174,7 +175,7 @@ public:
    */
   std::vector<Var_Grad *>
   requestTensors(const GraphNode &node,
-                 std::vector<Var_Grad::Spec> &tensors_spec);
+                 const std::vector<Var_Grad::Spec> &tensors_spec);
 
   /**
    * @brief     Create tensors with the given spec
@@ -182,8 +183,9 @@ public:
    * @param w   node Graph node to extract node identifiers/info
    * @param w   create tensors list
    */
-  std::vector<Var_Grad *> requestInputs(const GraphNode &node,
-                                        std::vector<TensorDim> &inputs_dim);
+  std::vector<Var_Grad *>
+  requestInputs(const GraphNode &node,
+                const std::vector<TensorDim> &inputs_dim);
 
   /**
    * @brief     Create tensors with the given spec
@@ -191,8 +193,9 @@ public:
    * @param w   node Graph node to extract node identifiers/info
    * @param w   create tensors list
    */
-  std::vector<Var_Grad *> requestOutputs(const GraphNode &node,
-                                         std::vector<TensorDim> &outputs_spec);
+  std::vector<Var_Grad *>
+  requestOutputs(const GraphNode &node,
+                 const std::vector<TensorDim> &outputs_spec);
 
   /**
    * @brief     Get weights tracked with nntrainer
@@ -552,7 +555,7 @@ private:
    */
   template <typename T>
   static std::vector<T *> _requestTensors(
-    const GraphNode &node, std::vector<typename T::Spec> &tensors_spec,
+    const GraphNode &node, const std::vector<typename T::Spec> &tensors_spec,
     std::vector<std::vector<std::unique_ptr<T>>> &layer_objs_list) {
     std::vector<T *> ret;
     std::vector<std::unique_ptr<T>> tensors_list;