return ML_ERROR_INVALID_PARAMETER;
}
- for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
- training.loss = 0.0f;
-
+ /**
+ * @brief run a single epoch with given callback, @a auto is used instead of
+ * std::function for performance measure
+ * @param buffer buffer to run
+ * @param shuffle whether to shuffle or not
+ * @param on_iteration_fetch function that will recieve reference to stat,
+ * buffer which will be called every time data is fetched and set
+ * @param on_epoch_end function that will recieve reference to stat,
+ * buffer which will be called on the epoch end
+ */
+ auto run_epoch = [this, &in, &label, &in_dims, &label_dims](
+ DataBuffer *buffer, bool shuffle,
+ auto &&on_iteration_fetch, auto &&on_epoch_end) {
+ /// @todo managing metrics must be handled here as well!! for now it is
+ /// handled in individual callbacks
+ RunStats stat;
std::future<std::shared_ptr<IterationQueue>> future_iq =
- train_buffer->startFetchWorker(in_dims, label_dims, true);
-
- // /// @todo make this working, test buffer is running but doing nothing
- // if (test_buffer != nullptr && test_buffer->isValid()) {
- // status = test_buffer->run();
- // if (status != ML_ERROR_NONE) {
- // test_buffer->clear();
- // return status;
- // }
- // }
-
- int count = 0;
-
+ buffer->startFetchWorker(in_dims, label_dims, shuffle);
while (true) {
- ScopedView<Iteration> iter_view = train_buffer->fetch();
+ ScopedView<Iteration> iter_view = buffer->fetch();
if (iter_view.isEmpty()) {
break;
}
in = iteration.getInputsRef().front();
label = iteration.getLabelsRef().front();
- forwarding(true);
- backwarding(iter++);
+ on_iteration_fetch(stat, *buffer);
+ }
+ future_iq.get();
+ on_epoch_end(stat, *buffer);
- std::cout << "#" << epoch_idx << "/" << epochs;
- float loss = getLoss();
- train_buffer->displayProgress(count++, loss);
- training.loss += loss;
+ if (stat.num_iterations == 0) {
+ throw std::runtime_error("No data came while buffer ran");
}
- future_iq.get();
+ return stat;
+ };
- if (count == 0)
- throw std::runtime_error("No training data");
+ auto train_for_iteration = [this](RunStats &stat, DataBuffer &buffer) {
+ forwarding(true);
+ backwarding(iter++);
- training.loss /= count;
+ std::cout << "#" << epoch_idx << "/" << epochs;
+ auto loss = getLoss();
+ stat.loss += loss;
+ buffer.displayProgress(stat.num_iterations++, loss);
+ };
+
+ auto train_epoch_end = [this](RunStats &stat, DataBuffer &buffer) {
+ stat.loss /= static_cast<float>(stat.num_iterations);
if (!save_path.empty()) {
save(save_path, ml::train::ModelFormat::MODEL_FORMAT_BIN);
}
std::cout << "#" << epoch_idx << "/" << epochs
- << " - Training Loss: " << training.loss;
-
- if (valid_buffer != nullptr) {
- int right = 0;
- validation.loss = 0.0f;
- unsigned int tcases = 0;
-
- std::future<std::shared_ptr<IterationQueue>> future_iq =
- valid_buffer->startFetchWorker(in_dims, label_dims, false);
-
- while (true) {
- ScopedView<Iteration> iter_view = valid_buffer->fetch();
- if (iter_view.isEmpty()) {
- break;
- }
- auto &iter = iter_view.get();
- if (iter.batch() != batch_size) {
- /// @todo support partial batch
- continue;
- }
- /// @todo multiple input support
- in = iter.getInputsRef().front();
- label = iter.getLabelsRef().front();
-
- forwarding(false);
- auto model_out = output.argmax();
- auto label_out = label.argmax();
- for (unsigned int b = 0; b < batch_size; b++) {
- if (model_out[b] == label_out[b])
- right++;
- }
- validation.loss += getLoss();
- tcases++;
- }
+ << " - Training Loss: " << stat.loss;
+ };
- future_iq.get();
+ auto eval_for_iteration = [this, &output, &label](RunStats &stat,
+ DataBuffer &buffer) {
+ forwarding(false);
+ auto model_out = output.argmax();
+ auto label_out = label.argmax();
+ for (unsigned int b = 0; b < batch_size; b++) {
+ if (model_out[b] == label_out[b])
+ stat.num_correct_predictions++;
+ }
+ stat.num_iterations++;
+ stat.loss += getLoss();
+ };
- if (tcases == 0) {
- ml_loge("Error : 0 test cases");
- status = ML_ERROR_INVALID_PARAMETER;
- return status;
- }
- validation.loss /= (float)(tcases);
- validation.accuracy = right / (float)(tcases * batch_size) * 100.0f;
- std::cout << " >> [ Accuracy: " << validation.accuracy
- << "% - Validation Loss : " << validation.loss << " ] ";
+ auto eval_epoch_end = [this](RunStats &stat, DataBuffer &buffer) {
+ stat.loss /= static_cast<float>(stat.num_iterations);
+ stat.accuracy = stat.num_correct_predictions /
+ static_cast<float>(stat.num_iterations * batch_size) *
+ 100.0f;
+ std::cout << " >> [ Accuracy: " << validation.accuracy
+ << "% - Validation Loss : " << validation.loss << " ]";
+ };
+
+ for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
+ training =
+ run_epoch(train_buffer.get(), true, train_for_iteration, train_epoch_end);
+ if (valid_buffer) {
+ validation = run_epoch(valid_buffer.get(), false, eval_for_iteration,
+ eval_epoch_end);
}
- std::cout << std::endl;
+ std::cout << '\n';
+ }
+ if (test_buffer) {
+ std::cout << "Evaluation with test data...\n";
+ testing =
+ run_epoch(test_buffer.get(), false, eval_for_iteration, eval_epoch_end);
}
return status;