using prop_tag = str_prop_tag; /**< property type */
};
+/**
+ * @brief model save path property
+ *
+ */
+class SaveBestPath : public Property<std::string> {
+public:
+ static constexpr const char *key =
+ "save_best_path"; /**< unique key to access */
+ using prop_tag = str_prop_tag; /**< property type */
+};
+
/**
* @brief model batch size property
*
NeuralNetwork::NeuralNetwork(AppContext app_context_, bool in_place_opt) :
model_props(props::LossType()),
model_flex_props(props::Epochs(), props::TrainingBatchSize(),
- props::SavePath(), props::ContinueTrain()),
+ props::SavePath(), props::ContinueTrain(),
+ props::SaveBestPath()),
load_path(std::string()),
epoch_idx(0),
iter(0),
stat.loss += getLoss();
};
- auto eval_epoch_end = [this, batch_size](RunStats &stat, DataBuffer &buffer) {
+ auto eval_epoch_end = [this, batch_size, max_acc = 0.0f,
+ min_loss = std::numeric_limits<float>::max()](
+ RunStats &stat, DataBuffer &buffer) mutable {
stat.loss /= static_cast<float>(stat.num_iterations);
stat.accuracy = stat.num_correct_predictions /
static_cast<float>(stat.num_iterations * batch_size) *
100.0f;
+
+ if (stat.accuracy > max_acc ||
+ (stat.accuracy == max_acc && stat.loss < min_loss)) {
+ max_acc = stat.accuracy;
+ /// @note this is not actually 'the' min loss for whole time but records
+ /// when data change
+ min_loss = stat.loss;
+ auto &save_best_path = std::get<props::SaveBestPath>(model_flex_props);
+ if (!save_best_path.empty()) {
+ save(save_best_path);
+ }
+ }
std::cout << " >> [ Accuracy: " << validation.accuracy
<< "% - Validation Loss : " << validation.loss << " ]";
};
}
std::cout << '\n';
}
+
if (test_buffer) {
std::cout << "Evaluation with test data...\n";
testing =
private:
using FlexiblePropTypes =
- std::tuple<props::Epochs, props::TrainingBatchSize, props::SavePath, props::ContinueTrain>;
+ std::tuple<props::Epochs, props::TrainingBatchSize, props::SavePath,
+ props::ContinueTrain, props::SaveBestPath>;
using RigidPropTypes = std::tuple<props::LossType>;
RigidPropTypes model_props; /**< model props */