[unittest/models] Fix matching for models unittest
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 16 Sep 2021 10:48:34 +0000 (19:48 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 23 Sep 2021 04:07:52 +0000 (13:07 +0900)
This patch fixes the matching for the last layer of each model for the
unittest of the models.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
test/unittest/unittest_nntrainer_models.cpp

index 8700dd2..cbf2808 100644 (file)
@@ -152,7 +152,7 @@ public:
    * @param iteration iteration
    * @return nntrainer::sharedConstTensor
    */
-  void forward(int iteration, NodeWatcher &next_node);
+  void forward(int iteration, bool verify);
 
   /**
    * @brief backward pass of the node with verifying inputs/gradients/outputs
@@ -208,6 +208,20 @@ public:
    */
   bool isLossType() { return node->requireLabel(); }
 
+  /**
+   * @brief support in-place operation
+   *
+   * @return true if support in-place else false
+   */
+  bool supportInPlace() { return node->supportInPlace(); }
+
+  /**
+   * @brief support backwarding operation
+   *
+   * @return true if support backwarding else false
+   */
+  bool supportBackwarding() { return node->supportBackwarding(); }
+
 private:
   NodeType node;
   std::vector<nntrainer::Tensor> expected_output;
@@ -259,7 +273,7 @@ void NodeWatcher::verifyGrad(const std::string &error_msg) {
   }
 }
 
-void NodeWatcher::forward(int iteration, NodeWatcher &next_node) {
+void NodeWatcher::forward(int iteration, bool verify_forward) {
   std::stringstream ss;
   ss << "forward failed at " << node->getName() << " at iteration "
      << iteration;
@@ -270,8 +284,7 @@ void NodeWatcher::forward(int iteration, NodeWatcher &next_node) {
     out.push_back(node->getOutput(idx));
   }
 
-  if (!next_node.node->supportInPlace() &&
-      getNodeType() != nntrainer::MultiOutLayer::type)
+  if (verify_forward && getNodeType() != nntrainer::MultiOutLayer::type)
     verify(out, expected_output, err_msg + " at output");
 }
 
@@ -423,18 +436,23 @@ void GraphWatcher::compareFor(const std::string &reference,
                   nntrainer::Tensor::epsilon);
     }
 
-    for (auto it = nodes.begin(); it != nodes.end() - 1; ++it) {
-      it->forward(iteration, *(it + 1));
+    auto it = nodes.begin();
+    for (; it != nodes.end() - 1; ++it) {
+      it->forward(iteration, !(it + 1)->supportInPlace());
     }
+    it->forward(iteration, true);
 
     if (loss_nodes.size()) {
       nn.backwarding(label, iteration);
 
-      for (auto it = nodes.rbegin(); it != nodes.rend() - 1; it++) {
-        if (it == nodes.rend())
-          it->backward(iteration, true, !optimize);
-        else
+      for (auto it = nodes.rbegin(); it != nodes.rend(); it++) {
+        if (it == nodes.rend() - 1) {
+          /** check last layer backwarding only when not input layers */
+          if (it->supportBackwarding())
+            it->backward(iteration, true, !optimize);
+        } else {
           it->backward(iteration, !optimize, !optimize);
+        }
       }
     } else {
       EXPECT_THROW(nn.backwarding(label, iteration), std::runtime_error);