NeuralNetwork::NeuralNetwork() :
model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
- model_flex_props(props::Epochs(), props::TrainingBatchSize(),
- props::SavePath(), props::ContinueTrain(),
- props::SaveBestPath(), props::MemoryOptimization(),
- props::MemorySwap(), props::MemorySwapPath()),
+ model_flex_props(
+ props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
+ props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
+ props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
load_path(std::string()),
epoch_idx(0),
iter(0),
NeuralNetwork::NeuralNetwork(AppContext app_context_) :
model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
- model_flex_props(props::Epochs(), props::TrainingBatchSize(),
- props::SavePath(), props::ContinueTrain(),
- props::SaveBestPath(), props::MemoryOptimization(),
- props::MemorySwap(), props::MemorySwapPath()),
+ model_flex_props(
+ props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
+ props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
+ props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
load_path(std::string()),
epoch_idx(0),
iter(0),
bool memory_swap = std::get<props::MemorySwap>(model_flex_props);
const std::string memory_swap_path =
std::get<props::MemorySwapPath>(model_flex_props);
- model_graph = NetworkGraph(memory_swap, memory_swap_path);
+ unsigned int lookahead =
+ std::get<props::MemorySwapLookahead>(model_flex_props);
+ model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead);
model_graph.setMemoryOptimizations(
std::get<props::MemoryOptimization>(model_flex_props));
/**
* @brief free layers
*/
-NeuralNetwork::~NeuralNetwork() = default;
+NeuralNetwork::~NeuralNetwork() { deallocate(); }
/**
* @brief forward propagation using layers object which has layer
auto train_for_iteration = [this, stop_cb](RunStats &stat,
DataBuffer &buffer) {
- model_graph.flushCache();
-
forwarding(true, stop_cb);
backwarding(iter++, stop_cb);
+ // To avoid unconsidered memory leak, we need to clear the cache
+ model_graph.flushCache();
+
if (!stop_cb(nullptr)) {
std::cout << "#" << epoch_idx << "/" << getEpochs();
ml_logi("# %d / %d", epoch_idx, getEpochs());