[ CCAPI ] add copy model configuration & get/set Weight of layer
authorjijoong.moon <jijoong.moon@samsung.com>
Tue, 26 Apr 2022 12:35:53 +0000 (21:35 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 11 May 2022 05:26:46 +0000 (14:26 +0900)
This patch enables the API for the copy model configuration and
get/set Weight Tensor data with vector of float pointer.

sementics of this APIs are:
  . copyConfigurarion(ml::train::Model &from)
  . std::vector<float*> getWeights()
  . void setWeights(const std::vector<float*>)

This is only copy the model configuration, it reauires the compile and
initialization before train.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
api/ccapi/include/layer.h
api/ccapi/include/model.h
api/ccapi/src/factory.cpp
nntrainer/graph/graph_core.h
nntrainer/layers/layer_node.cpp
nntrainer/layers/layer_node.h
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h
test/ccapi/unittest_ccapi.cpp
test/tizen_capi/unittest_tizen_capi.cpp

index 5623205..d8c8b86 100644 (file)
@@ -164,6 +164,22 @@ public:
    * to keep the name unique to the model
    */
   virtual const std::string getName() const noexcept = 0;
+
+  /**
+   * @brief     Get weight data of the layer
+   * @retval    weight data of the layer
+   * @note      nntrainer assign the vector and if there is no weights, the size
+   * of vector is zero
+   * @note      layer needs to be finalized before called.
+   */
+  virtual const std::vector<float *> getWeights() = 0;
+
+  /**
+   * @brief     Set weight data of the layer
+   * @note      Size of vector must be the same with number of weights.
+   * @note      layer needs to be finalized before called.
+   */
+  virtual void setWeights(const std::vector<float *>) = 0;
 };
 
 /**
index 7cbcbbc..5e57dbc 100644 (file)
@@ -298,6 +298,11 @@ public:
 std::unique_ptr<Model>
 createModel(ModelType type, const std::vector<std::string> &properties = {});
 
+/**
+ * @brief creator by copying the configuration of other model
+ */
+std::unique_ptr<Model> copyConfiguration(Model &from);
+
 } // namespace train
 } // namespace ml
 
index f3ca2aa..944fd1f 100644 (file)
@@ -75,6 +75,17 @@ std::unique_ptr<Model> createModel(ModelType type,
 }
 
 /**
+ * @brief creator by copying the configuration of other model
+ */
+std::unique_ptr<Model> copyConfiguration(Model &from) {
+  std::unique_ptr<nntrainer::NeuralNetwork> model =
+    std::make_unique<nntrainer::NeuralNetwork>();
+  nntrainer::NeuralNetwork &f = dynamic_cast<nntrainer::NeuralNetwork &>(from);
+  model->copyConfiguration(f);
+  return model;
+}
+
+/**
  * @brief Factory creator with constructor for dataset
  */
 std::unique_ptr<Dataset>
index c8ce00d..25444ee 100644 (file)
@@ -161,8 +161,8 @@ public:
   GraphCore &copy(GraphCore &from) {
     node_list.resize(from.node_list.size());
     if (this != &from) {
-      // or (unsigned int i = 0; i < node_list.size(); i++)
-      //  node_list[i]->copy(from.node_list[i]);
+      //      for (unsigned int i = 0; i < node_list.size(); ++i)
+      //        node_list[i]->copy(from.node_list[i]);
     }
     return *this;
   }
index 4c669f7..9a4c1e5 100644 (file)
@@ -188,6 +188,20 @@ void LayerNode::setProperty(const std::vector<std::string> &properties) {
   }
 }
 
+void LayerNode::setWeights(const std::vector<float *> weights) {
+  NNTR_THROW_IF(!run_context, std::runtime_error)
+    << __func__ << " layer needs to be finalized first!";
+
+  NNTR_THROW_IF(getNumWeights() != weights.size(), std::runtime_error)
+    << __func__ << " Number of Weights dismatch!";
+
+  // Needs Deep copy
+  for (unsigned int idx = 0; idx < getNumWeights(); ++idx) {
+    Tensor &w = getWeight(idx);
+    std::copy(weights[idx], weights[idx] + w.size(), w.getData());
+  }
+}
+
 const unsigned LayerNode::getInputConnectionIndex(unsigned nth) const {
   auto &input_conns =
     std::get<std::vector<props::InputConnection>>(*layer_node_props);
index 7e73e07..1ec5340 100644 (file)
@@ -512,6 +512,31 @@ public:
   }
 
   /**
+   * @brief     Get weight data of the layer
+   * @retval    weight data of the layer
+   * @note      nntrainer assign the vector and if there is no weights, the size
+   * of vector is zero
+   * @note      layer needs to be finalized before called.
+   */
+  const std::vector<float *> getWeights() {
+    NNTR_THROW_IF(!run_context, std::runtime_error)
+      << __func__ << " layer needs to be finalized first!";
+
+    std::vector<float *> weights;
+    for (unsigned int idx = 0; idx < getNumWeights(); ++idx) {
+      weights.emplace_back(getWeight(idx).getData());
+    }
+    return weights;
+  }
+
+  /**
+   * @brief     Set weight data of the layer
+   * @note      Size of vector must be the same with number of weights.
+   * @note      layer needs to be finalized before called.
+   */
+  void setWeights(const std::vector<float *> weights);
+
+  /**
    * @brief Get the Input tensor object
    *
    * @param idx Identifier of the input
index 72d3aad..cd57cf8 100644 (file)
 
 namespace nntrainer {
 
+NeuralNetwork::NeuralNetwork() :
+  model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
+  model_flex_props(props::Epochs(), props::TrainingBatchSize(),
+                   props::SavePath(), props::ContinueTrain(),
+                   props::SaveBestPath(), props::MemoryOptimization()),
+  load_path(std::string()),
+  epoch_idx(0),
+  iter(0),
+  loss(0.0f),
+  data_buffers({nullptr, nullptr, nullptr}),
+  initialized(false),
+  compiled(false),
+  loadedFromConfig(false) {
+  app_context = AppContext(AppContext::Global());
+}
+
 NeuralNetwork::NeuralNetwork(AppContext app_context_) :
   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
   model_flex_props(props::Epochs(), props::TrainingBatchSize(),
@@ -856,6 +872,22 @@ int NeuralNetwork::addLayer(NodeType layer) {
   return status;
 }
 
+NeuralNetwork &NeuralNetwork::copyConfiguration(NeuralNetwork &from) {
+  if (this != &from) {
+    model_props = from.model_props;
+    model_flex_props = from.model_flex_props;
+    loss = from.loss;
+    opt = from.opt;
+
+    NetworkGraph f_graph = from.getNetworkGraph();
+    for (auto &l_node : f_graph.getLayerNodes()) {
+      addLayer(static_cast<std::shared_ptr<ml::train::Layer>>(
+        l_node->cloneConfiguration()));
+    }
+  }
+  return *this;
+}
+
 NeuralNetwork::GraphType
 NeuralNetwork::getUnsortedLayers(const std::string &input_layer,
                                  const std::string &output_layer) {
@@ -886,9 +918,12 @@ int NeuralNetwork::setDataBuffer(const DatasetModeType &mode,
 
 int NeuralNetwork::getLayer(const char *name,
                             std::shared_ptr<ml::train::Layer> *layer) {
-  if (compiled) {
-    return ML_ERROR_NOT_SUPPORTED;
-  }
+  // We provide the layer change through the api with user's responsibility.
+  //
+  // if (compiled) {
+  //   ml_loge("Cannot get compiled layer.");
+  //   return ML_ERROR_NOT_SUPPORTED;
+  // }
 
   *layer = std::static_pointer_cast<ml::train::Layer>(
     model_graph.getLayerNode(std::string(name)));
index 91a1f19..4696939 100644 (file)
@@ -101,7 +101,12 @@ public:
   /**
    * @brief     Constructor of NeuralNetwork Class
    */
-  NeuralNetwork(AppContext app_context_ = AppContext(AppContext::Global()));
+  NeuralNetwork();
+
+  /**
+   * @brief     Constructor of NeuralNetwork Class
+   */
+  NeuralNetwork(AppContext app_context_);
 
   /**
    * @brief     Destructor of NeuralNetwork Class
@@ -251,10 +256,22 @@ public:
    * @brief     Copy Neural Network
    * @param[in] from NeuralNetwork Object to copy
    * @retval    NeuralNewtork Object copyed
+   * @todo Need to implement the copy of graph core
    */
   NeuralNetwork &copy(NeuralNetwork &from);
 
   /**
+   * @brief     Copy Neural Network Configuration
+   * @param[in] from NeuralNetwork Object to copy
+   * @retval    NeuralNewtork Object copyed
+   * @note This does not copy the context of neural network model. It only
+   * copies the configuration of the network model. Therefore, it needs the
+   * compile and initialization to run the model. Also if you need the
+   * initialized the weight, load call is required.
+   */
+  NeuralNetwork &copyConfiguration(NeuralNetwork &from);
+
+  /**
    * @brief     Run NeuralNetwork train
    * @param[in] values hyper parameters
    * @param[in] stop_cb callback function to decide stop training or not
index ac2d3c1..0355b4d 100644 (file)
@@ -437,6 +437,205 @@ TEST(nntrainer_ccapi, save_ini_p) {
   model->save(saved_ini_name, ml::train::ModelFormat::MODEL_FORMAT_INI);
 }
 
+TEST(nntrainer_ccapi, model_copy_01_p) {
+  std::unique_ptr<ml::train::Model> model;
+  std::unique_ptr<ml::train::Model> c_model;
+
+  model = ml::train::createModel(ml::train::ModelType::NEURAL_NET);
+  ScopedIni s("ccapi_simple_ini",
+              {model_base + "batch_size = 16", optimizer, learning_rate,
+               dataset + "-BufferSize", inputlayer, outputlayer});
+  std::shared_ptr<ml::train::Dataset> dataset = ml::train::createDataset(
+    ml::train::DatasetType::FILE, getTestResPath("trainingSet.dat").c_str());
+  EXPECT_NO_THROW(dataset->setProperty({"buffer_size=100"}));
+  EXPECT_EQ(model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, dataset),
+            ML_ERROR_NONE);
+
+  EXPECT_EQ(model->loadFromConfig(s.getIniName()), ML_ERROR_NONE);
+
+  c_model = copyConfiguration(*model);
+
+  EXPECT_EQ(model->compile(), ML_ERROR_NONE);
+  EXPECT_EQ(model->initialize(), ML_ERROR_NONE);
+
+  EXPECT_EQ(c_model->compile(), ML_ERROR_NONE);
+  EXPECT_EQ(c_model->initialize(), ML_ERROR_NONE);
+
+  std::shared_ptr<ml::train::Layer> l;
+  std::shared_ptr<ml::train::Layer> c_l;
+
+  model->getLayer("inputlayer", &l);
+  c_model->getLayer("inputlayer", &c_l);
+
+  EXPECT_EQ(l->getName(), c_l->getName());
+  EXPECT_EQ(l->getType(), c_l->getType());
+}
+
+TEST(nntrainer_ccapi, model_copy_01_n) {
+  std::unique_ptr<ml::train::Model> model;
+  std::unique_ptr<ml::train::Model> c_model;
+
+  model = ml::train::createModel(ml::train::ModelType::NEURAL_NET);
+  ScopedIni s("ccapi_simple_ini",
+              {model_base + "batch_size = 16", optimizer, learning_rate,
+               dataset + "-BufferSize", inputlayer, outputlayer});
+  std::shared_ptr<ml::train::Dataset> dataset = ml::train::createDataset(
+    ml::train::DatasetType::FILE, getTestResPath("trainingSet.dat").c_str());
+  EXPECT_NO_THROW(dataset->setProperty({"buffer_size=100"}));
+  EXPECT_EQ(model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, dataset),
+            ML_ERROR_NONE);
+
+  EXPECT_EQ(model->loadFromConfig(s.getIniName()), ML_ERROR_NONE);
+
+  EXPECT_EQ(model->compile(), ML_ERROR_NONE);
+  EXPECT_EQ(model->initialize(), ML_ERROR_NONE);
+
+  // copyConfiguration is not allowed after finalize
+  EXPECT_THROW(copyConfiguration(*model), std::invalid_argument);
+}
+
+TEST(nntrainer_ccapi, model_copy_02_p) {
+
+  std::unique_ptr<ml::train::Model> model;
+  std::unique_ptr<ml::train::Model> c_model;
+  std::shared_ptr<ml::train::Layer> layer;
+  std::shared_ptr<ml::train::Layer> c_layer;
+  std::shared_ptr<ml::train::Optimizer> optimizer;
+  std::unique_ptr<ml::train::LearningRateScheduler> lrs;
+  std::shared_ptr<ml::train::Dataset> dataset;
+  std::shared_ptr<ml::train::Dataset> c_dataset;
+
+  EXPECT_NO_THROW(model =
+                    ml::train::createModel(ml::train::ModelType::NEURAL_NET));
+
+  EXPECT_NO_THROW(
+    layer = ml::train::layer::Input(
+      {"name=input0", "input_shape=1:1:62720", "normalization=true"}));
+  EXPECT_NO_THROW(model->addLayer(layer));
+
+  EXPECT_NO_THROW(
+    layer = ml::train::layer::FullyConnected(
+      {"name=fc", "unit= 10", "activation=softmax", "bias_initializer=zeros",
+       "weight_regularizer=l2norm", "weight_regularizer_constant=0.005",
+       "weight_initializer=xavier_uniform", "input_layers=input0"}));
+
+  EXPECT_NO_THROW(model->addLayer(layer));
+
+  EXPECT_NO_THROW(optimizer = ml::train::optimizer::Adam(
+                    {"beta1=0.002", "beta2=0.001", "epsilon=1e-7"}));
+
+  EXPECT_NO_THROW(
+    lrs = ml::train::optimizer::learning_rate::Exponential(
+      {"learning_rate=0.0001", "decay_rate=0.96", "decay_steps=1000"}));
+
+  EXPECT_NO_THROW(optimizer->setLearningRateScheduler(std::move(lrs)));
+  EXPECT_NO_THROW(model->setOptimizer(optimizer));
+
+  EXPECT_NO_THROW(
+    dataset = ml::train::createDataset(
+      ml::train::DatasetType::FILE, getTestResPath("trainingSet.dat").c_str()));
+  EXPECT_NO_THROW(dataset->setProperty({"buffer_size=100"}));
+  EXPECT_EQ(model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, dataset),
+            ML_ERROR_NONE);
+
+  EXPECT_NO_THROW(
+    dataset = ml::train::createDataset(ml::train::DatasetType::FILE,
+                                       getTestResPath("valSet.dat").c_str()));
+  EXPECT_NO_THROW(dataset->setProperty({"buffer_size=100"}));
+  EXPECT_EQ(model->setDataset(ml::train::DatasetModeType::MODE_VALID, dataset),
+            ML_ERROR_NONE);
+
+  EXPECT_NO_THROW(model->setProperty(
+    {"loss=cross", "batch_size=16", "epochs=2", "save_path=model.bin"}));
+
+  c_model = copyConfiguration(*model);
+
+  EXPECT_EQ(model->compile(), ML_ERROR_NONE);
+  EXPECT_EQ(model->initialize(), ML_ERROR_NONE);
+
+  EXPECT_NO_THROW(model->train());
+
+  EXPECT_NO_THROW(
+    c_dataset = ml::train::createDataset(
+      ml::train::DatasetType::FILE, getTestResPath("trainingSet.dat").c_str()));
+
+  EXPECT_NO_THROW(c_dataset->setProperty({"buffer_size=10"}));
+
+  EXPECT_EQ(
+    c_model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, dataset),
+    ML_ERROR_NONE);
+
+  EXPECT_NO_THROW(
+    c_dataset = ml::train::createDataset(ml::train::DatasetType::FILE,
+                                         getTestResPath("valSet.dat").c_str()));
+
+  EXPECT_NO_THROW(c_dataset->setProperty({"buffer_size=10"}));
+  EXPECT_EQ(
+    c_model->setDataset(ml::train::DatasetModeType::MODE_VALID, dataset),
+    ML_ERROR_NONE);
+
+  EXPECT_EQ(c_model->compile(), ML_ERROR_NONE);
+  EXPECT_EQ(c_model->initialize(), ML_ERROR_NONE);
+
+  EXPECT_NO_THROW(c_model->train());
+
+  std::shared_ptr<ml::train::Layer> l;
+  std::shared_ptr<ml::train::Layer> c_l;
+
+  model->getLayer("fc", &l);
+  c_model->getLayer("fc", &c_l);
+
+  EXPECT_EQ(l->getName(), c_l->getName());
+  EXPECT_EQ(l->getType(), c_l->getType());
+
+  std::vector<float *> l_weight;
+  std::vector<float *> cl_weight;
+
+  l_weight = l->getWeights();
+  cl_weight = c_l->getWeights();
+
+  EXPECT_EQ(l_weight.size(), cl_weight.size());
+
+  c_l->setWeights(l_weight);
+  cl_weight = c_l->getWeights();
+
+  for (unsigned int i = 0; i < 10; ++i) {
+    // Weight of Fully Connected Layer
+    EXPECT_EQ(cl_weight[0][i], l_weight[0][i]);
+    // Bias of Fully Connected Layer
+    EXPECT_EQ(cl_weight[1][i], l_weight[1][i]);
+  }
+
+  // Run after set the weights
+  EXPECT_NO_THROW(c_model->train());
+
+  std::vector<float *> one_weights;
+  float *w_one = new float[62720 * 10];
+  float *b_one = new float[10];
+  for (unsigned int i = 0; i < 62720 * 10; ++i)
+    w_one[i] = 1.0;
+  for (unsigned int i = 0; i < 10; ++i)
+    b_one[i] = 1.0;
+
+  one_weights.push_back(w_one);
+  one_weights.push_back(b_one);
+
+  c_model->getLayer("fc", &c_l);
+
+  c_l->setWeights(one_weights);
+  cl_weight = c_l->getWeights();
+
+  for (unsigned int i = 0; i < 62720 * 10; ++i)
+    EXPECT_EQ(cl_weight[0][i], 1.0);
+  for (unsigned int i = 0; i < 10; ++i)
+    EXPECT_EQ(cl_weight[1][i], 1.0);
+
+  // Run after set the weights again
+  EXPECT_NO_THROW(c_model->train());
+  delete w_one;
+  delete b_one;
+}
+
 /**
  * @brief Main gtest
  */
index 38e3970..6c05094 100644 (file)
@@ -651,7 +651,7 @@ TEST(nntrainer_capi_nnmodel, getLayer_03_n) {
 /**
  * @brief Neural Network Model Get Layer Test
  */
-TEST(nntrainer_capi_nnmodel, getLayer_04_n) {
+TEST(nntrainer_capi_nnmodel, getLayer_04_p) {
   int status = ML_ERROR_NONE;
 
   ml_train_model_h model;
@@ -666,7 +666,7 @@ TEST(nntrainer_capi_nnmodel, getLayer_04_n) {
   EXPECT_EQ(status, ML_ERROR_NONE);
 
   status = ml_train_model_get_layer(model, "inputlayer", &get_layer);
-  EXPECT_EQ(status, ML_ERROR_NOT_SUPPORTED);
+  EXPECT_EQ(status, ML_ERROR_NONE);
 
   status = ml_train_model_destroy(model);
   EXPECT_EQ(status, ML_ERROR_NONE);