From 24a7738539cadd9d6ce5c0ee479eade373f8ea6f Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Fri, 12 Nov 2021 11:02:45 +0900 Subject: [PATCH] [KNN] Fix querying label when not training This patch fix querying label when not training, which is not feasible **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/layers/centroid_knn.cpp | 7 +------ nntrainer/layers/layer_context.cpp | 9 +++++++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nntrainer/layers/centroid_knn.cpp b/nntrainer/layers/centroid_knn.cpp index 859722d..153d5b2 100644 --- a/nntrainer/layers/centroid_knn.cpp +++ b/nntrainer/layers/centroid_knn.cpp @@ -74,14 +74,8 @@ void CentroidKNN::forwarding(nntrainer::RunLayerContext &context, bool training) { auto &hidden_ = context.getOutput(SINGLE_INOUT_IDX); auto &input_ = context.getInput(SINGLE_INOUT_IDX); - auto &label = context.getLabel(SINGLE_INOUT_IDX); const auto &input_dim = input_.getDim(); - if (training && label.empty()) { - throw std::invalid_argument( - "[CentroidKNN] forwarding requires label feeded"); - } - auto &map = context.getWeight(weight_idx[KNNParams::map]); auto &num_samples = context.getWeight(weight_idx[KNNParams::num_samples]); auto feature_len = input_dim.getFeatureLen(); @@ -92,6 +86,7 @@ void CentroidKNN::forwarding(nntrainer::RunLayerContext &context, }; if (training) { + auto &label = context.getLabel(SINGLE_INOUT_IDX); auto ans = label.argmax(); for (unsigned int b = 0; b < input_.batch(); ++b) { diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index f20a07f..3f1c133 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -10,6 +10,7 @@ * @brief This is the layer context for each layer */ +#include "nntrainer_error.h" #include #include @@ -323,8 +324,12 @@ bool RunLayerContext::isLabelAvailable(unsigned int idx) const { Tensor &RunLayerContext::getLabel(unsigned int idx) { if (isLabelAvailable(idx)) return outputs[idx]->getGradientRef(); - else - throw std::invalid_argument("Request tensor which does not exist"); + else { + std::stringstream ss; + ss << "Requesing label of index: " << idx << "for " << getName() + << " does not exist"; + throw std::invalid_argument(ss.str().c_str()); + } } /** -- 2.7.4