}
}
-void NetworkGraph::init2runContext(InitLayerContext &init_context,
- RunLayerContext &run_context) {
- // NOTE: this just create the wrappers and does not actually memory inside
- // these wrappers
- // TODO: create wrappers for weights - initialize is already done, so the
- // object creation can be done
- // TODO: create wrapper for the temporary tensors
- // TODO: create wrappers for outputs
- // TODO: create a new context with these new wrappers and then copy assign
- // run_context
+void NetworkGraph::updateRunContext(std::shared_ptr<Manager> &manager,
+ const std::shared_ptr<LayerNode> &lnode) {
+ /**
+ * using copy assignment allows setting run_context without adding more
+ * interfaces
+ */
+ const GraphNode &gnode = *lnode.get();
+ const InitLayerContext &init_context = lnode->getInitContext();
+ /**
+ * @todo must use existing properties like name/trainable of run_context to
+ * create the new run_context
+ */
+ // const RunLayerContext &run_context = lnode->getRunContext();
+
+ lnode->updateRunContext(RunLayerContext(
+ manager->requestWeights(gnode, init_context.getWeightsSpec()),
+ manager->requestInputs(gnode, init_context.getInputDimensions()),
+ manager->requestOutputs(gnode, init_context.getOutputDimensions()),
+ manager->requestTensors(gnode, init_context.getTensorsSpec())));
}
int NetworkGraph::initialize(std::shared_ptr<Manager> manager) {
/**
* Initialize all the layers, allocate output tensors for each layer
- * and add optimizer related weights for the layer
+ * init2and add optimizer related weights for the layer
*/
- // TODO: pass init context, this call will fill it
+ // TODO: this call will fill the init context inside the layer
+ // lnode->initialize();
status = lptr->initialize(*manager);
NN_RETURN_STATUS();
- // TODO: call init2runContext
+ updateRunContext(manager, lnode);
+ // TODO: remove this
auto &in_out = manager->trackLayerOutputs(cur_type, lnode->getName(),
lptr->getOutputDimension(),
lptr->getInputDimension());
* @param init_context Init layer context to create run context
* @param run_context Run layer context to be created
*/
- void init2runContext(InitLayerContext &init_context,
- RunLayerContext &run_context);
+ static void updateRunContext(std::shared_ptr<Manager> &manager,
+ const std::shared_ptr<LayerNode> &lnode);
private:
std::map<std::string, std::string> sub_in_out; /** This is map to identify
/**
* @brief Construct a new Run Layer Context object
+ * @todo Include properties like name/trainable later
*
- * @param props properties of the layer
* @param w weights of the layer
* @param in inputs of the layer
* @param out outputs of the layer
* @param t extra tensors of the layer
*/
- RunLayerContext(std::tuple<props::Name> p, const std::vector<Weight *> &w,
+ RunLayerContext(const std::vector<Weight *> &w,
const std::vector<Var_Grad *> &in,
const std::vector<Var_Grad *> &out,
const std::vector<Var_Grad *> &t) :
- props(p),
weights(w),
inputs(in),
outputs(out),
*/
friend std::ostream &operator<<(std::ostream &out, const LayerNode &l);
+ /**
+ * @brief Get init layer context
+ *
+ * @retval init layer context
+ */
+ const InitLayerContext &getInitContext() const { return init_context; }
+
+ /**
+ * @brief Get run layer context
+ *
+ * @retval run layer context
+ */
+ const RunLayerContext &getRunContext() const { return run_context; }
+
+ /**
+ * @brief Set run layer context
+ *
+ * @param context Updated run layer context
+ */
+ void updateRunContext(RunLayerContext &&context) {
+ // TODO: ensure props/trainable must match
+ run_context = std::move(context);
+ }
+
private:
// TODO: make this unique_ptr once getObject API is removed
std::shared_ptr<nntrainer::LayerV1>
*/
std::vector<Weight *>
Manager::requestWeights(const GraphNode &node,
- std::vector<Weight::Spec> &weights_spec) {
+ const std::vector<Weight::Spec> &weights_spec) {
return _requestTensors<Weight>(node, weights_spec, weights_v2);
}
*/
std::vector<Var_Grad *>
Manager::requestTensors(const GraphNode &node,
- std::vector<Var_Grad::Spec> &tensors_spec) {
+ const std::vector<Var_Grad::Spec> &tensors_spec) {
return _requestTensors<Var_Grad>(node, tensors_spec, tensors_v2);
}
*/
std::vector<Var_Grad *>
Manager::requestInputs(const GraphNode &node,
- std::vector<TensorDim> &inputs_dim) {
+ const std::vector<TensorDim> &inputs_dim) {
unsigned int count = 0;
std::vector<Var_Grad::Spec> inputs_spec;
std::transform(
*/
std::vector<Var_Grad *>
Manager::requestOutputs(const GraphNode &node,
- std::vector<TensorDim> &outputs_dim) {
+ const std::vector<TensorDim> &outputs_dim) {
unsigned int count = 0;
std::vector<Var_Grad::Spec> outputs_spec;
std::transform(outputs_dim.begin(), outputs_dim.end(),
* @param w node Graph node to extract node identifiers/info
* @param w weights_spec Specficiation for the weights
*/
- std::vector<Weight *> requestWeights(const GraphNode &node,
- std::vector<Weight::Spec> &weights_spec);
+ std::vector<Weight *>
+ requestWeights(const GraphNode &node,
+ const std::vector<Weight::Spec> &weights_spec);
/**
* @brief Create tensors with the given spec
*/
std::vector<Var_Grad *>
requestTensors(const GraphNode &node,
- std::vector<Var_Grad::Spec> &tensors_spec);
+ const std::vector<Var_Grad::Spec> &tensors_spec);
/**
* @brief Create tensors with the given spec
* @param w node Graph node to extract node identifiers/info
* @param w create tensors list
*/
- std::vector<Var_Grad *> requestInputs(const GraphNode &node,
- std::vector<TensorDim> &inputs_dim);
+ std::vector<Var_Grad *>
+ requestInputs(const GraphNode &node,
+ const std::vector<TensorDim> &inputs_dim);
/**
* @brief Create tensors with the given spec
* @param w node Graph node to extract node identifiers/info
* @param w create tensors list
*/
- std::vector<Var_Grad *> requestOutputs(const GraphNode &node,
- std::vector<TensorDim> &outputs_spec);
+ std::vector<Var_Grad *>
+ requestOutputs(const GraphNode &node,
+ const std::vector<TensorDim> &outputs_spec);
/**
* @brief Get weights tracked with nntrainer
*/
template <typename T>
static std::vector<T *> _requestTensors(
- const GraphNode &node, std::vector<typename T::Spec> &tensors_spec,
+ const GraphNode &node, const std::vector<typename T::Spec> &tensors_spec,
std::vector<std::vector<std::unique_ptr<T>>> &layer_objs_list) {
std::vector<T *> ret;
std::vector<std::unique_ptr<T>> tensors_list;