This patch refines forwading operation in neural network class.
The code depth is not matched for forward and backward operations. For
backwarding operation, there is backwarding_op and it pasded to the
graph, which can handle the operation.
To keep same depth for forwarding, this patch adds forwarding_op
function, and passed to the graph class.
Releated Issue:
\#2108
Signed-off-by: Jiho Chu <jiho.chu@samsung.com>
}
}
-sharedConstTensors
-NetworkGraph::forwarding(bool training,
- std::function<bool(void *userdata)> stop_cb) {
+sharedConstTensors NetworkGraph::forwarding(
+ bool training,
+ std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op,
+ std::function<bool(void *userdata)> stop_cb) {
for (auto iter = cbegin(); iter != cend() && !stop_cb(nullptr); iter++) {
- auto const &ln = *iter;
- PROFILE_TIME_START(profile_keys.at(ln->getType()));
- PROFILE_MEM_ANNOTATE("Forwarding for layer: " + ln->getName());
-
- auto f = std::get<0>(ln->getExecutionOrder());
- flushCacheExcept(f);
-
- ln->forwarding(training);
+ auto &ln = *iter;
+ PROFILE_TIME_START(profile_keys.at(ln->getType()));
+ forwarding_op(*iter, training);
PROFILE_TIME_END(profile_keys.at(ln->getType()));
}
* @param[in] training true if forwarding is on training
* @retval output tensors
*/
- sharedConstTensors forwarding(bool training = false,
- std::function<bool(void *userdata)> stop_cb =
- [](void *user_data) { return false; });
+ sharedConstTensors forwarding(
+ bool training = false,
+ std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op =
+ [](std::shared_ptr<LayerNode>, bool) {},
+ std::function<bool(void *userdata)> stop_cb =
+ [](void *user_data) { return false; });
/**
* @brief backwarding the network graph
sharedConstTensors
NeuralNetwork::forwarding(bool training,
std::function<bool(void *userdata)> stop_cb) {
- return model_graph.forwarding(training, stop_cb);
+ std::function<void(std::shared_ptr<LayerNode>, bool)> forwarding_op =
+ [this, stop_cb](std::shared_ptr<LayerNode> node, bool training) -> void {
+ (void)this;
+ PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName());
+
+ auto f = std::get<0>(node->getExecutionOrder());
+ model_graph.flushCacheExcept(f);
+
+ node->forwarding(training);
+ };
+
+ return model_graph.forwarding(training, forwarding_op, stop_cb);
}
/**