From 05fef153fb8364d1882ffd553a1abf33f8934afe Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Mon, 19 Oct 2020 20:00:22 +0900 Subject: [PATCH] [Bug/Act] Fix setActivation call properly `Act::setActivation` should be called for activationLayer. However, because `virtual Layer::setActivation` had diffrent signature `Act::setActivation`, There was no way this function can be called. This patch fixes the issue. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/include/layer.h | 36 +++++++++++++++++++++--------------- nntrainer/src/activation_layer.cpp | 3 ++- nntrainer/src/layer.cpp | 11 +++-------- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/nntrainer/include/layer.h b/nntrainer/include/layer.h index a93267e..1207bfd 100644 --- a/nntrainer/include/layer.h +++ b/nntrainer/include/layer.h @@ -240,6 +240,13 @@ public: int setOptimizer(std::shared_ptr opt); /** + * @brief Get the Optimizer object + * + * @return std::shared_ptr optimizer + */ + std::shared_ptr getOptimizer() { return opt; } + + /** * @brief Activation Type Getter * @retval Activation Type. */ @@ -331,6 +338,13 @@ public: return weight_list.get()[position]; } + /** + * @brief Get the number of weights + * + * @return unsigned int number of weights + */ + unsigned int getNumWeights() { return num_weights; } + #if defined(ENABLE_TEST) /** * @brief Set the batch for the layer @@ -453,13 +467,6 @@ protected: } /** - * @brief Get the number of weights - * - * @return unsigned int number of weights - */ - unsigned int getNumWeights() { return num_weights; } - - /** * @brief weight_list in this layer. This contains trainable weights of * layers. */ @@ -487,6 +494,13 @@ protected: */ void setType(LayerType type) { this->type = type; } + /** + * @brief Activation Setter + * @param[in] activation activation type + * @throw std::invalid_argument when ActivationType is unknown + */ + virtual void setActivation(ActivationType activation); + private: /** * @brief Set containing all the names of layers @@ -533,14 +547,6 @@ private: virtual void printMetric(std::ostream &out); /** - * @brief Activation Setter - * @param[in] activation activation type - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - int setActivation(ActivationType activation); - - /** * @brief set weight decay parameters * @param[in] w struct for weight decay */ diff --git a/nntrainer/src/activation_layer.cpp b/nntrainer/src/activation_layer.cpp index c796fb1..4bca0cb 100644 --- a/nntrainer/src/activation_layer.cpp +++ b/nntrainer/src/activation_layer.cpp @@ -104,6 +104,8 @@ int ActivationLayer::setActivation( * @param[in] ActivationType ActivationType ActivationType to be set */ void ActivationLayer::setActivation(ActivationType acti_type) { + Layer::setActivation(acti_type); + switch (acti_type) { case ActivationType::ACT_TANH: this->setActivation(tanhFloat, tanhPrime); @@ -124,7 +126,6 @@ void ActivationLayer::setActivation(ActivationType acti_type) { default: throw std::runtime_error("Error: Not Supported Activation Type"); } - this->activation_type = acti_type; } Tensor ActivationLayer::softmax(Tensor const &t) { diff --git a/nntrainer/src/layer.cpp b/nntrainer/src/layer.cpp index daccd27..5d3be2c 100644 --- a/nntrainer/src/layer.cpp +++ b/nntrainer/src/layer.cpp @@ -30,15 +30,11 @@ namespace nntrainer { -int Layer::setActivation(ActivationType acti) { - int status = ML_ERROR_NONE; +void Layer::setActivation(ActivationType acti) { if (acti == ActivationType::ACT_UNKNOWN) { - ml_loge("Error:have to specify activation function"); - return ML_ERROR_INVALID_PARAMETER; + throw std::invalid_argument("Error:have to specify activation function"); } activation_type = acti; - - return status; } int Layer::setOptimizer(std::shared_ptr opt) { @@ -159,8 +155,7 @@ void Layer::setProperty(const PropertyType type, const std::string &value) { break; case PropertyType::activation: if (!value.empty()) { - status = setActivation((ActivationType)parseType(value, TOKEN_ACTI)); - throw_status(status); + setActivation((ActivationType)parseType(value, TOKEN_ACTI)); } break; case PropertyType::flatten: -- 2.7.4