endif
endif
-
include $(CLEAR_VARS)
NNTRAINER_JNI_ROOT := $(NNTRAINER_ROOT)/jni
endif #ENABLE_TFLITE_BACKBONE
-
ifeq ($(ENABLE_BLAS), 1)
include $(CLEAR_VARS)
include $(PREBUILT_STATIC_LIBRARY)
endif #ENABLE_BLAS
-
include $(CLEAR_VARS)
NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
$(NNTRAINER_ROOT)/nntrainer/models/model_loader.cpp \
+ $(NNTRAINER_ROOT)/nntrainer/models/dynamic_training_optimization.cpp \
$(NNTRAINER_ROOT)/nntrainer/dataset/databuffer.cpp \
$(NNTRAINER_ROOT)/nntrainer/dataset/databuffer_factory.cpp \
$(NNTRAINER_ROOT)/nntrainer/dataset/databuffer_func.cpp \
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file dynamic_training_optimization.cpp
+ * @date 5 January 2021
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is Dynamic Training Optimization for Neural Network
+ *
+ */
+
+#include <random>
+#include <vector>
+
+#include <dynamic_training_optimization.h>
+#include <layer_internal.h>
+#include <tensor.h>
+#include <util_func.h>
+
+namespace nntrainer {
+DynamicTrainingOptimization::DynamicTrainingOptimization(int threshold_,
+ int skip_n_iter) :
+ threshold(threshold_),
+ enabled(false),
+ epsilon(1e-7),
+ skip_n_iterations(skip_n_iter) {
+ reduce_op = reduceByNorm;
+ calc_ratio_op = ratioUsingDerivative;
+ rng.seed(getSeed());
+ dist = std::uniform_real_distribution<float>(0.0, 1.0);
+}
+
+/**
+ * @brief Check if the given weights can skip updating
+ * @note true if should be applied, else false
+ */
+bool DynamicTrainingOptimization::checkIfApply(
+ const std::vector<Weight> &weights, const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ const std::shared_ptr<Optimizer> &opt, int iteration) {
+ if (!enabled || iteration < skip_n_iterations)
+ return true;
+
+ std::vector<bool> apply;
+ apply.reserve(weights.size());
+
+ for (auto const &weight : weights)
+ apply.push_back(checkIfApply(weight, input, output, opt, iteration));
+
+ return std::accumulate(apply.begin(), apply.end(), true,
+ std::logical_and<bool>());
+}
+
+/**
+ * @brief Check if the given weight can skip updating
+ * @note true if should be applied, else false
+ */
+bool DynamicTrainingOptimization::checkIfApply(
+ const Weight &weight, const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ const std::shared_ptr<Optimizer> &opt, int iteration) {
+ if (iteration < skip_n_iterations)
+ return true;
+
+ if (!weight.getTrainable() || weight.getGradientRef().uninitialized())
+ return true;
+
+ float reduced_ratio = calc_ratio_op(weight, input, output, reduce_op);
+
+ return checkIfApply(reduced_ratio, (float)opt->getLearningRate(iteration));
+}
+
+/**
+ * @brief Calculate the ratio of update to the weight using derivative
+ */
+float DynamicTrainingOptimization::ratioUsingDerivative(
+ const Weight &weight, const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ std::function<float(Tensor const &)> reduce_op) {
+ float reduced_derivative = reduce_op(output->getGradientRef());
+ float reduced_input = reduce_op(input->getVariableRef());
+ float reduced_weight = reduce_op(weight.getVariableRef());
+ float reduced_grad = reduced_derivative * reduced_input;
+
+ return reduced_grad / reduced_weight;
+}
+
+/**
+ * @brief Calculate the ratio of update to the weight using gradient
+ */
+float DynamicTrainingOptimization::ratioUsingGradient(
+ const Weight &weight, const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ std::function<float(Tensor const &)> reduce_op) {
+ Tensor ratio = weight.getGradientRef().divide(weight.getVariableRef());
+ return reduce_op(ratio);
+}
+
+/**
+ * @brief Check if the update should be applied or skipped
+ * @note true if should be applied, else false
+ */
+bool DynamicTrainingOptimization::checkIfApply(float reduced_ratio,
+ float learning_rate) {
+ /**
+ * If the reduced update ratio is higher than 1, then always apply update.
+ * If the reduced update raito is less than 1, then apply it with
+ * probability = update ratio
+ */
+ if (dist(rng) < reduced_ratio * learning_rate / threshold)
+ return true;
+
+ return false;
+}
+
+/**
+ * @brief Operation to decide if update should be skipped
+ * @note Calculate l0 norm of the tensor
+ */
+float DynamicTrainingOptimization::reduceByMax(Tensor const &ratio) {
+ return ratio.max_abs();
+}
+
+/**
+ * @brief Operation to decide if update should be skipped
+ * @note Calcalate l2 norm of the tensor averaged by its size
+ */
+float DynamicTrainingOptimization::reduceByNorm(Tensor const &ratio) {
+ float l2norm = ratio.l2norm();
+ return l2norm / std::sqrt(ratio.length());
+}
+
+/**< Different types of reduce operations */
+const std::string DynamicTrainingOptimization::dft_opt_max = "max";
+const std::string DynamicTrainingOptimization::dft_opt_norm = "norm";
+
+const std::string DynamicTrainingOptimization::dft_opt_mode_gradient =
+ "gradient";
+const std::string DynamicTrainingOptimization::dft_opt_mode_derivative =
+ "derivative";
+
+} /* namespace nntrainer */
/**
* Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
*
- * @file activation_layer.cpp
+ * @file dynamic_training_optimization.h
* @date 4 January 2021
* @see https://github.com/nnstreamer/nntrainer
* @author Parichay Kapoor <pk.kapoor@samsung.com>
* @bug No known bugs except for NYI items
* @brief This is Dynamic Training Optimization for Neural Network
*
+ * Dynamic training aims to optimize the cost of applying the gradient.
+ * The cost of applying the gradient includes the cost of the optimizer (adam,
+ * etc) where the optimizer variables are updated, and the cost of actually
+ * updating the weights (which can be non-trivial with bigger weights and
+ * distributed training).
+ *
+ * There are two supported modes:
+ * 1. Gradient Mode: The already calculated gradient is used to estimate if this
+ * gradient must be used to update the weight, or if this update must be
+ * skipped.
+ *
+ * 2. Derivative Mode: This mode tries to estimate an approximate gradient with
+ * low cost in order to save the cost of calculating gradient. This cost of
+ * calculating gradient is wasted if the gradient is not going to be applied.
+ *
+ * There are two supported reduction operations which reduce the gradient and
+ * the weight to a single value in order to compare it with a threshold.
+ * If the reduced value is less than threshold, the update is performed with
+ * some probabilty proportional to the value. If the reduced value is higher
+ * than threshold, then the update is always performed.
+ *
*/
#ifndef __DYNAMIC_TRAINING_OPT_H__
#include <layer_internal.h>
#include <tensor.h>
-#include <util_func.h>
namespace nntrainer {
/**
* @brief Constructor of DynamicFineTuning Optimization
*/
- DynamicTrainingOptimization(int threshold_ = 1, int skip_n_iter = 1) :
- threshold(threshold_),
- enabled(false),
- epsilon(1e-7),
- skip_n_iterations(skip_n_iter) {
- reduce_op = reduce_by_norm;
- rng.seed(getSeed());
- dist = std::uniform_real_distribution<float>(0.0, 1.0);
- }
+ DynamicTrainingOptimization(int threshold_ = 1, int skip_n_iter = 1);
/**
* @brief Set threshold for optimization
*/
- void setThreshold(float threshold_) { threshold = threshold_; };
+ void setThreshold(float threshold_) {
+ if (threshold_ < epsilon)
+ throw std::invalid_argument("Threshold is too small or negative");
+
+ threshold = threshold_;
+ };
/**
* @brief Set the reduce operation for dynamic optimization
*/
- void setOp(std::string op) {
- enabled = true;
+ void setOp(const std::string &op) {
if (op == dft_opt_max)
- reduce_op = reduce_by_max;
+ reduce_op = reduceByMax;
else if (op == dft_opt_norm)
- reduce_op = reduce_by_norm;
+ reduce_op = reduceByNorm;
else
- enabled = false;
+ throw std::invalid_argument(
+ "Unsupported reduction op in dynamic training");
};
/**
- * @brief Set initial iteraions to skip from optimization
+ * @brief Enable the optimization
*/
- void setSkipIterations(int skip_n_iter) { skip_n_iterations = skip_n_iter; }
+ void enable() { enabled = true; }
/**
- * @brief Check if the given weights can skip updating
+ * @brief Disable the optimization
*/
- std::vector<bool> checkIfApply(const std::vector<Weight> &weights,
- const std::shared_ptr<Var_Grad> input,
- const std::shared_ptr<Var_Grad> output,
- const std::shared_ptr<Optimizer> opt,
- int iteration) {
- if (!enabled)
- return std::vector<bool>(weights.size(), true);
+ void disable() { enabled = false; }
- std::vector<bool> apply;
- apply.reserve(weights.size());
+ /**
+ * @brief Set the mode for optimization
+ */
+ void setMode(const std::string &mode_) {
+ calc_ratio_mode = mode_;
+ if (mode_ == dft_opt_mode_derivative)
+ calc_ratio_op = ratioUsingDerivative;
+ else if (mode_ == dft_opt_mode_gradient)
+ calc_ratio_op = ratioUsingGradient;
+ else
+ throw std::invalid_argument("Unsupported mode in dynamic training");
+ }
- for (auto const &weight : weights)
- apply.push_back(checkIfApply(weight, input, output, opt, iteration));
+ /**
+ * @brief Check if the derivative mode is used for optimization
+ * @note Use the derivative to calculate an approximate gradient to estimate
+ * if the actual gradient needs applying
+ */
+ bool isDerivativeMode() {
+ if (enabled && calc_ratio_mode == dft_opt_mode_derivative)
+ return true;
+ return false;
+ }
- return apply;
+ /**
+ * @brief Check if the gradient mode is used for optimization
+ * @note Use the gradient to estimate if this gradient needs applying
+ */
+ bool isGradientMode() {
+ if (enabled && calc_ratio_mode == dft_opt_mode_gradient)
+ return true;
+ return false;
}
+ /**
+ * @brief Initial iterations to not perform dynamic training optimization
+ * @note If the current iteration is less than skip_n_iterations, the weights
+ * will updated and dynamic training optimization will not be performed.
+ *
+ */
+ void setSkipIterations(int skip_n_iter) { skip_n_iterations = skip_n_iter; }
+
+ /**
+ * @brief Check if the given weights can skip updating
+ * @param[in] weights All the weight tensors for a layer
+ * @param[in] input Input tensor for a layer
+ * @param[in] output Output tensor for a layer, from forward operation
+ * @param[in] opt Optimizer used to update the layer weights
+ * @param[in] iteration Current iteration number in training
+ * @note true if should be applied, else false
+ */
+ bool checkIfApply(const std::vector<Weight> &weights,
+ const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ const std::shared_ptr<Optimizer> &opt, int iteration);
+
/**
* @brief Check if the given weight can skip updating
+ * @param[in] weight Weight tensor for a layer
+ * @param[in] input Input tensor for a layer
+ * @param[in] output Output tensor for a layer, from forward operation
+ * @param[in] opt Optimizer used to update the layer weights
+ * @param[in] iteration Current iteration number in training
+ * @note true if should be applied, else false
*/
bool checkIfApply(const Weight &weight,
const std::shared_ptr<Var_Grad> &input,
const std::shared_ptr<Var_Grad> &output,
- const std::shared_ptr<Optimizer> &opt, int iteration) {
- // by gradient
- if (iteration < skip_n_iterations)
- return true;
+ const std::shared_ptr<Optimizer> &opt, int iteration);
- Tensor &weight_grad = weight.getGradientRef();
- Tensor &weight_var = weight.getVariableRef();
+ /**< Different types of reduce operations */
+ static const std::string dft_opt_max;
+ static const std::string dft_opt_norm;
- if (!weight.getTrainable() || weight_grad.uninitialized())
- return true;
+ /**< Different types of optimization modes */
+ static const std::string dft_opt_mode_gradient;
+ static const std::string dft_opt_mode_derivative;
- Tensor ratio = weight_grad.divide(weight_var);
+private:
+ std::mt19937 rng; /**< random number generator */
+ std::uniform_real_distribution<float>
+ dist; /**< uniform random distribution */
+ float threshold; /**< threshold to decide when to skip updating */
+ bool enabled; /**< if optimization is enabled */
+ float epsilon; /**< epsilon to skip overflow */
+ int skip_n_iterations; /**< skip initial iterations from optimization */
+ std::string calc_ratio_mode; /**< the mode to calc the ratio */
- // by derivative
- // Tensor ratio = output.getGradientRef().divide(weight.getVariableRef());
- // ratio.multiply_i(input.getVariableRef());
+ std::function<float(Tensor const &)>
+ reduce_op; /**< operation to reduce update ratio to value */
+ std::function<float(const Weight &, const std::shared_ptr<Var_Grad> &,
+ const std::shared_ptr<Var_Grad> &,
+ std::function<float(Tensor const &)> reduce_op)>
+ calc_ratio_op; /**< calculate the ratio of update to the weight */
- /**
- * If the reduced update ratio is higher than 1, then always apply update.
- * If the reduced update raito is less than 1, then apply it with
- * probability = update ratio
- */
- if (dist(rng) <
- reduce_op(ratio) * ((float)opt->getLearningRate(iteration)) / threshold)
- return false;
+ /**
+ * @brief Calculate the ratio of update to the weight using derivative
+ * @param[in] weight Weight tensor for a layer
+ * @param[in] input Input tensor for a layer
+ * @param[in] output Output tensor for a layer, from forward operation
+ * @param[in] reduce_op Operation to reduce the ratio
+ */
+ static float
+ ratioUsingDerivative(const Weight &weight,
+ const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ std::function<float(Tensor const &)> reduce_op);
- return true;
- }
+ /**
+ * @brief Calculate the ratio of update to the weight using gradient
+ * @param[in] weight Weight tensor for a layer
+ * @param[in] input Input tensor for a layer
+ * @param[in] output Output tensor for a layer, from forward operation
+ * @param[in] reduce_op Operation to reduce the ratio
+ */
+ static float
+ ratioUsingGradient(const Weight &weight,
+ const std::shared_ptr<Var_Grad> &input,
+ const std::shared_ptr<Var_Grad> &output,
+ std::function<float(Tensor const &)> reduce_op);
+
+ /**
+ * @brief Check if the update should be applied or skipped
+ * @note true if should be applied, else false
+ */
+ bool checkIfApply(float reduced_ratio, float learning_rate);
/**
* @brief Operation to decide if update should be skipped
* @note Calculate l0 norm of the tensor
*/
- static float reduce_by_max(Tensor const &ratio) { return ratio.max_abs(); }
+ static float reduceByMax(Tensor const &ratio);
/**
* @brief Operation to decide if update should be skipped
* @note Calcalate l2 norm of the tensor averaged by its size
*/
- static float reduce_by_norm(Tensor const &ratio) {
- float l2norm = ratio.l2norm();
- return (l2norm * l2norm) / ratio.length();
- }
-
- /**< Different types of reduce operations */
- static const std::string dft_opt_off;
- static const std::string dft_opt_max;
- static const std::string dft_opt_norm;
-
-private:
- std::mt19937 rng; /**< random number generator */
- std::uniform_real_distribution<float>
- dist; /**< uniform random distribution */
- float threshold; /**< threshold to decide when to skip updating */
- bool enabled; /**< if optimization is enabled */
- float epsilon; /**< epsilon to skip overflow */
- int skip_n_iterations; /**< skip initial iterations from optimization */
- std::function<float(Tensor const &)>
- reduce_op; /**< operation to reduce update ratio to value */
+ static float reduceByNorm(Tensor const &ratio);
};
-/**< Different types of reduce operations */
-const std::string dft_opt_off = "off";
-const std::string dft_opt_max = "max";
-const std::string dft_opt_norm = "norm";
-
} /* namespace nntrainer */
#endif /* __cplusplus */
model_sources = [
'model_loader.cpp',
- 'neuralnet.cpp'
+ 'neuralnet.cpp',
+ 'dynamic_training_optimization.cpp'
]
model_headers = [
return forwarding(training);
}
+void NeuralNetwork::backwarding(std::shared_ptr<Layer> layer, int iteration,
+ bool calc_derivative) {
+ /**
+ * Do not change this order:
+ * 1. calcGradient
+ * 2. calcDerivative
+ * 3. applyGradient
+ */
+ bool apply_gradient;
+ /** If gradient optimization mode, then calculate gradient first */
+ if (dynamic_training_opt.isGradientMode())
+ layer->calcGradient();
+
+ /**
+ * If optimization off, or gradient must be applied, then this will be true
+ */
+ apply_gradient = dynamic_training_opt.checkIfApply(
+ layer->getWeightsRef(), layer->net_input[0], layer->net_hidden[0], opt,
+ iteration);
+
+ /** If gradient must be applied and its not gradient mode, calculate gradient
+ */
+ if (!dynamic_training_opt.isGradientMode() && apply_gradient)
+ layer->calcGradient();
+
+ if (calc_derivative)
+ layer->calcDerivative();
+
+ if (apply_gradient)
+ opt->apply_gradients(layer->getWeightsRef(), iteration);
+}
+
/**
* @brief back propagation
* Call backwarding function of layer in reverse order
*/
auto iter_begin = model_graph.getBackwardingBeginIter();
auto iter_end = model_graph.getBackwardingEndIter();
- for (auto iter = iter_begin; iter != iter_end - 1; iter++) {
- auto layer = iter->layer;
- layer->backwarding();
-
- auto apply_grad_check =
- dft_opt.checkIfApply(layer->getWeightsRef(), layer->net_input[0],
- layer->net_hidden[0], opt, iteration);
- std::vector<Weight> weights_to_update;
-
- for (unsigned int idx = 0; idx < apply_grad_check.size(); idx++) {
- if (apply_grad_check[idx])
- weights_to_update.emplace_back(layer->getWeightsRef()[idx]);
- }
- opt->apply_gradients(weights_to_update, iteration);
+ for (auto iter = iter_begin; iter != iter_end - 1; iter++) {
+ backwarding(iter->layer, iteration, true);
}
auto last_layer = (iter_end - 1)->layer;
/**
* The last trainable layer need not calculate the derivatives
- * Do not change this order:
- * 1. calcGradient
- * 2. calcDerivative
- * 3. applyGradient
*/
- last_layer->calcGradient();
#ifdef ENABLE_TEST
- last_layer->calcDerivative();
+ backwarding(last_layer, iteration, true);
+#else
+ backwarding(last_layer, iteration, false);
#endif
- opt->apply_gradients(last_layer->getWeightsRef(), iteration);
}
/**
#include <optimizer_internal.h>
#include <pooling2d_layer.h>
#include <tensor.h>
-#include <util_func.h>
#include <model.h>
#include <nntrainer-api-common.h>
* "max" and "norm" for now
*/
void enableDynamicTraining(
- float threshold,
- std::string op = DynamicTrainingOptimization::dft_opt_norm) {
- dft_opt.setThreshold(threshold);
- dft_opt.setOp(op);
+ float threshold, std::string op = DynamicTrainingOptimization::dft_opt_norm,
+ std::string mode = DynamicTrainingOptimization::dft_opt_mode_derivative) {
+ dynamic_training_opt.setThreshold(threshold);
+ dynamic_training_opt.setOp(op);
+ dynamic_training_opt.setMode(mode);
+ dynamic_training_opt.enable();
}
/**
* @brief Disable dynamic fine-tuning optimization
*/
- void disableDynamicFineTuning() {
- dft_opt.setOp(DynamicTrainingOptimization::dft_opt_off);
- }
+ void disableDynamicFineTuning() { dynamic_training_opt.disable(); }
/// @todo Make a more common class have this
/// Maybe appcontext can have this?
bool in_place_optimization; /**< Run batch normalization, activation, etc
layers in-place */
- DynamicTrainingOptimization dft_opt; /**< Dynamic fine-tuning optimization
- mode. supported modes are "off", "max" and "norm" */
+ DynamicTrainingOptimization dynamic_training_opt; /**< Dynamic fine-tuning
+ optimization mode. supported modes are "max" and "norm" */
/**
* @brief print function for neuralnet
* @retval true if matches, false is error
*/
bool validateInput(sharedConstTensors X);
+
+ /**
+ * @brief Backward Propagation for the layer
+ * @param[in] layer Layer to backpropagate
+ * @param[in] iteration Iteration Number for the optimizer
+ * @param[in] calc_derivative If the derivative for previous layer must be
+ * calculated
+ */
+ void backwarding(std::shared_ptr<Layer> layer, int iteration,
+ bool calc_derivative);
};
} /* namespace nntrainer */
}
}
+static unsigned int isamax_raw(const unsigned int N, const float *X,
+ const int incX) {
+
+ unsigned int max_idx = 0;
+ float max_val = X[0];
+ for (unsigned int n = 1; n < N; n += incX) {
+ float cur_val = abs(X[n]);
+ if (cur_val > max_val) {
+ max_val = cur_val;
+ max_idx = n;
+ }
+ }
+}
+
#endif
void saxpy(const unsigned int N, const float alpha, const float *X,
#endif
}
+unsigned int isamax(const unsigned int N, const float *X, const int incX) {
+#ifdef USE_BLAS
+ return cblas_isamax(N, X, incX);
+#else
+ return isamax_raw(N, X, incX);
+#endif
+}
+
} // namespace nntrainer
const unsigned int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
+unsigned int isamax(const unsigned int N, const float *X, const int incX);
+
} /* namespace nntrainer */
#endif /* __cplusplus */
#endif /* __BLAS_INTERFACE_H__ */
return snrm2(len, data, 1);
}
+float Tensor::max_abs() const {
+ unsigned int len = length();
+ const float *data = getData();
+
+ unsigned int idx = isamax(len, data, 1);
+ return *(data + idx);
+}
+
Tensor &Tensor::normalization(Tensor &output) const {
if (output.uninitialized())
output = Tensor(dim);
*
* @return Tensor Variable tensor
*/
- Tensor &getVariableRef() const { return *var.get(); }
+ const Tensor &getVariableRef() const { return *var.get(); }
/**
* @brief Get the Gradient tensor (by reference)
*
* @return Tensor Gradient tensor
*/
- Tensor &getGradientRef() const { return *grad.get(); }
+ const Tensor &getGradientRef() const { return *grad.get(); }
protected:
TensorDim dim; /**< dimension of the tensor */