[CAPI] Open centroid KNN
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 18 Aug 2021 08:35:34 +0000 (17:35 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 23 Aug 2021 12:44:25 +0000 (21:44 +0900)
**Changes proposed in this PR:**
- Move centroid KNN to layer, delete centroid knn from the Application

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Cc: Inki Dae <inki.dae@samsung.com>
Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
13 files changed:
Applications/SimpleShot/layers/centroid_knn.cpp [deleted file]
Applications/SimpleShot/layers/centroid_knn.h [deleted file]
Applications/SimpleShot/meson.build
Applications/SimpleShot/task_runner.cpp
Applications/SimpleShot/test/simpleshot_layer_common_tests.cpp
api/ccapi/include/layer.h
jni/Android.mk
nntrainer/app_context.cpp
nntrainer/layers/centroid_knn.cpp [new file with mode: 0644]
nntrainer/layers/centroid_knn.h [new file with mode: 0644]
nntrainer/layers/common_properties.cpp
nntrainer/layers/common_properties.h
nntrainer/layers/meson.build

diff --git a/Applications/SimpleShot/layers/centroid_knn.cpp b/Applications/SimpleShot/layers/centroid_knn.cpp
deleted file mode 100644 (file)
index fa6db7a..0000000
+++ /dev/null
@@ -1,150 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
- *
- * @file   centroid_knn.cpp
- * @date   09 Jan 2021
- * @brief  This file contains the simple nearest neighbor layer
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Jihoon Lee <jhoon.it.lee@samsung.com>
- * @bug    No known bugs except for NYI items
- *
- * @details This layer takes centroid and calculate l2 distance
- */
-
-#include <iostream>
-#include <limits>
-#include <regex>
-#include <sstream>
-
-#include <nntrainer_error.h>
-#include <nntrainer_log.h>
-#include <tensor.h>
-#include <weight.h>
-
-#include <centroid_knn.h>
-#include <simpleshot_utils.h>
-
-namespace simpleshot {
-namespace layers {
-
-static constexpr size_t SINGLE_INOUT_IDX = 0;
-
-enum KNNParams { map, num_samples };
-
-void CentroidKNN::setProperty(const std::vector<std::string> &values) {
-  util::Entry e;
-
-  for (auto &val : values) {
-    e = util::getKeyValue(val);
-
-    if (e.key == "num_class") {
-      num_class = std::stoul(e.value);
-      if (num_class == 0) {
-        throw std::invalid_argument("[CentroidKNN] num_class cannot be zero");
-      }
-    } else {
-      std::string msg = "[CentroidKNN] Unknown Layer Properties count " + val;
-      throw nntrainer::exception::not_supported(msg);
-    }
-  }
-}
-
-void CentroidKNN::finalize(nntrainer::InitLayerContext &context) {
-  auto const &input_dim = context.getInputDimensions()[0];
-  if (input_dim.channel() != 1 || input_dim.height() != 1) {
-    ml_logw("centroid nearest layer is designed for flattend feature for now, "
-            "please check");
-  }
-
-  if (num_class == 0) {
-    throw std::invalid_argument(
-      "Error: num_class must be a positive non-zero integer");
-  }
-
-  auto output_dim = nntrainer::TensorDim({num_class});
-  context.setOutputDimensions({output_dim});
-
-  /// weight is a distance map that contains centroid of features of each class
-  auto map_dim = nntrainer::TensorDim({num_class, input_dim.getFeatureLen()});
-
-  /// samples seen for the current run to calculate the centroid
-  auto samples_seen = nntrainer::TensorDim({num_class});
-
-  weight_idx[KNNParams::map] =
-    context.requestWeight(map_dim, nntrainer::Tensor::Initializer::ZEROS,
-                          nntrainer::WeightRegularizer::NONE, 1.0f,
-                          context.getName() + ":map", false);
-
-  weight_idx[KNNParams::num_samples] =
-    context.requestWeight(samples_seen, nntrainer::Tensor::Initializer::ZEROS,
-                          nntrainer::WeightRegularizer::NONE, 1.0f,
-                          context.getName() + ":num_samples", false);
-}
-
-void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
-                             bool training) {
-  auto &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
-  auto &input_ = context.getInput(SINGLE_INOUT_IDX);
-  auto &label = context.getLabel(SINGLE_INOUT_IDX);
-  const auto &input_dim = input_.getDim();
-
-  if (training && label.empty()) {
-    throw std::invalid_argument(
-      "[CentroidKNN] forwarding requires label feeded");
-  }
-
-  auto &map = context.getWeight(weight_idx[KNNParams::map]);
-  auto &num_samples = context.getWeight(weight_idx[KNNParams::num_samples]);
-  auto feature_len = input_dim.getFeatureLen();
-
-  auto get_distance = [](const nntrainer::Tensor &a,
-                         const nntrainer::Tensor &b) {
-    return -a.subtract(b).l2norm();
-  };
-
-  if (training) {
-    auto ans = label.argmax();
-
-    for (unsigned int b = 0; b < input_.batch(); ++b) {
-      auto saved_feature =
-        map.getSharedDataTensor({feature_len}, ans[b] * feature_len);
-
-      //  nntrainer::Tensor::Map(map.getData(), {feature_len},
-      // ans[b] * feature_len);
-      auto num_sample = num_samples.getValue(0, 0, 0, ans[b]);
-      auto current_feature = input_.getBatchSlice(b, 1);
-      saved_feature.multiply_i(num_sample);
-      saved_feature.add_i(current_feature);
-      saved_feature.divide_i(num_sample + 1);
-      num_samples.setValue(0, 0, 0, ans[b], num_sample + 1);
-    }
-  }
-
-  for (unsigned int i = 0; i < num_class; ++i) {
-    auto saved_feature =
-      map.getSharedDataTensor({feature_len}, i * feature_len);
-    // nntrainer::Tensor::Map(map.getData(), {feature_len}, i * feature_len);
-
-    auto num_sample = num_samples.getValue(0, 0, 0, i);
-
-    for (unsigned int b = 0; b < input_.batch(); ++b) {
-      auto current_feature = input_.getBatchSlice(b, 1);
-
-      if (num_sample == 0) {
-        hidden_.setValue(b, 0, 0, i, std::numeric_limits<float>::min());
-      } else {
-        hidden_.setValue(b, 0, 0, i,
-                         get_distance(current_feature, saved_feature));
-      }
-    }
-  }
-}
-
-void CentroidKNN::calcDerivative(nntrainer::RunLayerContext &context) {
-  throw std::invalid_argument("[CentroidKNN::calcDerivative] This Layer "
-                              "does not support backward propagation");
-}
-
-} // namespace layers
-} // namespace simpleshot
diff --git a/Applications/SimpleShot/layers/centroid_knn.h b/Applications/SimpleShot/layers/centroid_knn.h
deleted file mode 100644 (file)
index 9e67aea..0000000
+++ /dev/null
@@ -1,106 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
- *
- * @file   centroid_knn.h
- * @date   09 Jan 2021
- * @details  This file contains the simple nearest neighbor layer, this layer
- * takes centroid and calculate l2 distance
- * @see    https://github.com/nnstreamer/nntrainer
- * @author Jihoon Lee <jhoon.it.lee@samsung.com>
- * @bug    No known bugs except for NYI items
- *
- */
-
-#ifndef __NEAREST_NEIGHBORS_H__
-#define __NEAREST_NEIGHBORS_H__
-#include <string>
-
-#include <layer_context.h>
-#include <layer_devel.h>
-#include <node_exporter.h>
-
-namespace simpleshot {
-namespace layers {
-
-/**
- * @brief Centroid KNN layer which takes centroid and do k-nearest neighbor
- * classification
- */
-class CentroidKNN : public nntrainer::Layer {
-public:
-  /**
-   * @brief Construct a new NearestNeighbors Layer object that does elementwise
-   * subtraction from mean feature vector
-   */
-  CentroidKNN() : Layer(), num_class(0), weight_idx({0}) {}
-
-  /**
-   *  @brief  Move constructor.
-   *  @param[in] CentroidKNN &&
-   */
-  CentroidKNN(CentroidKNN &&rhs) noexcept = default;
-
-  /**
-   * @brief  Move assignment operator.
-   * @parma[in] rhs CentroidKNN to be moved.
-   */
-  CentroidKNN &operator=(CentroidKNN &&rhs) = default;
-
-  /**
-   * @brief Destroy the NearestNeighbors Layer object
-   *
-   */
-  ~CentroidKNN() = default;
-
-  /**
-   * @copydoc Layer::requireLabel()
-   */
-  bool requireLabel() const override { return true; }
-
-  /**
-   * @copydoc Layer::finalize(InitLayerContext &context)
-   */
-  void finalize(nntrainer::InitLayerContext &context) override;
-
-  /**
-   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
-   */
-  void forwarding(nntrainer::RunLayerContext &context, bool training) override;
-
-  /**
-   * @copydoc Layer::calcDerivative(RunLayerContext &context)
-   */
-  void calcDerivative(nntrainer::RunLayerContext &context) override;
-
-  /**
-   * @copydoc bool supportBackwarding() const
-   */
-  bool supportBackwarding() const override { return false; };
-
-  /**
-   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
-   */
-  void exportTo(nntrainer::Exporter &exporter,
-                const nntrainer::ExportMethods &method) const override {}
-
-  /**
-   * @copydoc Layer::getType()
-   */
-  const std::string getType() const override { return CentroidKNN::type; };
-
-  /**
-   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
-   */
-  void setProperty(const std::vector<std::string> &values) override;
-
-  inline static const std::string type = "centroid_knn";
-
-private:
-  unsigned int num_class;
-  std::array<unsigned int, 2> weight_idx; /**< indices of the weights */
-};
-} // namespace layers
-} // namespace simpleshot
-
-#endif /** __NEAREST_NEIGHBORS_H__ */
index 5838cf3ef8afa1ac6fb7720c489e49fad92d625c..bd8d70be3e8a0973d264fb25288680d883ad57d2 100644 (file)
@@ -1,6 +1,5 @@
 simpleshot_sources = [
   'simpleshot_utils.cpp',
-  'layers/centroid_knn.cpp',
   'layers/centering.cpp',
   'layers/l2norm.cpp',
 ]
index fb579ab0372122978fe04f1935491f6a8b874b63..804d198d859ea8663f14d5e115c09d010b82d9bd 100644 (file)
@@ -22,7 +22,6 @@
 #include <nntrainer-api-common.h>
 
 #include "layers/centering.h"
-#include "layers/centroid_knn.h"
 #include "layers/l2norm.h"
 
 namespace simpleshot {
@@ -225,8 +224,6 @@ int main(int argc, char **argv) {
       nntrainer::createLayer<simpleshot::layers::CenteringLayer>);
     app_context.registerFactory(
       nntrainer::createLayer<simpleshot::layers::L2NormLayer>);
-    app_context.registerFactory(
-      nntrainer::createLayer<simpleshot::layers::CentroidKNN>);
   } catch (std::exception &e) {
     std::cerr << "registering factory failed: " << e.what();
     return 1;
index 98bb55228f339d88643e82dba087ff3f2c6f1d8a..fd35457b099dfa20380c4bb9afc2d8d30f8a95e5 100644 (file)
@@ -23,8 +23,8 @@ auto semantic_activation_l2norm = LayerSemanticsParamType(
   simpleshot::layers::L2NormLayer::type, {}, 0, false);
 
 auto semantic_activation_centroid_knn = LayerSemanticsParamType(
-  nntrainer::createLayer<simpleshot::layers::CentroidKNN>,
-  simpleshot::layers::CentroidKNN::type, {"num_class=1"}, 0, false);
+  nntrainer::createLayer<nntrainer::CentroidKNN>, nntrainer::CentroidKNN::type,
+  {"num_class=1"}, 0, false);
 
 auto semantic_activation_centering = LayerSemanticsParamType(
   nntrainer::createLayer<simpleshot::layers::CenteringLayer>,
index e17bb7a9455a0ef4c83bee7268a252b3fa03d069..1e547097d2e25745dc53554653cb92f342b72cf5 100644 (file)
@@ -317,6 +317,14 @@ TimeDistLayer(const std::vector<std::string> &properties = {}) {
   return createLayer(LayerType::LAYER_TIME_DIST, properties);
 }
 
+/**
+ * @brief Helper function to create Centroid KNN Layer
+ */
+inline std::unique_ptr<Layer>
+CentroidKNN(const std::vector<std::string> &properties = {}) {
+  return createLayer(LayerType::LAYER_CENTROID_KNN, properties);
+}
+
 /**
  * @brief Helper function to create activation layer
  */
index 43e7ac8a0d103d546eebd62592dd28c5cc345ea9..afafc39b5e08773f62a7370e57e4990305b4675d 100644 (file)
@@ -163,6 +163,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/time_dist.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/dropout.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/permute_layer.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/centroid_knn.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/acti_func.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/split_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/common_properties.cpp \
index 2c48cfcbf6b72f5b58d5357fc1cf9c9583a6042a..79f697760c508f85991100a6eec6901b5f8328bf 100644 (file)
@@ -31,6 +31,7 @@
 #include <activation_layer.h>
 #include <addition_layer.h>
 #include <bn_layer.h>
+#include <centroid_knn.h>
 #include <concat_layer.h>
 #include <conv2d_layer.h>
 #include <cross_entropy_sigmoid_loss_layer.h>
@@ -226,12 +227,7 @@ static void add_default_object(AppContext &ac) {
                      LayerType::LAYER_MULTIOUT);
   ac.registerFactory(nntrainer::createLayer<ConcatLayer>, ConcatLayer::type,
                      LayerType::LAYER_CONCAT);
-  ac.registerFactory(nntrainer::createLayer<PreprocessFlipLayer>,
-                     PreprocessFlipLayer::type,
-                     LayerType::LAYER_PREPROCESS_FLIP);
-  ac.registerFactory(nntrainer::createLayer<PreprocessTranslateLayer>,
-                     PreprocessTranslateLayer::type,
-                     LayerType::LAYER_PREPROCESS_TRANSLATE);
+
 #ifdef ENABLE_NNSTREAMER_BACKBONE
   ac.registerFactory(nntrainer::createLayer<NNStreamerLayer>,
                      NNStreamerLayer::type,
@@ -257,6 +253,16 @@ static void add_default_object(AppContext &ac) {
                      LayerType::LAYER_SPLIT);
   ac.registerFactory(nntrainer::createLayer<PermuteLayer>, PermuteLayer::type,
                      LayerType::LAYER_PERMUTE);
+  ac.registerFactory(nntrainer::createLayer<CentroidKNN>, CentroidKNN::type,
+                     LayerType::LAYER_CENTROID_KNN);
+
+  /** proprocess layers */
+  ac.registerFactory(nntrainer::createLayer<PreprocessFlipLayer>,
+                     PreprocessFlipLayer::type,
+                     LayerType::LAYER_PREPROCESS_FLIP);
+  ac.registerFactory(nntrainer::createLayer<PreprocessTranslateLayer>,
+                     PreprocessTranslateLayer::type,
+                     LayerType::LAYER_PREPROCESS_TRANSLATE);
 
   /** register losses */
   ac.registerFactory(nntrainer::createLayer<MSELossLayer>, MSELossLayer::type,
diff --git a/nntrainer/layers/centroid_knn.cpp b/nntrainer/layers/centroid_knn.cpp
new file mode 100644 (file)
index 0000000..3552098
--- /dev/null
@@ -0,0 +1,139 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file   centroid_knn.cpp
+ * @date   09 Jan 2021
+ * @brief  This file contains the simple nearest neighbor layer
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ * @details This layer takes centroid and calculate l2 distance
+ */
+
+#include <iostream>
+#include <limits>
+#include <regex>
+#include <sstream>
+
+#include <centroid_knn.h>
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <tensor.h>
+#include <weight.h>
+
+namespace nntrainer {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+enum KNNParams { map, num_samples };
+
+CentroidKNN::CentroidKNN() :
+  Layer(),
+  centroid_knn_props(props::NumClass()),
+  weight_idx({0}) {}
+
+CentroidKNN::~CentroidKNN() {}
+
+void CentroidKNN::setProperty(const std::vector<std::string> &values) {
+  auto left = loadProperties(values, centroid_knn_props);
+  NNTR_THROW_IF(!left.empty(), std::invalid_argument)
+    << "[Centroid KNN] there are unparsed properties " << left.front();
+}
+
+void CentroidKNN::finalize(nntrainer::InitLayerContext &context) {
+  auto const &input_dim = context.getInputDimensions()[0];
+  if (input_dim.channel() != 1 || input_dim.height() != 1) {
+    ml_logw("centroid nearest layer is designed for flattend feature for now, "
+            "please check");
+  }
+
+  auto num_class = std::get<props::NumClass>(centroid_knn_props);
+
+  auto output_dim = nntrainer::TensorDim({num_class});
+  context.setOutputDimensions({output_dim});
+
+  /// weight is a distance map that contains centroid of features of each class
+  auto map_dim = nntrainer::TensorDim({num_class, input_dim.getFeatureLen()});
+
+  /// samples seen for the current run to calculate the centroid
+  auto samples_seen = nntrainer::TensorDim({num_class});
+
+  weight_idx[KNNParams::map] =
+    context.requestWeight(map_dim, nntrainer::Tensor::Initializer::ZEROS,
+                          nntrainer::WeightRegularizer::NONE, 1.0f,
+                          context.getName() + ":map", false);
+
+  weight_idx[KNNParams::num_samples] =
+    context.requestWeight(samples_seen, nntrainer::Tensor::Initializer::ZEROS,
+                          nntrainer::WeightRegularizer::NONE, 1.0f,
+                          context.getName() + ":num_samples", false);
+}
+
+void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
+                             bool training) {
+  auto &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
+  auto &input_ = context.getInput(SINGLE_INOUT_IDX);
+  auto &label = context.getLabel(SINGLE_INOUT_IDX);
+  const auto &input_dim = input_.getDim();
+
+  if (training && label.empty()) {
+    throw std::invalid_argument(
+      "[CentroidKNN] forwarding requires label feeded");
+  }
+
+  auto &map = context.getWeight(weight_idx[KNNParams::map]);
+  auto &num_samples = context.getWeight(weight_idx[KNNParams::num_samples]);
+  auto feature_len = input_dim.getFeatureLen();
+
+  auto get_distance = [](const nntrainer::Tensor &a,
+                         const nntrainer::Tensor &b) {
+    return -a.subtract(b).l2norm();
+  };
+
+  if (training) {
+    auto ans = label.argmax();
+
+    for (unsigned int b = 0; b < input_.batch(); ++b) {
+      auto saved_feature =
+        map.getSharedDataTensor({feature_len}, ans[b] * feature_len);
+
+      //  nntrainer::Tensor::Map(map.getData(), {feature_len},
+      // ans[b] * feature_len);
+      auto num_sample = num_samples.getValue(0, 0, 0, ans[b]);
+      auto current_feature = input_.getBatchSlice(b, 1);
+      saved_feature.multiply_i(num_sample);
+      saved_feature.add_i(current_feature);
+      saved_feature.divide_i(num_sample + 1);
+      num_samples.setValue(0, 0, 0, ans[b], num_sample + 1);
+    }
+  }
+
+  for (unsigned int i = 0; i < std::get<props::NumClass>(centroid_knn_props);
+       ++i) {
+    auto saved_feature =
+      map.getSharedDataTensor({feature_len}, i * feature_len);
+    // nntrainer::Tensor::Map(map.getData(), {feature_len}, i * feature_len);
+
+    auto num_sample = num_samples.getValue(0, 0, 0, i);
+
+    for (unsigned int b = 0; b < input_.batch(); ++b) {
+      auto current_feature = input_.getBatchSlice(b, 1);
+
+      if (num_sample == 0) {
+        hidden_.setValue(b, 0, 0, i, std::numeric_limits<float>::min());
+      } else {
+        hidden_.setValue(b, 0, 0, i,
+                         get_distance(current_feature, saved_feature));
+      }
+    }
+  }
+}
+
+void CentroidKNN::calcDerivative(nntrainer::RunLayerContext &context) {
+  throw std::invalid_argument("[CentroidKNN::calcDerivative] This Layer "
+                              "does not support backward propagation");
+}
+} // namespace nntrainer
diff --git a/nntrainer/layers/centroid_knn.h b/nntrainer/layers/centroid_knn.h
new file mode 100644 (file)
index 0000000..83208d6
--- /dev/null
@@ -0,0 +1,104 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
+ *
+ * @file   centroid_knn.h
+ * @date   09 Jan 2021
+ * @details  This file contains the simple nearest neighbor layer, this layer
+ * takes centroid and calculate l2 distance
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jihoon Lee <jhoon.it.lee@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __CENTROID_KNN_H__
+#define __CENTROID_KNN_H__
+#include <string>
+
+#include <common_properties.h>
+#include <layer_context.h>
+#include <layer_devel.h>
+
+namespace nntrainer {
+
+/**
+ * @brief Centroid KNN layer which takes centroid and do k-nearest neighbor
+ * classification
+ */
+class CentroidKNN : public Layer {
+public:
+  /**
+   * @brief Construct a new NearestNeighbors Layer object that does elementwise
+   * subtraction from mean feature vector
+   */
+  CentroidKNN();
+
+  /**
+   *  @brief  Move constructor.
+   *  @param[in] CentroidKNN &&
+   */
+  CentroidKNN(CentroidKNN &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs CentroidKNN to be moved.
+   */
+  CentroidKNN &operator=(CentroidKNN &&rhs) noexcept = default;
+
+  /**
+   * @brief Destroy the NearestNeighbors Layer object
+   *
+   */
+  ~CentroidKNN();
+
+  /**
+   * @copydoc Layer::requireLabel()
+   */
+  bool requireLabel() const override { return true; }
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(nntrainer::InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(nntrainer::RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(nntrainer::RunLayerContext &context) override;
+
+  /**
+   * @copydoc bool supportBackwarding() const
+   */
+  bool supportBackwarding() const override { return false; };
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+   */
+  void exportTo(nntrainer::Exporter &exporter,
+                const nntrainer::ExportMethods &method) const override {}
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override { return CentroidKNN::type; };
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  inline static const std::string type = "centroid_knn";
+
+private:
+  std::tuple<props::NumClass> centroid_knn_props;
+  std::array<unsigned int, 2> weight_idx; /**< indices of the weights */
+};
+} // namespace nntrainer
+
+#endif /** __CENTROID_KNN_H__ */
index 61b198eeb372209da4ed6a4d2148b65b82975cce..9b8d945f85639343e7af1aff6b5eebb959246445 100644 (file)
@@ -48,6 +48,8 @@ void FilePath::set(const std::string &v) {
 
 std::ifstream::pos_type FilePath::file_size() { return cached_pos_size; }
 
+bool NumClass::isValid(const unsigned int &v) const { return v > 0; }
+
 ConnectionSpec::ConnectionSpec(const std::vector<props::Name> &layer_ids_,
                                const std::string &op_type_) :
   op_type(op_type_),
index f1706838031b9b12e4436efd8a5a6ce5af0bdf27..72aeec6feb155da4947051d6dabc4412e2bf6b0c 100644 (file)
@@ -297,6 +297,21 @@ public:
 private:
   std::ifstream::pos_type cached_pos_size;
 };
+
+/**
+ * @brief Number of class
+ * @todo deprecate this
+ */
+class NumClass final : public nntrainer::Property<unsigned int> {
+public:
+  using prop_tag = uint_prop_tag;                 /**< property type */
+  static constexpr const char *key = "num_class"; /**< unique key to access */
+
+  /**
+   * @copydoc nntrainer::Property<unsigned int>::isValid(const unsigned int &v);
+   */
+  bool isValid(const unsigned int &v) const override;
+};
 } // namespace props
 } // namespace nntrainer
 
index 23b783d21a36fa3868eb1af058aeb73facdb645e..2c1b41f08e110c2141bb535d3f37155bea45979a 100644 (file)
@@ -26,6 +26,7 @@ layer_sources = [
   'layer_impl.cpp',
   'gru.cpp',
   'dropout.cpp',
+  'centroid_knn.cpp',
   'layer_context.cpp'
 ]