[KNN] Fix querying label when not training
authorJihoon Lee <jhoon.it.lee@samsung.com>
Fri, 12 Nov 2021 02:02:45 +0000 (11:02 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 17 Nov 2021 06:57:18 +0000 (15:57 +0900)
This patch fix querying label when not training, which is not feasible

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/layers/centroid_knn.cpp
nntrainer/layers/layer_context.cpp

index 859722d..153d5b2 100644 (file)
@@ -74,14 +74,8 @@ void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
                              bool training) {
   auto &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
   auto &input_ = context.getInput(SINGLE_INOUT_IDX);
-  auto &label = context.getLabel(SINGLE_INOUT_IDX);
   const auto &input_dim = input_.getDim();
 
-  if (training && label.empty()) {
-    throw std::invalid_argument(
-      "[CentroidKNN] forwarding requires label feeded");
-  }
-
   auto &map = context.getWeight(weight_idx[KNNParams::map]);
   auto &num_samples = context.getWeight(weight_idx[KNNParams::num_samples]);
   auto feature_len = input_dim.getFeatureLen();
@@ -92,6 +86,7 @@ void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
   };
 
   if (training) {
+    auto &label = context.getLabel(SINGLE_INOUT_IDX);
     auto ans = label.argmax();
 
     for (unsigned int b = 0; b < input_.batch(); ++b) {
index f20a07f..3f1c133 100644 (file)
@@ -10,6 +10,7 @@
  * @brief  This is the layer context for each layer
  */
 
+#include "nntrainer_error.h"
 #include <functional>
 
 #include <layer_context.h>
@@ -323,8 +324,12 @@ bool RunLayerContext::isLabelAvailable(unsigned int idx) const {
 Tensor &RunLayerContext::getLabel(unsigned int idx) {
   if (isLabelAvailable(idx))
     return outputs[idx]->getGradientRef();
-  else
-    throw std::invalid_argument("Request tensor which does not exist");
+  else {
+    std::stringstream ss;
+    ss << "Requesing label of index: " << idx << "for " << getName()
+       << " does not exist";
+    throw std::invalid_argument(ss.str().c_str());
+  }
 }
 
 /**