[LayerNode] Change context to be RAII
authorJihoon Lee <jhoon.it.lee@samsung.com>
Mon, 6 Sep 2021 08:49:02 +0000 (17:49 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 9 Sep 2021 05:03:05 +0000 (14:03 +0900)
This patch refactors context to be RAII to make layer node properties
dumb.

**Changes**
- init context is removed
   - layer node instead created and returned from layerNode::finalize(input_dims)
   - remove dependency to input dimension before initialize
   - layerNode now has input_shapes instead of setting directly from initContext
   - runcontext is used to query things
   - fix multiple bugs regarding the validity of layer node

- networkGraph::updateRunContext() -> finalizeContext() now
finalizes inside this function for brevity
- layerNode::updateRunContext() -> configureContext() to make it RAIIer.
- minor code cleans in networkGraph::initialize

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
Applications/Custom/LayerPlugin/layer_plugin_common_test.cpp
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/layers/layer_node.cpp
nntrainer/layers/layer_node.h
nntrainer/layers/time_dist.cpp
test/unittest/layers/layers_dependent_common_tests.cpp

index 1f9e5016d54a89420e944da54c9a6d98e5762a15..bfa467d19107fb4603dd62ce30f109297574765a 100644 (file)
@@ -46,8 +46,8 @@ TEST_P(LayerPluginCommonTest, DefaultEnvironmentPath_p) {
   auto lnode = std::static_pointer_cast<nntrainer::LayerNode>(l);
 
   EXPECT_THROW(lnode->setProperty({"invalid_values"}), std::invalid_argument);
-  EXPECT_EQ(lnode->getOutputDimensions().size(), size_t(0));
-  EXPECT_EQ(lnode->getInputDimensions().size(), size_t(0));
+  EXPECT_THROW(lnode->getOutputDimensions(), std::runtime_error);
+  EXPECT_THROW(lnode->getInputDimensions(), std::runtime_error);
 }
 
 TEST_P(LayerPluginCommonTest, DefaultEnvironmentPathLayerNotExist_n) {
index 7d0fe912da93d1e740df025adb2533b45a4a364b..842f8a825f99174563382a0f5c88d7d85f232eca 100644 (file)
@@ -342,9 +342,9 @@ int NetworkGraph::checkCompiledGraph() {
   /** Dimension of input layers must be known */
   for (auto iter = cbegin(); iter != cend(); iter++) {
     auto lnode = (*iter);
-    if (lnode->getType() == InputLayer::type) {
-      if (lnode->getInputDimensions().size() == 0) {
-        ml_loge("InputDimension of first layer is not set");
+    if (lnode->getNumInputConnections() == 0) {
+      if (!lnode->hasInputShapeProperty()) {
+        ml_loge("Layer with no inbound connection need input_shape property");
         return ML_ERROR_INVALID_PARAMETER;
       }
     }
@@ -371,12 +371,10 @@ int NetworkGraph::realizeGraph() {
     /** If a layer does not has input nodes, then it must have input dimension
      */
     if (lnode->getNumInputConnections() == 0) {
-      for (unsigned int i = 0; i < lnode->getInputDimensions().size(); ++i) {
-        if (lnode->getInputDimensions()[i].getDataLen() == 0) {
-          ml_loge("Input Dimension must be set");
-          status = ML_ERROR_INVALID_PARAMETER;
-          NN_RETURN_STATUS();
-        }
+      if (!lnode->hasInputShapeProperty()) {
+        ml_loge("Input Dimension must be set");
+        status = ML_ERROR_INVALID_PARAMETER;
+        NN_RETURN_STATUS();
       }
     }
 
@@ -407,8 +405,8 @@ int NetworkGraph::realizeGraph() {
 
   /**
    * invariant: the new realized nodes are added to the end,
-   * otherwise this iteration becomes invalid. So, every iteration must be fresh
-   * iterator as vector resize invalidates all the iterators.
+   * otherwise this iteration becomes invalid. So, every iteration must be
+   * fresh iterator as vector resize invalidates all the iterators.
    */
   for (unsigned int i = 0; i < graph.size(); ++i) {
     auto const &lnode = LNODE(*(cbegin() + i));
@@ -519,7 +517,8 @@ void NetworkGraph::extendGraph(std::vector<std::shared_ptr<LayerNode>> ex_graph,
 
   /**
    * The input_layers for ex_graph[0] here is provided to the backbone by the
-   * ini file and is overwritten here by the model loader for connection making.
+   * ini file and is overwritten here by the model loader for connection
+   * making.
    *
    * This loop intends to connect a new backbone to be added with an old
    * backbone.
@@ -666,15 +665,17 @@ void NetworkGraph::inPlaceOptimize() {
 }
 
 std::vector<Var_Grad *>
-NetworkGraph::updateRunContext(std::shared_ptr<Manager> &tensor_manager,
-                               const std::shared_ptr<LayerNode> &lnode,
-                               const std::vector<Var_Grad *> &prev_inputs) {
-  /**
-   * using copy assignment allows setting run_context without adding more
-   * interfaces
-   */
+NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
+                              const std::vector<Var_Grad *> &prev_inputs) {
   const GraphNode &gnode = *lnode.get();
-  const InitLayerContext &init_context = lnode->getInitContext();
+  std::vector<TensorDim> input_dims;
+  input_dims.reserve(prev_inputs.size());
+  std::transform(prev_inputs.begin(), prev_inputs.end(),
+                 std::back_inserter(input_dims),
+                 [](const Var_Grad *vg) { return vg->getDim(); });
+
+  auto init_context = lnode->finalize(input_dims);
+
   std::vector<Var_Grad *> inputs = prev_inputs;
   if (inputs.empty())
     inputs =
@@ -683,17 +684,11 @@ NetworkGraph::updateRunContext(std::shared_ptr<Manager> &tensor_manager,
   const std::vector<Var_Grad *> &outputs =
     tensor_manager->requestOutputs(gnode, init_context.getOutputDimensions());
 
-  /**
-   * @note must use existing properties like name/trainable of run_context to
-   * create the new run_context
-   */
-  const RunLayerContext &run_context = lnode->getRunContext();
-  lnode->updateRunContext(RunLayerContext(
-    run_context.getName(), run_context.getLoss(),
+  lnode->configureRunContext(
     // TODO: update weights spec for trainable based on layer trainable prop
     tensor_manager->requestWeights(gnode, init_context.getWeightsSpec()),
     inputs, outputs,
-    tensor_manager->requestTensors(gnode, init_context.getTensorsSpec())));
+    tensor_manager->requestTensors(gnode, init_context.getTensorsSpec()));
 
   return outputs;
 }
@@ -708,9 +703,9 @@ int NetworkGraph::initialize() {
     return node->getInputConnections().empty();
   };
 
+  std::vector<Var_Grad *> inputs;
   for (unsigned int idx = 0; idx < graph.size(); ++idx) {
     auto const &lnode = getSortedLayerNode(idx);
-    std::string cur_type = lnode->getType();
     ml_logd("layer name : %s", lnode->getName().c_str());
 
     /**
@@ -718,41 +713,16 @@ int NetworkGraph::initialize() {
      * For input layer, as input dimension is known, set input tensor.
      */
     if (!is_input_node(lnode)) {
-      auto &input_layers = lnode->getInputLayers();
-      lnode->resizeInputDimensions(input_layers.size());
-      for (unsigned int i = 0; i < input_layers.size(); ++i) {
-        auto in_layer_node = getLayerNode(input_layers[i]);
-
-        auto const &in_layer_out_connect = in_layer_node->getOutputLayers();
-        unsigned int location =
-          std::find(in_layer_out_connect.begin(), in_layer_out_connect.end(),
-                    lnode->getName()) -
-          in_layer_out_connect.begin();
-
-#ifdef DEBUG
-        if (location == in_layer_out_connect.size())
-          throw std::runtime_error("Invalid connection between nodes.");
-#endif
-
-        lnode->setInputDimension(in_layer_node->getOutputDimensions()[location],
-                                 i);
-      }
+      if (input_map.find(lnode->getName()) == input_map.end())
+        throw std::runtime_error("Cannot find input buffers for the node");
+      inputs = input_map.at(lnode->getName());
     }
 
     /**
      * Initialize all the layers, allocate output tensors for each layer
      * init2and add optimizer related weights for the layer
      */
-    lnode->finalize();
-
-    std::vector<Var_Grad *> inputs = {};
-    if (!is_input_node(lnode)) {
-      if (input_map.find(lnode->getName()) == input_map.end())
-        throw std::runtime_error("Cannot find input buffers for the node");
-      inputs = input_map.at(lnode->getName());
-    }
-    const std::vector<Var_Grad *> &outputs =
-      updateRunContext(tensor_manager, lnode, inputs);
+    const std::vector<Var_Grad *> &outputs = finalizeContext(lnode, inputs);
 
     /** no need to update input_map for the last layer */
     if (idx == graph.size() - 1)
index 4e8143ecf821d93a0dc32c11929de7cc1a2c006a..c3dc96c707c4236511cddef673084b2ad61e5548 100644 (file)
@@ -248,13 +248,12 @@ public:
   /**
    * @brief Create run layer context from the given init layer context
    *
-   * @param init_context Init layer context to create run context
-   * @param run_context Run layer context to be created
+   * @param lnode layer node to finalize and set run context
+   * @param prev_inputs previous input information
    */
-  static std::vector<Var_Grad *>
-  updateRunContext(std::shared_ptr<Manager> &manager,
-                   const std::shared_ptr<LayerNode> &lnode,
-                   const std::vector<Var_Grad *> &inputs);
+  std::vector<Var_Grad *>
+  finalizeContext(const std::shared_ptr<LayerNode> &lnode,
+                  const std::vector<Var_Grad *> &prev_inputs);
 
   /** Interface for manager */
 
index 1e1c6beef05692e01ac0480c7c163115042d8f98..eeda2c54befb79d5ae447fe4ad2b1b92f02967c6 100644 (file)
 #include <weight.h>
 
 namespace nntrainer {
+RunLayerContext::RunLayerContext(const std::string &name, float l,
+                                 const std::vector<Weight *> &w,
+                                 const std::vector<Var_Grad *> &in,
+                                 const std::vector<Var_Grad *> &out,
+                                 const std::vector<Var_Grad *> &t) :
+  loss(l),
+  weights(w),
+  inputs(in),
+  outputs(out),
+  tensors(t) {
+  std::get<props::Name>(props).set(name);
+  NNTR_THROW_IF(!readyToUse(), std::invalid_argument)
+    << "run context is not ready to use upon creation";
+}
 
 /**
  * @brief Get the Weight tensor object
index 49acac54384161a74df4b3a64b05042eebebb00b..bf19f6ce4e5184d20e46a103e7e38586bc66f943 100644 (file)
@@ -37,22 +37,20 @@ class Var_Grad;
  */
 class InitLayerContext {
 public:
-  /**
-   * @brief Construct a new Init Layer Context object
-   *
-   */
-  InitLayerContext() : InitLayerContext({}, 1) {}
-
   /**
    * @brief Construct a new Init Layer Context object
    *
    * @param dim Input dimensions for the layer
    */
   InitLayerContext(const std::vector<TensorDim> &dim, unsigned int num_out,
-                   const std::string &n = "") :
+                   const std::string &n) :
     input_dim(dim),
     num_outputs(num_out),
-    name(n) {}
+    name(n) {
+    NNTR_THROW_IF(!validate(), std::invalid_argument)
+      << "Invalid init context name: " << name
+      << " num inputs: " << getNumInputs();
+  }
 
   /**
    * @brief   get name by the layer
@@ -269,8 +267,9 @@ public:
       }
     }
 
-    if (name.empty())
+    if (name.empty()) {
       return false;
+    }
 
     return true;
   }
@@ -299,20 +298,6 @@ private:
  */
 class RunLayerContext {
 public:
-  /**
-   * @brief Construct a new Run Layer Context object
-   *
-   */
-  RunLayerContext() : loss(0.0) {}
-
-  /**
-   * @brief Construct a new Run Layer Context object
-   *
-   */
-  RunLayerContext(const std::string &name) : RunLayerContext() {
-    std::get<props::Name>(props).set(name);
-  }
-
   /**
    * @brief Construct a new Run Layer Context object
    * @todo  Include properties like name/trainable later
@@ -326,14 +311,7 @@ public:
                   const std::vector<Weight *> &w,
                   const std::vector<Var_Grad *> &in,
                   const std::vector<Var_Grad *> &out,
-                  const std::vector<Var_Grad *> &t) :
-    loss(l),
-    weights(w),
-    inputs(in),
-    outputs(out),
-    tensors(t) {
-    std::get<props::Name>(props).set(name);
-  }
+                  const std::vector<Var_Grad *> &t);
 
   /**
    * @brief Get the Weight tensor object
index a4acfda67e94e6fc7648ca907fe0f2e305ef77de..5266505bf6477f98847cdc68fbd6ec777858b71e 100644 (file)
@@ -141,8 +141,9 @@ createLayerNode(std::unique_ptr<nntrainer::Layer> &&layer,
 
 LayerNode::LayerNode(std::unique_ptr<nntrainer::Layer> &&l) :
   layer(std::move(l)),
-  finalized(false),
   activation_type(ActivationType::ACT_NONE),
+  run_context(nullptr),
+  input_shapes(),
   layer_node_props(new PropsType(props::Name(), props::Flatten(),
                                  props::Distribute(), props::Trainable(), {})),
   loss(new props::Loss()),
@@ -207,11 +208,11 @@ bool LayerNode::setProperty(const std::string &key, const std::string &value) {
   PropertyType type = static_cast<PropertyType>(parseLayerProperty(key));
   switch (type) {
   case PropertyType::input_shape: {
-    std::vector<TensorDim> input_dim = init_context.getInputDimensions();
-    if (getNumInputs() > 1) {
+    std::vector<TensorDim> input_dim = input_shapes;
+    if (input_shapes.size() > 1) {
       throw std::invalid_argument("input_shape keyword is only for one input");
     }
-    if (getNumInputs() == 0)
+    if (input_shapes.empty())
       input_dim.resize(1);
 
     TensorDim &in_dim = input_dim[0];
@@ -231,8 +232,7 @@ bool LayerNode::setProperty(const std::string &key, const std::string &value) {
       /** set back to cache value of dimension */
       in_dim.batch(cache_batch_size);
       throw_status(status);
-
-      init_context = InitLayerContext(input_dim, init_context.getNumOutputs());
+      input_shapes = std::move(input_dim);
     }
   } break;
   case PropertyType::activation: {
@@ -373,6 +373,36 @@ void LayerNode::setInputLayers(const std::vector<std::string> &layers) {
   input_layers = std::vector<props::InputLayer>(layers.begin(), layers.end());
 }
 
+bool LayerNode::hasInputShapeProperty() const { return !input_shapes.empty(); }
+
+const std::vector<TensorDim> LayerNode::getInputDimensions() const {
+  NNTR_THROW_IF(!run_context, std::runtime_error)
+    << __func__ << " layer needs to be finalized first!";
+  auto sz = run_context->getNumInputs();
+  std::vector<TensorDim> dims;
+  dims.reserve(sz);
+
+  for (auto i = 0u; i < sz; ++i) {
+    dims.push_back(run_context->getInput(i).getDim());
+  }
+
+  return dims;
+}
+
+const std::vector<TensorDim> LayerNode::getOutputDimensions() const {
+  NNTR_THROW_IF(!run_context, std::runtime_error)
+    << __func__ << " layer needs to be finalized first!";
+  auto sz = run_context->getNumOutputs();
+  std::vector<TensorDim> dims;
+  dims.reserve(sz);
+
+  for (auto i = 0u; i < sz; ++i) {
+    dims.push_back(run_context->getOutput(i).getDim());
+  }
+
+  return dims;
+}
+
 void LayerNode::exportTo(Exporter &exporter,
                          const ExportMethods &method) const {
   exporter.saveResult(*layer_node_props, method, this);
@@ -382,69 +412,93 @@ void LayerNode::exportTo(Exporter &exporter,
 }
 
 void LayerNode::read(std::ifstream &file) {
-  for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
-    run_context.getWeight(i).read(file);
+  NNTR_THROW_IF(!run_context, std::runtime_error)
+    << __func__ << " layer needs to be finalized first!";
+  for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
+    run_context->getWeight(i).read(file);
   }
 }
 
 void LayerNode::save(std::ofstream &file) const {
-  for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
-    run_context.getWeight(i).save(file);
+  NNTR_THROW_IF(!run_context, std::runtime_error)
+    << __func__ << " layer needs to be finalized first!";
+  for (unsigned int i = 0; i < run_context->getNumWeights(); ++i) {
+    run_context->getWeight(i).save(file);
   }
 }
 
 /**
  * @brief     Finalize creating the layer node
  */
-void LayerNode::finalize() {
+InitLayerContext LayerNode::finalize(const std::vector<TensorDim> &input_dims) {
+  std::vector<TensorDim> actual_input_dims;
+  auto &prop_dims = input_shapes;
+  if (!input_dims.empty()) {
+    actual_input_dims = input_dims;
+    if (!prop_dims.empty()) {
+      /// if prop_dims exist, check if it's same with given input_dims
+      NNTR_THROW_IF(input_dims != prop_dims, std::invalid_argument)
+        << "calculated input dimension is different from given input_shape "
+           "property";
+    }
+  } else {
+    actual_input_dims = prop_dims;
+  }
+
+  NNTR_THROW_IF(input_dims.size() < getNumInputConnections(),
+                std::invalid_argument)
+    << "number of input dimensions must be equal or larger "
+    << "than number of input connections, node name: " << getName()
+    << " num input dims: " << input_dims.size()
+    << " num connections: " << getNumInputConnections();
+
   /** Create init context right before finalize */
-  if (finalized)
+  if (run_context)
     throw std::runtime_error("Finalizing a layer which is already finalized");
 
-  init_context = InitLayerContext(init_context.getInputDimensions(),
-                                  init_context.getNumOutputs(), getName());
-  NNTR_THROW_IF(!init_context.validate(), std::invalid_argument)
-    << "Invalid init context, name: " << getName()
-    << " initContext num inputs: " << init_context.getNumInputs();
+  auto num_outputs = output_layers.size();
+  if (output_layers.empty()) {
+    num_outputs = 1;
+  }
+
+  auto init_context =
+    InitLayerContext(actual_input_dims, num_outputs, getName());
 
-  if (layer)
-    layer->finalize(init_context);
-  finalized = true;
-  run_context = RunLayerContext(getName());
+  layer->finalize(init_context);
+  return init_context;
 }
 
 /**
  * @brief     Forward Propagation of a layer
  */
 void LayerNode::forwarding(bool training) {
-  loss->set(run_context.getRegularizationLoss());
-  layer->forwarding(run_context, training);
+  loss->set(run_context->getRegularizationLoss());
+  layer->forwarding(*run_context, training);
 }
 
 /**
  * @brief     calc the derivative to be passed to the previous layer
  */
-void LayerNode::calcDerivative() { layer->calcDerivative(run_context); }
+void LayerNode::calcDerivative() { layer->calcDerivative(*run_context); }
 
 /**
  * @brief     Calculate the derivative of a layer
  */
-void LayerNode::calcGradient() { layer->calcGradient(run_context); }
+void LayerNode::calcGradient() { layer->calcGradient(*run_context); }
 
 /**
  * @brief Set the batch for the layer
  */
 void LayerNode::setBatch(unsigned int batch) {
-  run_context.setBatch(batch);
-  init_context.setBatch(batch);
-
-  if (finalized) {
-    if (run_context.readyToUse()) {
-      getLayer()->setBatch(run_context, batch);
-    } else {
-      /** run_context has not been created yet */
-      getLayer()->setBatch(init_context, batch);
-    }
+  /** @todo we won't going to need Layer::setBatch(InitLayerContext), remove it
+   */
+  for (auto &input_shape : input_shapes) {
+    input_shape.batch(batch);
+  }
+
+  if (run_context) {
+    run_context->setBatch(batch);
+    getLayer()->setBatch(*run_context, batch);
   }
 }
 
@@ -465,21 +519,17 @@ bool LayerNode::requireLabel() const { return getLayer()->requireLabel(); }
 float LayerNode::getLoss() const {
   /** add loss only for loss layers */
   if (requireLabel())
-    loss->set(*loss + run_context.getLoss());
+    loss->set(*loss + run_context->getLoss());
 
   return *loss;
 }
 
-void LayerNode::setInputDimension(const TensorDim &dim, unsigned int idx) {
-  NNTR_THROW_IF(idx >= getNumInputs(), std::out_of_range)
-    << "Setting dimensions out of bounds, idx: " << idx
-    << " size: " << getNumInputs() << " name: " << getName();
-
-  std::vector<TensorDim> input_dim = init_context.getInputDimensions();
-  if (input_dim[idx] != dim) {
-    input_dim[idx] = dim;
-    init_context = InitLayerContext(input_dim, init_context.getNumOutputs());
-  }
+void LayerNode::configureRunContext(const std::vector<Weight *> &weights,
+                                    const std::vector<Var_Grad *> &inputs,
+                                    const std::vector<Var_Grad *> &outputs,
+                                    const std::vector<Var_Grad *> &tensors) {
+  run_context = std::make_unique<RunLayerContext>(getName(), 0.0f, weights,
+                                                  inputs, outputs, tensors);
 }
 
 /**
@@ -518,24 +568,15 @@ void LayerNode::printPreset(std::ostream &out, PrintPreset preset) {
   print(out, flags);
 }
 
-void LayerNode::resizeInputDimensions(unsigned int size) {
-  auto cur_input_dim = init_context.getInputDimensions();
-  if (cur_input_dim.size() != size) {
-    cur_input_dim.resize(size);
-    init_context =
-      InitLayerContext(cur_input_dim, init_context.getNumOutputs());
-  }
-}
-
 void LayerNode::printShapeInfo(std::ostream &out) {
-  for (unsigned int idx = 0; idx < init_context.getNumInputs(); ++idx) {
-    out << "input " << init_context.getInputDimensions()[idx];
+  for (unsigned int idx = 0; idx < getNumInputs(); ++idx) {
+    out << "input " << run_context->getInput(idx).getDim();
   }
-  for (unsigned int i = 0; i < init_context.getNumWeights(); i++) {
-    out << "weight" << std::get<0>(init_context.getWeightsSpec()[i]);
+  for (unsigned int idx = 0; idx < getNumWeights(); idx++) {
+    out << "weight " << run_context->getWeight(idx).getDim();
   }
-  for (unsigned int idx = 0; idx < init_context.getNumOutputs(); ++idx) {
-    out << "output " << init_context.getOutputDimensions()[idx];
+  for (unsigned int idx = 0; idx < getNumOutputs(); ++idx) {
+    out << "output " << run_context->getOutput(idx).getDim();
   }
 }
 
@@ -556,7 +597,7 @@ void LayerNode::print(std::ostream &out, unsigned int flags) {
   }
 
   if (flags & PRINT_SHAPE_INFO) {
-    if (init_context.validate()) {
+    if (run_context) {
       out << "======shape information: " << std::endl;
       printShapeInfo(out);
     }
@@ -573,13 +614,10 @@ void LayerNode::print(std::ostream &out, unsigned int flags) {
   }
 
   if (flags & PRINT_WEIGHTS) {
-    if (init_context.validate()) {
+    if (run_context) {
       out << "======weights: " << std::endl;
-      for (unsigned int idx = 0; idx < init_context.getNumWeights(); idx++) {
-        out << '[' << std::get<5>(init_context.getWeightsSpec()[idx]) << ']'
-            << std::endl;
-        if (run_context.readyToUse())
-          out << run_context.getWeight(idx);
+      for (unsigned int idx = 0; idx < getNumWeights(); idx++) {
+        out << run_context->getWeight(idx);
       }
     }
   }
index a442d6dc66fd22ad3ed564902ea84bbf4e666426..d0f13262295863543993cd7749652716c40cf3a7 100644 (file)
@@ -57,11 +57,6 @@ class InputLayer;
  */
 class LayerNode final : public ml::train::Layer, public GraphNode {
 public:
-  /**
-   * @brief Default constructor
-   */
-  LayerNode() : LayerNode(nullptr) {}
-
   /**
    * @brief Constructor of LayerNode class for v2
    * @param l layer to wrap with, the ownership is transferred to layer node
@@ -159,17 +154,17 @@ public:
   /**
    * @brief     Finalize creating the layer node
    *
-   * @details   Input dimensions will be provided set in the context. This
-   * function must set output dimensions in the given context. Further, context
-   * can be used to request weights for the layer, and any extra tensor required
-   * for the operation of the layer.
+   * @param   input_dims input dimension provided to be used to set output
+   * dimensions. if empty function This function must set output dimensions in
+   * the given context. Further, context can be used to request weights for the
+   * layer, and any extra tensor required for the operation of the layer.
    * @note      After calling this it is not allowed to
    * change properties.
    * @note      No memory allocation must be performed in the initialization
    * step. Any tensor memory required must be requested to the context which
    * will be made available during execution of the layer with the context.
    */
-  void finalize();
+  InitLayerContext finalize(const std::vector<TensorDim> &input_dims = {});
 
   /**
    * @brief     Forward Propagation of a layer
@@ -291,20 +286,32 @@ public:
    * @brief     Get number of inputs
    * @retval    number of inputs
    */
-  unsigned int getNumInputs() const { return init_context.getNumInputs(); }
+  unsigned int getNumInputs() const {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getNumInputs();
+  }
 
   /**
    * @brief     Get number of outputs
    * @retval    number of outputs
    */
-  unsigned int getNumOutputs() const { return init_context.getNumOutputs(); }
+  unsigned int getNumOutputs() const {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getNumOutputs();
+  }
 
   /**
    * @brief Get the number of weights
    *
    * @return unsigned int number of weights
    */
-  unsigned int getNumWeights() const { return init_context.getNumWeights(); }
+  unsigned int getNumWeights() const {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getNumWeights();
+  }
 
   /**
    * @brief     Get the Input Layers object
@@ -352,8 +359,6 @@ public:
    */
   void addOutputLayers(const std::string &out_layer) {
     output_layers.push_back(out_layer);
-    init_context =
-      InitLayerContext(init_context.getInputDimensions(), output_layers.size());
   }
 
   /**
@@ -370,27 +375,26 @@ public:
    */
   void setOutputLayers(const std::vector<std::string> &layers) {
     output_layers = layers;
-    init_context =
-      InitLayerContext(init_context.getInputDimensions(),
-                       std::max((unsigned int)output_layers.size(), 1u));
   }
 
+  /**
+   * @brief check if input shape property is set
+   *
+   * @return bool true if input shape property has set
+   */
+  bool hasInputShapeProperty() const;
+
   /**
    * @brief Get the input dimension
    * @return TensorDim dimension of the input
    */
-  const std::vector<TensorDim> getInputDimensions() const {
-    return init_context.getInputDimensions();
-  }
+  const std::vector<TensorDim> getInputDimensions() const;
 
   /**
    * @brief Get the output dimension
    * @return TensorDim dimension of the output
    */
-  const std::vector<TensorDim> getOutputDimensions() const {
-    return init_context.getOutputDimensions();
-  }
-
+  const std::vector<TensorDim> getOutputDimensions() const;
   /**
    * @brief Get the Weight object
    *
@@ -398,12 +402,15 @@ public:
    * @return Weight& Reference to the weight
    */
   Weight getWeightWrapper(unsigned int idx) {
-    if (run_context.weightHasGradient(idx)) {
-      return Weight(run_context.getWeight(idx), run_context.getWeightGrad(idx),
-                    run_context.getWeightName(idx));
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    if (run_context->weightHasGradient(idx)) {
+      return Weight(run_context->getWeight(idx),
+                    run_context->getWeightGrad(idx),
+                    run_context->getWeightName(idx));
     } else {
-      return Weight(run_context.getWeight(idx), Tensor(),
-                    run_context.getWeightName(idx));
+      return Weight(run_context->getWeight(idx), Tensor(),
+                    run_context->getWeightName(idx));
     }
   }
 
@@ -414,7 +421,9 @@ public:
    * @return Tensor& Reference to the weight tensor
    */
   Weight &getWeightObject(unsigned int idx) {
-    return run_context.getWeightObject(idx);
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getWeightObject(idx);
   }
 
   /**
@@ -423,7 +432,11 @@ public:
    * @param idx Identifier of the weight
    * @return Tensor& Reference to the weight tensor
    */
-  Tensor &getWeight(unsigned int idx) { return run_context.getWeight(idx); }
+  Tensor &getWeight(unsigned int idx) {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getWeight(idx);
+  }
 
   /**
    * @brief Get the Weight Gradient tensor object
@@ -432,7 +445,9 @@ public:
    * @return Tensor& Reference to the weight grad tensor
    */
   Tensor &getWeightGrad(unsigned int idx) {
-    return run_context.getWeightGrad(idx);
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getWeightGrad(idx);
   }
 
   /**
@@ -442,7 +457,9 @@ public:
    * @return const std::string &Name of the weight
    */
   const std::string &getWeightName(unsigned int idx) {
-    return run_context.getWeightName(idx);
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getWeightName(idx);
   }
 
   /**
@@ -451,7 +468,11 @@ public:
    * @param idx Identifier of the input
    * @return Tensor& Reference to the input grad tensor
    */
-  Tensor &getInput(unsigned int idx) { return run_context.getInput(idx); }
+  Tensor &getInput(unsigned int idx) {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getInput(idx);
+  }
 
   /**
    * @brief Get the Input Grad tensor object
@@ -460,7 +481,9 @@ public:
    * @return Tensor& Reference to the input grad tensor
    */
   Tensor &getInputGrad(unsigned int idx) {
-    return run_context.getInputGrad(idx);
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getInputGrad(idx);
   }
 
   /**
@@ -469,7 +492,11 @@ public:
    * @param idx Identifier of the output
    * @return Tensor& Reference to the output tensor
    */
-  Tensor &getOutput(unsigned int idx) { return run_context.getOutput(idx); }
+  Tensor &getOutput(unsigned int idx) {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getOutput(idx);
+  }
 
   /**
    * @brief Get the Output Grad tensor object
@@ -478,7 +505,9 @@ public:
    * @return Tensor& Reference to the output grad tensor
    */
   Tensor &getOutputGrad(unsigned int idx) {
-    return run_context.getOutputGrad(idx);
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return run_context->getOutputGrad(idx);
   }
 
   /**
@@ -488,7 +517,7 @@ public:
    * @return Tensor& Reference to the output grad tensor
    */
   Tensor &getOutputGradUnsafe(unsigned int idx) {
-    return run_context.getOutputGradUnsafe(idx);
+    return run_context->getOutputGradUnsafe(idx);
   }
 
   /**
@@ -518,37 +547,29 @@ public:
    */
   friend std::ostream &operator<<(std::ostream &out, const LayerNode &l);
 
-  /**
-   * @brief   Get init layer context
-   *
-   * @retval  init layer context
-   */
-  const InitLayerContext &getInitContext() const { return init_context; }
-
   /**
    * @brief   Get run layer context
    *
    * @retval  run layer context
    */
-  const RunLayerContext &getRunContext() const { return run_context; }
-
-  /**
-   * @brief   Set run layer context
-   *
-   * @param  context Updated run layer context
-   */
-  void updateRunContext(RunLayerContext &&context) {
-    // TODO: ensure props/trainable must match
-    run_context = std::move(context);
+  const RunLayerContext &getRunContext() const {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+    return *run_context;
   }
 
   /**
-   * @brief Set input dimension for the layer
+   * @brief Set the Run Context object with given tensor packs
    *
-   * @param dim Input tensor dim
-   * @param idx Index of the dim
+   * @param weights weights
+   * @param inputs inputs
+   * @param outputs outputs
+   * @param tensors tensors
    */
-  void setInputDimension(const TensorDim &dim, unsigned int idx);
+  void configureRunContext(const std::vector<Weight *> &weights,
+                           const std::vector<Var_Grad *> &inputs,
+                           const std::vector<Var_Grad *> &outputs,
+                           const std::vector<Var_Grad *> &tensors);
 
   /**
    * @brief Preset modes for printing summary for the layer
@@ -570,36 +591,27 @@ public:
   void printPreset(std::ostream &out,
                    PrintPreset preset = PrintPreset::PRINT_SUMMARY);
 
-  /**
-   * @brief     Resize the input dimensions
-   *
-   * @param size Number of input dimensions
-   */
-  void resizeInputDimensions(unsigned int size);
-
 private:
   std::unique_ptr<nntrainer::Layer>
     layer; /**< The actual object in the graph node */
 
-  bool finalized; /**< if the layer node has been finalized */
-
   std::vector<std::string> output_layers; /**< output layer names */
   ActivationType
     activation_type; /**< activation applied to the output of this node */
 
-  InitLayerContext init_context; /**< context to be built for/while
-                                    initialization of the layer. This will also
-                                    contain the properties of the layer. */
-
-  RunLayerContext run_context; /**< context required for running/execution of
-                    the layer. This will also contain the properties of the
-                    layer. The properties will be copied upon final creation.
-                    Editing properties of the layer after init will not the
-                    properties in the context/graph unless intended. */
+  std::unique_ptr<RunLayerContext>
+    run_context; /**< context required for running/execution of the layer. This
+will also contain the properties of the layer. The properties will be copied
+upon final creation. Editing properties of the layer after init will not the
+properties in the context/graph unless intended. */
 
   using PropsType =
     std::tuple<props::Name, props::Flatten, props::Distribute, props::Trainable,
                std::vector<props::InputLayer>>;
+
+  std::vector<TensorDim>
+    input_shapes; /**< input shapes, @see LayerNode::finalize() to know how this
+                     is interpreted */
   /**
    * These properties are set for the layer by the user but are intercepted
    * and used in the node which forms the basic element of the graph.
index cbaffc0069a20711d48c6ddef3ea9c71044878c9..7f4d0f95a71bfc5f021260b11550ec627a009c8f 100644 (file)
@@ -120,7 +120,7 @@ void TimeDistLayer::finalize(InitLayerContext &context) {
    */
   TensorDim dist_dim = input_dim;
   dist_dim.height(1);
-  InitLayerContext dist_context({dist_dim}, context.getNumOutputs());
+  InitLayerContext dist_context({dist_dim}, context.getNumOutputs(), getType());
 
   // During forwarding and backwarding, it set the input and output buffer of
   // dist_layer properly
@@ -407,7 +407,8 @@ void TimeDistLayer::setBatch(RunLayerContext &context, unsigned int batch) {
 void TimeDistLayer::setBatch(InitLayerContext &context, unsigned int batch) {
   TensorDim input_dim = context.getInputDimensions()[SINGLE_INOUT_IDX];
   input_dim.height(1);
-  InitLayerContext dist_context({input_dim}, context.getNumOutputs());
+  InitLayerContext dist_context({input_dim}, context.getNumOutputs(),
+                                getType());
 
   TensorDim output_dim = context.getOutputDimensions()[0];
   // input_dim.height is number of time iteration
index f2d5ecfd198088cffc7fe44498c2f132d0747208..84a953af093b349e1ec210289d66d78e3ff5a36f 100644 (file)
@@ -47,9 +47,7 @@ TEST_P(LayerSemantics, finalizeValidateLayerNode_p) {
   EXPECT_NO_THROW(lnode->setProperty(valid_properties));
 
   if (!must_fail) {
-    EXPECT_NO_THROW(lnode->finalize());
-
-    auto &init_context = lnode->getInitContext();
+    nntrainer::InitLayerContext init_context = lnode->finalize();
     EXPECT_EQ(init_context.getOutputDimensions().size(),
               init_context.getNumOutputs());
 
@@ -87,9 +85,6 @@ TEST_P(LayerSemantics, setBatchValidateLayerNode_p) {
 
   if (!must_fail) {
     EXPECT_NO_THROW(lnode->finalize());
-    auto &init_context = lnode->getInitContext();
-    EXPECT_NO_THROW(
-      lnode->setBatch(init_context.getInputDimensions()[0].batch() + 10));
   } else {
     EXPECT_THROW(lnode->finalize(), nntrainer::exception::not_supported);
   }