#include <stdlib.h>
#include <time.h>
+#if defined(__TIZEN__)
+#include <gtest/gtest.h>
+#endif
+
#include "databuffer.h"
#include "databuffer_func.h"
#include "neuralnet.h"
std::string data_path;
+float training_loss = 0.0;
+
/**
* @brief step function
* @param[in] x value to be distinguished
return ML_ERROR_NONE;
}
+#if defined(__TIZEN__)
+TEST(MNIST_training, verify_accuracy) {
+ EXPECT_FLOAT_EQ(training_loss, 2.0374029);
+}
+#endif
+
/**
* @brief create NN
* Get Feature from tflite & run foword & back propatation
* @param[in] arg 2 : resource path
*/
int main(int argc, char *argv[]) {
+ int status = 0;
if (argc < 2) {
std::cout << "./nntrainer_mnist mnist.ini\n";
exit(0);
NN.readModel();
NN.setDataBuffer((DB));
+#if defined(__TIZEN__)
+ status = NN.setProperty({"epochs=5"});
+ if (status != ML_ERROR_NONE) {
+ std::cerr << "Error setting the number of epochs" << std::endl;
+ return 0;
+ }
+#endif
/**
* @brief Neural Network Train & validation
*/
try {
NN.train();
+ training_loss = NN.getLoss();
} catch (...) {
std::cerr << "Error during train" << std::endl;
return 0;
}
+#if defined(__TIZEN__)
+ try {
+ testing::InitGoogleTest(&argc, argv);
+ } catch (...) {
+ std::cerr << "Error duing InitGoogleTest" << std::endl;
+ return 0;
+ }
+
+ try {
+ status = RUN_ALL_TESTS();
+ } catch (...) {
+ std::cerr << "Error duing RUN_ALL_TSETS()" << std::endl;
+ }
+#endif
+
/**
* @brief Finalize NN
*/
- return 0;
+ return status;
}
mnist_sources = [
'main.cpp'
]
+if build_platform == 'tizen'
+ if not gtest_dep.found()
+ error('Gtest dependency not found for MNIST application')
+ endif
+endif
executable('nntrainer_mnist',
mnist_sources,
- dependencies: [iniparser_dep, nntrainer_dep],
+ dependencies: [iniparser_dep, nntrainer_dep, gtest_dep],
include_directories: include_directories('.'),
install: get_option('install-app'),
install_dir: application_install_dir
export NNSTREAMER_CONF=$(pwd)/test/nnstreamer_filter_nntrainer/nnstreamer-test.ini
export NNSTREAMER_FILTERS=$(pwd)/build/nnstreamer/tensor_filter
pushd build
+rm -rf model.bin
TF_APP=Applications/TransferLearning/Draw_Classification
./${TF_APP}/jni/nntrainer_training ../${TF_APP}/res/Training.ini ../${TF_APP}/res
+
+rm -rf model.bin
+cp ../Applications/MNIST/jni/mnist_trainingSet.dat .
+MNIST_APP=Applications/MNIST
+./${MNIST_APP}/jni/nntrainer_mnist ../${MNIST_APP}/res/mnist.ini
+
popd
# unittest for nntrainer plugin for nnstreamer