std::vector<std::string> shared_weight_names;
std::vector<std::string> shared_tensor_names;
if (auto shared_node_str = lnode->getSharedFrom(); !shared_node_str.empty()) {
- auto shared_node = getLayerNode(shared_node_str).get();
- NNTR_THROW_IF(shared_node == nullptr, std::invalid_argument)
- << "shared_node requested but it is not registered in the graph, name: "
- << shared_node_str << " requested from " << lnode->getName();
- NNTR_THROW_IF(shared_node->getType() != lnode->getType(),
- std::invalid_argument)
- << " shared_node and lnode type mismatch, source node type: "
- << shared_node->getType() << " depedent node type: " << lnode->getType()
- << " depedent node name: " << lnode->getName();
- NNTR_THROW_IF(!shared_node->isFinalized(), std::invalid_argument)
- << "shared node must be prior to the dependent node and it should be "
- "finalized beforehand, shared node name: "
- << shared_node_str << " dependent node name: " << lnode->getName();
- auto num_weight = shared_node->getNumWeights();
- shared_weight_names.reserve(num_weight);
- for (auto i = 0u; i < num_weight; ++i) {
- shared_weight_names.emplace_back(shared_node->getWeightName(i));
- }
-
- auto &rc = shared_node->getRunContext();
+ // auto shared_node = getLayerNode(shared_node_str).get();
+ // NNTR_THROW_IF(shared_node == nullptr, std::invalid_argument)
+ // << "shared_node requested but it is not registered in the graph, name:
+ // "
+ // << shared_node_str << " requested from " << lnode->getName();
+ // NNTR_THROW_IF(shared_node->getType() != lnode->getType(),
+ // std::invalid_argument)
+ // << " shared_node and lnode type mismatch, source node type: "
+ // << shared_node->getType() << " depedent node type: " <<
+ // lnode->getType()
+ // << " depedent node name: " << lnode->getName();
+ // NNTR_THROW_IF(!shared_node->isFinalized(), std::invalid_argument)
+ // << "shared node must be prior to the dependent node and it should be "
+ // "finalized beforehand, shared node name: "
+ // << shared_node_str << " dependent node name: " << lnode->getName();
+ // auto num_weight = shared_node->getNumWeights();
+ // shared_weight_names.reserve(num_weight);
+ // for (auto i = 0u; i < num_weight; ++i) {
+ // shared_weight_names.emplace_back(shared_node->getWeightName(i));
+ // }
+ // auto &rc = node->getRunContext();
/// @fixme tensor should be only shared if context explicitly requested to
/// do so. This has to be added to the part of tensor spec, other wise it
/// will break many things
- auto num_tensors = rc.getNumTensors();
- for (auto i = 0u; i < num_tensors; ++i) {
- shared_tensor_names.emplace_back(rc.getTensorName(i));
+ const auto &t_specs = init_context.getTensorsSpec();
+ for (auto i = 0u; i < t_specs.size(); ++i) {
+ shared_tensor_names.emplace_back(std::get<3>(t_specs.at(i)));
+ // std::cout << shared_tensor_names.back() << '\n';
+ }
+
+ const auto &w_specs = init_context.getWeightsSpec();
+ for (auto i = 0u; i < w_specs.size(); ++i) {
+ shared_weight_names.emplace_back(std::get<5>(w_specs.at(i)));
+ // std::cout << shared_weight_names.back() << '\n';
}
}
Tensor *var = nullptr, *grad = nullptr;
bool is_dependent = !shared_names.empty();
if (is_dependent) {
+ /// shared_name is used and the orignal name is discarded
const auto &shared_name = shared_names.at(i);
/** case when shared names are given */
- var = weight_pool.view(name, shared_name, dim, var_exec_order, var_ls);
+ var = weight_pool.requestOrExtend(shared_name, dim, var_exec_order,
+ var_ls, t_initializer);
if (trainable && need_gradient) {
- grad = tensor_pool.view(name + Var_Grad::grad_suffix,
- shared_name + Var_Grad::grad_suffix, dim,
- grad_exec_order, grad_ls);
+ grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
+ dim, grad_exec_order, grad_ls,
+ Tensor::Initializer::ZEROS);
}
} else {
if (is_dependent) {
const auto &shared_name = shared_names.at(i);
- var = tensor_pool.view(name, shared_name, dim, var_exec_order, tspan);
+ var = tensor_pool.requestOrExtend(shared_name, dim, var_exec_order, tspan,
+ t_init);
if (need_grad && tspan > TensorLifespan::FORWARD_FUNC_LIFESPAN) {
- grad = tensor_pool.view(name + Var_Grad::grad_suffix,
- shared_name + Var_Grad::grad_suffix, dim,
- grad_exec_order, tspan);
+ grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
+ dim, grad_exec_order, tspan,
+ Tensor::Initializer::ZEROS);
}
} else {