From 92a44945fa267c966e890c3fedd3fad7fb9db87b Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Tue, 5 Oct 2021 13:43:37 +0900 Subject: [PATCH] [WeightSharing] enable weight sharing from manager This patch enables weight sharing from manager. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/tensor/manager.cpp | 41 +++++++++++++++++++++++++++++----------- nntrainer/tensor/tensor_pool.cpp | 3 ++- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index a2038c1..3bfec15 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -265,25 +265,44 @@ std::vector Manager::requestWeights( std::vector ret; size_t current_size = weights_v2.size(); - for (auto const &ws : std::as_const(weights_spec)) { - auto &[dim, t_initializer, w_reg, w_reg_const, need_gradient, name] = ws; + for (unsigned int i = 0; i < weights_spec.size(); ++i) { + auto &[dim, t_initializer, w_reg, w_reg_const, need_gradient, name] = + weights_spec.at(i); + + Tensor *var = nullptr, *grad = nullptr; if (!shared_names.empty()) { - /** @todo add case when shared names are given */ + const auto &shared_name = shared_names.at(i); + /** case when shared names are given */ + var = weight_pool.requestPrerequestedTensor( + dim, var_exec_order, var_ls, + name, /// name + shared_name, /// shared name + t_initializer /// tensor initializer + ); + + if (trainable && need_gradient) { + grad = tensor_pool.requestPrerequestedTensor( + dim, grad_exec_order, grad_ls, + name + Var_Grad::grad_suffix, /// name + shared_name + Var_Grad::grad_suffix, /// shared name + Tensor::Initializer::ZEROS /// tensor initializer + ); + } } else { - Tensor *var = weight_pool.requestTensor(dim, var_exec_order, var_ls, name, - t_initializer); + /** case requesting fresh weights */ + var = weight_pool.requestTensor(dim, var_exec_order, var_ls, name, + t_initializer); - Tensor *grad = nullptr; if (trainable && need_gradient) grad = tensor_pool.requestTensor(dim, grad_exec_order, grad_ls, name + Var_Grad::grad_suffix, Tensor::Initializer::ZEROS); - - weights_v2.emplace_back( - std::make_unique(var, grad, w_reg, w_reg_const)); } + + weights_v2.emplace_back( + std::make_unique(var, grad, w_reg, w_reg_const)); } std::transform(weights_v2.begin() + current_size, weights_v2.end(), @@ -383,7 +402,7 @@ Manager::requestInputs(const GraphNode &node, dim, /// tensor dim var_exec_order, var_ls, var_name, /// name - outputs_name[idx], /// name + outputs_name[idx], /// shared name Tensor::Initializer::NONE /// tensor initializer ); @@ -391,7 +410,7 @@ Manager::requestInputs(const GraphNode &node, dim, /// tensor dim grad_exec_order, grad_ls, var_name + Var_Grad::grad_suffix, /// name - outputs_name[idx] + Var_Grad::grad_suffix, /// name + outputs_name[idx] + Var_Grad::grad_suffix, /// shared name Tensor::Initializer::ZEROS /// tensor initializer ); } else if (!node.getInputConnections().empty()) { diff --git a/nntrainer/tensor/tensor_pool.cpp b/nntrainer/tensor/tensor_pool.cpp index 59ad6a2..a15f05a 100644 --- a/nntrainer/tensor/tensor_pool.cpp +++ b/nntrainer/tensor/tensor_pool.cpp @@ -77,7 +77,8 @@ Tensor *TensorPool::requestPrerequestedTensor( TensorLifespan lifespan, const std::string &name, const std::string &shared_name, const Tensor::Initializer &init) { if (name_map.find(shared_name) == name_map.end()) - throw std::invalid_argument("Requested shared tensor not found"); + throw std::invalid_argument("Requested shared tensor not found, name: " + + shared_name); /** find the parent non-dependent node where the spec is stored */ int parent_spec_idx = name_map[shared_name]; -- 2.7.4