std::string data_path;
float training_loss = 0.0;
+float validation_loss = 0.0;
+float last_batch_loss = 0.0;
/**
* @brief step function
#if defined(APP_VALIDATE)
TEST(MNIST_training, verify_accuracy) {
- EXPECT_FLOAT_EQ(training_loss, 2.0374029);
+ EXPECT_FLOAT_EQ(training_loss, 2.3255470);
+ EXPECT_FLOAT_EQ(validation_loss, 2.3074534);
+ EXPECT_FLOAT_EQ(last_batch_loss, 2.2916341);
}
#endif
*/
try {
NN.train();
- training_loss = NN.getLoss();
+ training_loss = NN.getTrainingLoss();
+ validation_loss = NN.getValidationLoss();
+ last_batch_loss = NN.getLoss();
} catch (...) {
std::cerr << "Error during train" << std::endl;
return 0;
~NeuralNetwork();
/**
- * @brief Get Loss
+ * @brief Get Loss from the previous ran batch of data
* @retval loss value
*/
float getLoss();
/**
+ * @brief Get Loss from the previous epoch of training data
+ * @retval loss value
+ */
+ float getTrainingLoss() { return training.loss; }
+
+ /**
+ * @brief Get Loss from the previous epoch of validation data
+ * @retval loss value
+ */
+ float getValidationLoss() { return validation.loss; }
+
+ /**
* @brief Get Learning rate
* @retval Learning rate
*/