From: Parichay Kapoor Date: Thu, 16 Sep 2021 10:48:34 +0000 (+0900) Subject: [unittest/models] Fix matching for models unittest X-Git-Tag: accepted/tizen/6.5/unified/20211028.114744~10 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=abc838db1332facf0b2ea319ab1061265aee8759;p=platform%2Fcore%2Fml%2Fnntrainer.git [unittest/models] Fix matching for models unittest This patch fixes the matching for the last layer of each model for the unittest of the models. Signed-off-by: Parichay Kapoor --- diff --git a/test/unittest/unittest_nntrainer_models.cpp b/test/unittest/unittest_nntrainer_models.cpp index 8700dd2..cbf2808 100644 --- a/test/unittest/unittest_nntrainer_models.cpp +++ b/test/unittest/unittest_nntrainer_models.cpp @@ -152,7 +152,7 @@ public: * @param iteration iteration * @return nntrainer::sharedConstTensor */ - void forward(int iteration, NodeWatcher &next_node); + void forward(int iteration, bool verify); /** * @brief backward pass of the node with verifying inputs/gradients/outputs @@ -208,6 +208,20 @@ public: */ bool isLossType() { return node->requireLabel(); } + /** + * @brief support in-place operation + * + * @return true if support in-place else false + */ + bool supportInPlace() { return node->supportInPlace(); } + + /** + * @brief support backwarding operation + * + * @return true if support backwarding else false + */ + bool supportBackwarding() { return node->supportBackwarding(); } + private: NodeType node; std::vector expected_output; @@ -259,7 +273,7 @@ void NodeWatcher::verifyGrad(const std::string &error_msg) { } } -void NodeWatcher::forward(int iteration, NodeWatcher &next_node) { +void NodeWatcher::forward(int iteration, bool verify_forward) { std::stringstream ss; ss << "forward failed at " << node->getName() << " at iteration " << iteration; @@ -270,8 +284,7 @@ void NodeWatcher::forward(int iteration, NodeWatcher &next_node) { out.push_back(node->getOutput(idx)); } - if (!next_node.node->supportInPlace() && - getNodeType() != nntrainer::MultiOutLayer::type) + if (verify_forward && getNodeType() != nntrainer::MultiOutLayer::type) verify(out, expected_output, err_msg + " at output"); } @@ -423,18 +436,23 @@ void GraphWatcher::compareFor(const std::string &reference, nntrainer::Tensor::epsilon); } - for (auto it = nodes.begin(); it != nodes.end() - 1; ++it) { - it->forward(iteration, *(it + 1)); + auto it = nodes.begin(); + for (; it != nodes.end() - 1; ++it) { + it->forward(iteration, !(it + 1)->supportInPlace()); } + it->forward(iteration, true); if (loss_nodes.size()) { nn.backwarding(label, iteration); - for (auto it = nodes.rbegin(); it != nodes.rend() - 1; it++) { - if (it == nodes.rend()) - it->backward(iteration, true, !optimize); - else + for (auto it = nodes.rbegin(); it != nodes.rend(); it++) { + if (it == nodes.rend() - 1) { + /** check last layer backwarding only when not input layers */ + if (it->supportBackwarding()) + it->backward(iteration, true, !optimize); + } else { it->backward(iteration, !optimize, !optimize); + } } } else { EXPECT_THROW(nn.backwarding(label, iteration), std::runtime_error);