[ccapi] change setLearningRateScheduler function prototype
authorhyeonseok lee <hs89.lee@samsung.com>
Thu, 13 Apr 2023 03:09:36 +0000 (12:09 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 18 Apr 2023 04:49:44 +0000 (13:49 +0900)
 - Change return type from void to int.
   Capi will call this function so it should be return status.
 - Change learning rate scheduler pointer from unique_ptr to shared_ptr

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
api/ccapi/include/optimizer.h
nntrainer/optimizers/optimizer_wrapped.cpp
nntrainer/optimizers/optimizer_wrapped.h

index da1f943..584caad 100644 (file)
@@ -83,8 +83,8 @@ public:
    *
    * @param lrs the learning rate scheduler object
    */
-  virtual void setLearningRateScheduler(
-    std::unique_ptr<ml::train::LearningRateScheduler> &&lrs) = 0;
+  virtual int setLearningRateScheduler(
+    std::shared_ptr<ml::train::LearningRateScheduler> lrs) = 0;
 };
 
 /**
index d0513ab..69b0338 100644 (file)
@@ -124,11 +124,11 @@ OptimizerWrapped::getOptimizerVariableDim(const TensorDim &dim) {
   return optimizer->getOptimizerVariableDim(dim);
 }
 
-void OptimizerWrapped::setLearningRateScheduler(
-  std::unique_ptr<ml::train::LearningRateScheduler> &&lrs) {
-  nntrainer::LearningRateScheduler *ptr =
-    static_cast<nntrainer::LearningRateScheduler *>(lrs.release());
-  lr_sched = std::unique_ptr<nntrainer::LearningRateScheduler>(ptr);
+int OptimizerWrapped::setLearningRateScheduler(
+  std::shared_ptr<ml::train::LearningRateScheduler> lrs) {
+  lr_sched = std::static_pointer_cast<nntrainer::LearningRateScheduler>(lrs);
+
+  return ML_ERROR_NONE;
 }
 
 nntrainer::LearningRateScheduler *OptimizerWrapped::getLearningRateScheduler() {
index 9257c08..c0667f8 100644 (file)
@@ -85,8 +85,8 @@ public:
    *
    * @param lrs the learning rate scheduler object
    */
-  void setLearningRateScheduler(
-    std::unique_ptr<ml::train::LearningRateScheduler> &&lrs) override;
+  int setLearningRateScheduler(
+    std::shared_ptr<ml::train::LearningRateScheduler> lrs) override;
 
   /**
    * Support all the interface requirements by nntrainer::Optimizer
@@ -150,7 +150,7 @@ public:
 
 private:
   std::unique_ptr<OptimizerCore> optimizer; /**< the underlying optimizer */
-  std::unique_ptr<nntrainer::LearningRateScheduler>
+  std::shared_ptr<nntrainer::LearningRateScheduler>
     lr_sched; /**< the underlying learning rate scheduler */
 
   std::tuple<props::LearningRate, props::DecayRate, props::DecaySteps>