From cb645aaea994cd4c0423c7af36dafe0f8bdf5f21 Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Thu, 13 Apr 2023 12:31:07 +0900 Subject: [PATCH] [capi] add learning rate scheduler related api - Added learning rate scheduler create/destroy/set property/set property with single param api - Added set learning rate scheduler to optimizer - Added ml_train_lr_scheduler_type_e enum - Fix some comments Signed-off-by: hyeonseok lee --- api/capi/include/nntrainer.h | 79 +++++++++++++++++- api/capi/include/nntrainer_internal.h | 55 ++++++++++++- api/capi/src/nntrainer.cpp | 149 +++++++++++++++++++++++++++++++++- api/ccapi/include/optimizer.h | 7 +- api/nntrainer-api-common.h | 11 +++ 5 files changed, 287 insertions(+), 14 deletions(-) diff --git a/api/capi/include/nntrainer.h b/api/capi/include/nntrainer.h index 1026fa4..b69c72d 100644 --- a/api/capi/include/nntrainer.h +++ b/api/capi/include/nntrainer.h @@ -52,25 +52,31 @@ extern "C" { */ /** - * @brief A handle of an NNTrainer model. + * @brief A handle of a NNTrainer model. * @since_tizen 6.0 */ typedef void *ml_train_model_h; /** - * @brief A handle of an NNTrainer layer. + * @brief A handle of a NNTrainer layer. * @since_tizen 6.0 */ typedef void *ml_train_layer_h; /** - * @brief A handle of an NNTrainer optimizer. + * @brief A handle of a NNTrainer optimizer. * @since_tizen 6.0 */ typedef void *ml_train_optimizer_h; /** - * @brief A handle of an NNTrainer dataset. + * @brief A handle of a NNTrainer learning rate scheduler. + * @since_tizen 7.5 + */ +typedef void *ml_train_lr_scheduler_h; + +/** + * @brief A handle of a NNTrainer dataset. * @since_tizen 6.0 */ typedef void *ml_train_dataset_h; @@ -391,6 +397,71 @@ int ml_train_optimizer_destroy(ml_train_optimizer_h optimizer); int ml_train_optimizer_set_property(ml_train_optimizer_h optimizer, ...); /** + * @brief Sets the learning rate scheduler for the optimizer. + * @details Use this function to set learning rate scheduler. This transfers + * the ownership of the scheduler to the optimizer. No need to destroy the + * optimizer if it is to a model. + * @since_tizen 7.5 + * @remarks Unsets the previously set lr_scheduler, if any. The previously set + * lr_scheduler must be freed using ml_train_lr_scheduler_destroy(). + * @param[in] optimizer The NNTrainer optimizer handle. + * @param[in] lr_scheduler The NNTrainer lr scheduler handle. + * @return @c 0 on success. Otherwise a negative error value. + * @retval #ML_ERROR_NONE Successful. + * @retval #ML_ERROR_NOT_SUPPORTED Not supported. + * @retval #ML_ERROR_INVALID_PARAMETER Invalid parameter. + */ +int ml_train_optimizer_set_lr_scheduler(ml_train_optimizer_h optimizer, + ml_train_lr_scheduler_h lr_scheduler); + +/** + * @brief Creates a learning rate scheduler for optimizer. + * @details Use this function to create learning rate scheduler for optimizer. + * If not set to a optimizer, @a lr_sheduler should be released using + * ml_train_lr_scheduler_destroy(). If set to a optimizer, @a lr_scheduler is + * available until optimizer is released. + * @since_tizen 7.5 + * @remarks If the function succeeds, @a lr_scheduler must be released using + * ml_train_lr_scheduler_destroy(), if not set to a optimizer. If set to a + * optimizer, @a lr_scheduler is available until the optimizer is released. + * @param[out] lr_scheduler The NNTrainer learning rate scheduler handle. + * @param[in] type The NNTrainer learning rate scheduler type. + * @return @c 0 on success. Otherwise a negative error value. + * @retval #ML_ERROR_NONE Successful. + * @retval #ML_ERROR_NOT_SUPPORTED Not supported. + * @retval #ML_ERROR_INVALID_PARAMETER Invalid parameter. + */ +int ml_train_lr_scheduler_create(ml_train_lr_scheduler_h *lr_scheduler, + ml_train_lr_scheduler_type_e type); + +/** + * @brief Frees the learning rate scheduler. + * @details Use this function to destroy learning rate scheduler. Fails if + * learning rate scheduler is owned by a optimizer. + * @since_tizen 7.5 + * @param[in] lr_scheduler The NNTrainer learning rate scheduler handle. + * @return @c 0 on success. Otherwise a negative error value. + * @retval #ML_ERROR_NONE Successful. + * @retval #ML_ERROR_NOT_SUPPORTED Not supported. + * @retval #ML_ERROR_INVALID_PARAMETER Invalid parameter. + */ +int ml_train_lr_scheduler_destroy(ml_train_lr_scheduler_h lr_scheduler); + +/** + * @brief Sets the learning rate scheduler property. + * @details Use this function to set learning rate scheduler property. + * @since_tizen 7.5 + * @param[in] lr_scheduler The NNTrainer learning rate scheduler handle. + * @param[in] ... Property values with NULL for termination. + * @return @c 0 on success. Otherwise a negative error value. + * @retval #ML_ERROR_NONE Successful. + * @retval #ML_ERROR_NOT_SUPPORTED Not supported. + * @retval #ML_ERROR_INVALID_PARAMETER Invalid parameter. + */ +int ml_train_lr_scheduler_set_property(ml_train_lr_scheduler_h lr_scheduler, + ...); + +/** * @deprecated Deprecated since 6.5. Use ml_train_dataset_create() instead. * @brief Creates a dataset with generators to feed to a neural network. * @details Use this function to create a neural network dataset using diff --git a/api/capi/include/nntrainer_internal.h b/api/capi/include/nntrainer_internal.h index aceb59c..592dd70 100644 --- a/api/capi/include/nntrainer_internal.h +++ b/api/capi/include/nntrainer_internal.h @@ -95,6 +95,18 @@ typedef struct { } ml_train_layer; /** + * @brief Struct to wrap learning rate scheduler for the API + * @note optimizer mutex must be locked before learning rate scheduler lock, if + * optimizer lock is needed + */ +typedef struct { + uint magic; + std::shared_ptr lr_scheduler; + bool in_use; + std::mutex m; +} ml_train_lr_scheduler; + +/** * @brief Struct to wrap neural network optimizer for the API * @note model mutex must be locked before optimizer lock, if model lock is * needed @@ -102,6 +114,7 @@ typedef struct { typedef struct { uint magic; std::shared_ptr optimizer; + ml_train_lr_scheduler *lr_sheduler; bool in_use; std::mutex m; } ml_train_optimizer; @@ -141,7 +154,7 @@ typedef struct { } while (0) /** - * @brief Check validity of the user passed arguments and lock the object + * @brief Check validity of the user passed arguments */ #define ML_TRAIN_GET_VALID_HANDLE(obj, obj_h, obj_type, obj_name) \ do { \ @@ -164,7 +177,8 @@ typedef struct { } while (0) /** - * @brief Check validity of the user passed arguments and lock the object + * @brief Check validity of the user passed arguments, reset magic if in use + * and lock the object */ #define ML_TRAIN_GET_VALID_HANDLE_LOCKED_RESET(obj, obj_h, obj_type, obj_name) \ do { \ @@ -177,7 +191,7 @@ typedef struct { } while (0) /** - * @brief Check validity of the user passed arguments and lock the object + * @brief Reset object magic */ #define ML_TRAIN_RESET_VALIDATED_HANDLE(obj) \ do { \ @@ -231,6 +245,22 @@ typedef struct { "optimizer") /** + * @brief Check validity of passed lr_scheduler and lock the object + */ +#define ML_TRAIN_GET_VALID_LR_SCHEDULER_LOCKED(nnlrscheduler, lrscheduler) \ + ML_TRAIN_GET_VALID_HANDLE_LOCKED(nnlrscheduler, lrscheduler, \ + ml_train_lr_scheduler, "lr_scheduler") + +/** + * @brief Check validity of passed lr_scheduler, reset magic and lock the + * object + */ +#define ML_TRAIN_GET_VALID_LR_SCHEDULER_LOCKED_RESET(nnlrscheduler, \ + lrscheduler) \ + ML_TRAIN_GET_VALID_HANDLE_LOCKED_RESET( \ + nnlrscheduler, lrscheduler, ml_train_lr_scheduler, "lr_scheduler") + +/** * @brief Check validity of passed dataset and lock the object */ #define ML_TRAIN_GET_VALID_DATASET_LOCKED(nndataset, dataset) \ @@ -394,6 +424,25 @@ int ml_train_optimizer_set_property_with_single_param( ml_train_optimizer_h optimizer, const char *single_param); /** + * @brief Sets the learning rate scheduler property with single param. + * @details Use this function to set learning rate scheduler property. + * @since_tizen 7.5 + * API to solve va_list issue of Dllimport of C# interop. + * The input format of single_param must be 'key = value' format, and it + * received as shown in the example below. delimiter is '|'. e.g) + * ml_train_lr_scheduler_set_property_with_single_param(lr_scheduler, + * "learning_rate=0.01 | decay_rate=0.5 | decay_steps=1000"); + * @param[in] lr_scheduler The learning rate scheduler handle. + * @param[in] single_param Property values. + * @return @c 0 on success. Otherwise a negative error value. + * @retval #ML_ERROR_NONE Successful. + * @retval #ML_ERROR_NOT_SUPPORTED Not supported. + * @retval #ML_ERROR_INVALID_PARAMETER Invalid parameter. + */ +int ml_train_lr_scheduler_set_property_with_single_param( + ml_train_lr_scheduler_h lr_scheduler, const char *single_param); + +/** * @brief Sets the neural network dataset property with single param. * @details Use this function to set dataset property for a specific mode. * API to solve va_list issue of Dllimport of C# interop. diff --git a/api/capi/src/nntrainer.cpp b/api/capi/src/nntrainer.cpp index 0e8f64e..2e90275 100644 --- a/api/capi/src/nntrainer.cpp +++ b/api/capi/src/nntrainer.cpp @@ -393,10 +393,12 @@ int ml_train_model_destroy(ml_train_model_h model) { ML_TRAIN_ADOPT_LOCK(nnmodel, model_lock); } - std::shared_ptr m; - m = nnmodel->model; - if (nnmodel->optimizer) { + if (nnmodel->optimizer->lr_sheduler) { + ML_TRAIN_RESET_VALIDATED_HANDLE(nnmodel->optimizer->lr_sheduler); + delete nnmodel->optimizer->lr_sheduler; + } + ML_TRAIN_RESET_VALIDATED_HANDLE(nnmodel->optimizer); delete nnmodel->optimizer; } @@ -547,8 +549,9 @@ int ml_train_model_set_optimizer(ml_train_model_h model, status = nntrainer_exception_boundary(f); if (status == ML_ERROR_NONE) { nnopt->in_use = true; - if (nnmodel->optimizer) + if (nnmodel->optimizer) { nnmodel->optimizer->in_use = false; + } nnmodel->optimizer = nnopt; } @@ -754,6 +757,7 @@ int ml_train_optimizer_create(ml_train_optimizer_h *optimizer, ml_train_optimizer *nnopt = new ml_train_optimizer; nnopt->magic = ML_NNTRAINER_MAGIC; nnopt->in_use = false; + nnopt->lr_sheduler = NULL; returnable f = [&]() { nnopt->optimizer = @@ -787,6 +791,11 @@ int ml_train_optimizer_destroy(ml_train_optimizer_h optimizer) { "Delete model will delete this optimizer."); return ML_ERROR_INVALID_PARAMETER; } + + if (nnopt->lr_sheduler) { + ML_TRAIN_RESET_VALIDATED_HANDLE(nnopt->lr_sheduler); + delete nnopt->lr_sheduler; + } } delete nnopt; @@ -837,6 +846,138 @@ int ml_train_optimizer_set_property_with_single_param( return ml_train_optimizer_set_property(optimizer, single_param, NULL); } +int ml_train_optimizer_set_lr_scheduler(ml_train_optimizer_h optimizer, + ml_train_lr_scheduler_h lr_scheduler) { + int status = ML_ERROR_NONE; + ml_train_optimizer *nnopt; + ml_train_lr_scheduler *nnlrscheduler; + + check_feature_state(); + + ML_TRAIN_GET_VALID_OPT_LOCKED(nnopt, optimizer); + ML_TRAIN_ADOPT_LOCK(nnopt, opt_lock); + ML_TRAIN_GET_VALID_LR_SCHEDULER_LOCKED(nnlrscheduler, lr_scheduler); + ML_TRAIN_ADOPT_LOCK(nnlrscheduler, lr_scheduler_lock); + + if (nnlrscheduler->in_use) { + ml_loge("learning rate scheduler already in use."); + return ML_ERROR_INVALID_PARAMETER; + } + + std::shared_ptr opt; + std::shared_ptr lr_sched; + + opt = nnopt->optimizer; + lr_sched = nnlrscheduler->lr_scheduler; + + returnable f = [&]() { return opt->setLearningRateScheduler(lr_sched); }; + + status = nntrainer_exception_boundary(f); + if (status == ML_ERROR_NONE) { + nnlrscheduler->in_use = true; + if (nnopt->lr_sheduler) { + nnopt->lr_sheduler->in_use = false; + } + nnopt->lr_sheduler = nnlrscheduler; + } + + return status; +} + +int ml_train_lr_scheduler_create(ml_train_lr_scheduler_h *lr_scheduler, + ml_train_lr_scheduler_type_e type) { + int status = ML_ERROR_NONE; + + check_feature_state(); + + ml_train_lr_scheduler *nnlrscheduler = new ml_train_lr_scheduler; + nnlrscheduler->magic = ML_NNTRAINER_MAGIC; + nnlrscheduler->in_use = false; + + returnable f = [&]() { + nnlrscheduler->lr_scheduler = ml::train::createLearningRateScheduler( + (ml::train::LearningRateSchedulerType)type); + return ML_ERROR_NONE; + }; + + status = nntrainer_exception_boundary(f); + if (status != ML_ERROR_NONE) { + delete nnlrscheduler; + ml_loge("creating optimizer failed"); + } else { + *lr_scheduler = nnlrscheduler; + } + + return status; +} + +int ml_train_lr_scheduler_destroy(ml_train_lr_scheduler_h lr_scheduler) { + int status = ML_ERROR_NONE; + ml_train_lr_scheduler *nnlrscheduler; + + check_feature_state(); + + { + ML_TRAIN_GET_VALID_LR_SCHEDULER_LOCKED_RESET(nnlrscheduler, lr_scheduler); + ML_TRAIN_ADOPT_LOCK(nnlrscheduler, lr_scheduler_lock); + + if (nnlrscheduler->in_use) { + ml_loge( + "Cannot delete learning rate scheduler already set to a optimizer." + "Delete optimizer will delete this learning rate scheduler."); + return ML_ERROR_INVALID_PARAMETER; + } + } + + delete nnlrscheduler; + return status; +} + +int ml_train_lr_scheduler_set_property(ml_train_lr_scheduler_h lr_scheduler, + ...) { + int status = ML_ERROR_NONE; + ml_train_lr_scheduler *nnlrscheduler; + const char *data; + std::shared_ptr lr_sched; + + check_feature_state(); + + ML_TRAIN_VERIFY_VALID_HANDLE(lr_scheduler); + + std::vector arg_list; + va_list arguments; + va_start(arguments, lr_scheduler); + + while ((data = va_arg(arguments, const char *))) { + arg_list.push_back(data); + } + + va_end(arguments); + + { + ML_TRAIN_GET_VALID_LR_SCHEDULER_LOCKED(nnlrscheduler, lr_scheduler); + ML_TRAIN_ADOPT_LOCK(nnlrscheduler, lr_scheduler_lock); + + lr_sched = nnlrscheduler->lr_scheduler; + } + + returnable f = [&]() { + lr_sched->setProperty(arg_list); + return ML_ERROR_NONE; + }; + + status = nntrainer_exception_boundary(f); + + return status; +} + +int ml_train_lr_scheduler_set_property_with_single_param( + ml_train_lr_scheduler_h lr_scheduler, const char *single_param) { + ML_TRAIN_VERIFY_VALID_HANDLE(lr_scheduler); + + return ml_train_lr_scheduler_set_property(lr_scheduler, single_param, NULL); +} + int ml_train_dataset_create(ml_train_dataset_h *dataset) { return ml_train_dataset_create(dataset, ml::train::DatasetType::UNKNOWN, nullptr, nullptr, nullptr); diff --git a/api/ccapi/include/optimizer.h b/api/ccapi/include/optimizer.h index 584caad..6cc6363 100644 --- a/api/ccapi/include/optimizer.h +++ b/api/ccapi/include/optimizer.h @@ -141,9 +141,10 @@ SGD(const std::vector &properties = {}) { * @brief Enumeration of learning rate scheduler type */ enum LearningRateSchedulerType { - CONSTANT = 0, /**< constant */ - EXPONENTIAL, /**< exponentially decay */ - STEP /**< step wise decay */ + CONSTANT = ML_TRAIN_LR_SCHEDULER_TYPE_CONSTANT, /**< constant */ + EXPONENTIAL = + ML_TRAIN_LR_SCHEDULER_TYPE_EXPONENTIAL, /**< exponentially decay */ + STEP = ML_TRAIN_LR_SCHEDULER_TYPE_STEP /**< step wise decay */ }; /** diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index b11a884..f91d3e4 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -88,6 +88,17 @@ typedef enum { } ml_train_optimizer_type_e; /** + * @brief Enumeration for the learning rate scheduler type of NNTrainer. + * @since_tizen 7.5 + */ +typedef enum { + ML_TRAIN_LR_SCHEDULER_TYPE_CONSTANT = 0, /**< Constant lr scheduler */ + ML_TRAIN_LR_SCHEDULER_TYPE_EXPONENTIAL = 1, /**< Exponentially lr scheduler */ + ML_TRAIN_LR_SCHEDULER_TYPE_STEP = 2, /**< Step lr scheduler */ + ML_TRAIN_LR_SCHEDULER_TYPE_UNKNOWN = 999 /**< Unknown lr scheduler */ +} ml_train_lr_scheduler_type_e; + +/** * @brief Dataset generator callback function for train/valid/test data. * * @details The user of the API must provide this callback function to supply -- 2.7.4