From b9a1210ba5d49ef51300f150460c8f1c7b911804 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Thu, 10 Sep 2020 16:58:54 +0900 Subject: [PATCH] [Application] Update application to use C-API Update draw-classification application to use the C-API Also add more polishing to the application Update building of android application to build with C-API Update meson to use C-API than nntrainer directly Inference of application is still remaining **Self evaluation:** 1. Build test: [x]Passed [ ]Failed [ ]Skipped 2. Run test: [x]Passed [ ]Failed [ ]Skipped Signed-off-by: Parichay Kapoor --- .../Draw_Classification/jni/Android.mk | 10 +- .../Draw_Classification/jni/main.cpp | 437 +++++++++++---------- .../Draw_Classification/jni/meson.build | 2 +- 3 files changed, 241 insertions(+), 208 deletions(-) diff --git a/Applications/TransferLearning/Draw_Classification/jni/Android.mk b/Applications/TransferLearning/Draw_Classification/jni/Android.mk index 9a393e4..603e5df 100644 --- a/Applications/TransferLearning/Draw_Classification/jni/Android.mk +++ b/Applications/TransferLearning/Draw_Classification/jni/Android.mk @@ -13,6 +13,7 @@ endif NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer/include \ $(NNTRAINER_ROOT)/api \ + $(NNTRAINER_ROOT)/api/capi/include \ $(NNTRAINER_ROOT)/api/capi/include/platform NNTRAINER_APPLICATION := $(NNTRAINER_ROOT)/Applications @@ -48,6 +49,13 @@ include $(PREBUILT_SHARED_LIBRARY) include $(CLEAR_VARS) +LOCAL_MODULE := capi-nntrainer +LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/libs/$(TARGET_ARCH_ABI)/libcapi-nntrainer.so + +include $(PREBUILT_SHARED_LIBRARY) + +include $(CLEAR_VARS) + LOCAL_ARM_NEON := true LOCAL_CFLAGS += -std=c++14 -Ofast -mcpu=cortex-a53 -Ilz4-nougat/lib LOCAL_LDFLAGS += -Llz4-nougat/lib/obj/local/$(TARGET_ARCH_ABI)/ @@ -61,7 +69,7 @@ LOCAL_LDLIBS := -llog LOCAL_SRC_FILES := main.cpp bitmap_helpers.cpp -LOCAL_SHARED_LIBRARIES := nntrainer +LOCAL_SHARED_LIBRARIES := capi-nntrainer LOCAL_STATIC_LIBRARIES := tensorflow-lite diff --git a/Applications/TransferLearning/Draw_Classification/jni/main.cpp b/Applications/TransferLearning/Draw_Classification/jni/main.cpp index 3feeb99..72e5eb8 100644 --- a/Applications/TransferLearning/Draw_Classification/jni/main.cpp +++ b/Applications/TransferLearning/Draw_Classification/jni/main.cpp @@ -39,184 +39,201 @@ #include #include -#include "neuralnet.h" -#include "tensor.h" +#include -/** - * @brief Data size for each category - */ -#define TOTAL_DATA_SIZE 5 +/** Number of dimensions for the input data */ +#define MAX_DIM 4 -/** - * @brief Number of category : Three - */ -#define TOTAL_LABEL_SIZE 3 +/** Data size for each category */ +#define NUM_DATA_PER_LABEL 5 -/** - * @brief Number of Test Set - */ +/** Size of each label (number of label categories) */ +#define LABEL_SIZE 3 + +/** Size of each input */ +#define INPUT_SIZE 128 + +/** Number of test data points */ #define TOTAL_TEST_SIZE 8 -/** - * @brief Max Epochs - */ -#define ITERATION 1000 +/** Total number of data points in an epoch */ +#define EPOCH_SIZE LABEL_SIZE *NUM_DATA_PER_LABEL -using namespace std; +/** Max Epochs */ +#define EPOCHS 1000 -/** - * @brief location of resources ( ../../res/ ) - */ -string data_path; +/** labels values */ +const std::string label_names[LABEL_SIZE] = {"happy", "sad", "soso"}; + +/** Vectors containing the training data */ +std::vector> inputVector, labelVector; +unsigned int iteration = 0; /** * @brief step function * @param[in] x value to be distinguished * @retval 0.0 or 1.0 */ -float stepFunction(float x) { - if (x > 0.9) { - return 1.0; - } - - if (x < 0.1) { - return 0.0; - } +// float stepFunction(float x) { +// if (x > 0.9) { +// return 1.0; +// } +// +// if (x < 0.1) { +// return 0.0; +// } +// +// return x; +// } + +struct TFLiteData { + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + std::unique_ptr model; + std::string data_path; - return x; -} + int output_number_of_pixels; + int inputDimReq[MAX_DIM]; +}; -/** - * @brief Get Feature vector from tensorflow lite - * This creates interpreter & inference with ssd tflite - * @param[in] filename input file path - * @param[out] feature_input save output of tflite - */ -void getFeature(const string filename, vector &feature_input) { +void setupTensorflowLiteModel(const std::string &data_path, + TFLiteData &tflite_data) { int input_size; int output_size; - int *output_idx_list; - int *input_idx_list; - int inputDim[4]; - int outputDim[4]; - int input_idx_list_len = 0; - int output_idx_list_len = 0; + int len; + int outputDim[MAX_DIM]; + + tflite_data.data_path = data_path; std::string model_path = data_path + "ssd_mobilenet_v2_coco_feature.tflite"; - std::unique_ptr model = + tflite_data.model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str()); + if (tflite_data.model == NULL) + throw std::runtime_error("Unable to build model from file"); - assert(model != NULL); - tflite::ops::builtin::BuiltinOpResolver resolver; - std::unique_ptr interpreter; - tflite::InterpreterBuilder(*model.get(), resolver)(&interpreter); + tflite::InterpreterBuilder(*tflite_data.model.get(), + tflite_data.resolver)(&tflite_data.interpreter); - input_size = interpreter->inputs().size(); - output_size = interpreter->outputs().size(); + if (tflite_data.interpreter->AllocateTensors() != kTfLiteOk) + throw std::runtime_error("Failed to allocate tensors!"); - input_idx_list = new int[input_size]; - output_idx_list = new int[output_size]; + input_size = tflite_data.interpreter->inputs().size(); + output_size = tflite_data.interpreter->outputs().size(); - int t_size = interpreter->tensors_size(); - for (int i = 0; i < t_size; i++) { - for (int j = 0; j < input_size; j++) { - if (strcmp(interpreter->tensor(i)->name, interpreter->GetInputName(j)) == - 0) - input_idx_list[input_idx_list_len++] = i; - } - for (int j = 0; j < output_size; j++) { - if (strcmp(interpreter->tensor(i)->name, interpreter->GetOutputName(j)) == - 0) - output_idx_list[output_idx_list_len++] = i; - } - } - for (int i = 0; i < 4; i++) { - inputDim[i] = 1; + if (input_size > 1 || output_size > 1) + throw std::runtime_error("Model is expected with single input and output"); + + for (int i = 0; i < MAX_DIM; i++) { + tflite_data.inputDimReq[i] = 1; outputDim[i] = 1; } - int len = interpreter->tensor(input_idx_list[0])->dims->size; - std::reverse_copy(interpreter->tensor(input_idx_list[0])->dims->data, - interpreter->tensor(input_idx_list[0])->dims->data + len, - inputDim); - len = interpreter->tensor(output_idx_list[0])->dims->size; - std::reverse_copy(interpreter->tensor(output_idx_list[0])->dims->data, - interpreter->tensor(output_idx_list[0])->dims->data + len, - outputDim); - - int output_number_of_pixels = 1; - int wanted_channels = inputDim[0]; - int wanted_height = inputDim[1]; - int wanted_width = inputDim[2]; - - for (int k = 0; k < 4; k++) - output_number_of_pixels *= inputDim[k]; - - int _input = interpreter->inputs()[0]; + int input_idx = tflite_data.interpreter->inputs()[0]; + len = tflite_data.interpreter->tensor(input_idx)->dims->size; + std::reverse_copy(tflite_data.interpreter->tensor(input_idx)->dims->data, + tflite_data.interpreter->tensor(input_idx)->dims->data + + len, + tflite_data.inputDimReq); + + int output_idx = tflite_data.interpreter->outputs()[0]; + len = tflite_data.interpreter->tensor(output_idx)->dims->size; + std::reverse_copy( + tflite_data.interpreter->tensor(output_idx)->dims->data, + tflite_data.interpreter->tensor(output_idx)->dims->data + len, outputDim); + + tflite_data.output_number_of_pixels = 1; + for (int k = 0; k < MAX_DIM; k++) + tflite_data.output_number_of_pixels *= tflite_data.inputDimReq[k]; +} +/** + * @brief Get Feature vector from tensorflow lite + * This creates interpreter & inference with ssd tflite + * @param[in] filename input file path + * @param[out] feature_input save output of tflite + */ +void getInputFeature(const TFLiteData &tflite_data, const std::string filename, + std::vector &feature_input) { uint8_t *in; - float *output; - in = tflite::label_image::read_bmp(filename, &wanted_width, &wanted_height, - &wanted_channels); - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cout << "Failed to allocate tensors!" << std::endl; - exit(0); + int inputDim[MAX_DIM] = {1, 1, 1, 1}; + in = tflite::label_image::read_bmp(filename, inputDim, inputDim + 1, + inputDim + 2); + + int input_img_size = 1; + for (int idx = 0; idx < MAX_DIM; idx++) { + input_img_size *= inputDim[idx]; } - for (int l = 0; l < output_number_of_pixels; l++) { - (interpreter->typed_tensor(_input))[l] = - ((float)in[l] - 127.5f) / 127.5f; + if (tflite_data.output_number_of_pixels != input_img_size) { + delete in; + throw std::runtime_error("Input size does not match the required size"); } - if (interpreter->Invoke() != kTfLiteOk) { - std::cout << "Failed to invoke!" << std::endl; - exit(0); + int input_idx = tflite_data.interpreter->inputs()[0]; + for (int l = 0; l < tflite_data.output_number_of_pixels; l++) { + (tflite_data.interpreter->typed_tensor(input_idx))[l] = + ((float)in[l] - 127.5f) / 127.5f; } - output = interpreter->typed_output_tensor(0); + if (tflite_data.interpreter->Invoke() != kTfLiteOk) + std::runtime_error("Failed to invoke."); - for (int l = 0; l < 128; l++) { + float *output = tflite_data.interpreter->typed_output_tensor(0); + for (int l = 0; l < INPUT_SIZE; l++) { feature_input[l] = output[l]; } - delete[] input_idx_list; - delete[] output_idx_list; delete[] in; } /** - * @brief Extract the features from all three categories - * @param[in] p data path - * @param[out] feature_input save output of tflite - * @param[out] feature_output save label data + * @brief Extract the features from pretrained model + * @param[in] data_path data path + * @param[out] input_data output of tflite model (input for the nntrainer model) + * @param[out] label_data one hot label data */ -void ExtractFeatures(std::string p, vector> &feature_input, - vector> &feature_output) { - string total_label[TOTAL_LABEL_SIZE] = {"happy", "sad", "soso"}; - - int trainingSize = TOTAL_LABEL_SIZE * TOTAL_DATA_SIZE; - - feature_input.resize(trainingSize); - feature_output.resize(trainingSize); - - int count = 0; - - for (int i = 0; i < TOTAL_LABEL_SIZE; i++) { - std::string path = p; - path += total_label[i]; +void extractFeatures(const TFLiteData &tflite_data, + std::vector> &input_data, + std::vector> &label_data) { + int trainingSize = LABEL_SIZE * NUM_DATA_PER_LABEL; + + input_data.resize(trainingSize, std::vector(INPUT_SIZE)); + /** resize label data to size and initialize to 0 */ + label_data.resize(trainingSize, std::vector(LABEL_SIZE, 0)); + + for (int i = 0; i < LABEL_SIZE; i++) { + for (int j = 0; j < NUM_DATA_PER_LABEL; j++) { + std::string label_file = label_names[i] + std::to_string(j + 1) + ".bmp"; + std::string img = + tflite_data.data_path + "/" + label_names[i] + "/" + label_file; + + int count = i * NUM_DATA_PER_LABEL + j; + getInputFeature(tflite_data, img, input_data[count]); + label_data[count][i] = 1; + } + } +} - for (int j = 0; j < TOTAL_DATA_SIZE; j++) { - std::string img = path + "/"; - img += total_label[i] + std::to_string(j + 1) + ".bmp"; - printf("%s\n", img.c_str()); +/** + * Data generator callback + */ +int getBatch_train(float **input, float **label, bool *last, void *user_data) { + if (iteration >= EPOCH_SIZE) { + *last = true; + iteration = 0; + return ML_ERROR_NONE; + } - feature_input[count].resize(128); + for (int idx = 0; idx < INPUT_SIZE; idx++) { + input[0][idx] = inputVector[iteration][idx]; + } - getFeature(img, feature_input[count]); - feature_output[count].resize(TOTAL_LABEL_SIZE); - feature_output[count][i] = 1; - count++; - } + for (int idx = 0; idx < LABEL_SIZE; idx++) { + label[0][idx] = labelVector[iteration][idx]; } + + *last = false; + iteration += 1; + return ML_ERROR_NONE; } /** @@ -226,97 +243,105 @@ void ExtractFeatures(std::string p, vector> &feature_input, * @param[in] arg 2 : resource path */ int main(int argc, char *argv[]) { + int status = ML_ERROR_NONE; if (argc < 3) { std::cout << "./TransferLearning Config.ini resources\n"; exit(0); } - const vector args(argv + 1, argv + argc); + + const std::vector args(argv + 1, argv + argc); std::string config = args[0]; - data_path = args[1]; + + /** location of resources ( ../../res/ ) */ + std::string data_path = args[1]; srand(time(NULL)); - std::string ini_file = data_path + "ini.bin"; - std::vector> inputVector, outputVector; - /** - * @brief Extract Feature - */ - ExtractFeatures(data_path, inputVector, outputVector); - /** - * @brief Neural Network Create & Initialization - */ - nntrainer::NeuralNetwork NN; - - try { - NN.loadFromConfig(config); - NN.init(); - } catch (...) { - std::cerr << "Error during initiation" << std::endl; - NN.finalize(); - return -1; + TFLiteData tflite_data; + setupTensorflowLiteModel(data_path, tflite_data); + + /** Extract features from the already trained model */ + extractFeatures(tflite_data, inputVector, labelVector); + + /** Neural Network Create & Initialization */ + ml_train_model_h handle = NULL; + ml_train_dataset_h dataset = NULL; + + status = ml_train_model_construct_with_conf(config.c_str(), &handle); + if (status != ML_ERROR_NONE) { + std::cerr << "Failed to construct the model" << std::endl; + return status; } - /** - * @brief back propagation - */ - for (int i = 0; i < ITERATION; i++) { - for (unsigned int j = 0; j < inputVector.size(); j++) { - nntrainer::Tensor in, out; - try { - in = nntrainer::Tensor({inputVector[j]}); - } catch (...) { - std::cerr << "Error during tensor initialization" << std::endl; - NN.finalize(); - return -1; - } - try { - out = nntrainer::Tensor({outputVector[j]}); - } catch (...) { - std::cerr << "Error during tensor initialization" << std::endl; - NN.finalize(); - return -1; - } - - try { - NN.backwarding(MAKE_SHARED_TENSOR(in), MAKE_SHARED_TENSOR(out), i); - } catch (...) { - std::cerr << "Error during backwarding the model" << std::endl; - NN.finalize(); - return -1; - } - } - cout << "#" << i + 1 << "/" << ITERATION << " - Loss : " << NN.getLoss() - << endl; - NN.setLoss(0.0); + status = ml_train_model_compile(handle, NULL); + if (status != ML_ERROR_NONE) { + std::cerr << "Failed to compile the model" << std::endl; + ml_train_model_destroy(handle); + return status; + } + + /** Set the dataset from generator */ + status = ml_train_dataset_create_with_generator(&dataset, getBatch_train, + NULL, NULL); + if (status != ML_ERROR_NONE) { + std::cerr << "Failed to create the dataset" << std::endl; + ml_train_model_destroy(handle); + return status; + } + + status = ml_train_dataset_set_property(dataset, "buffer_size=100", NULL); + if (status != ML_ERROR_NONE) { + std::cerr << "Failed to set property for the dataset" << std::endl; + ml_train_dataset_destroy(dataset); + ml_train_model_destroy(handle); + return status; + } + + status = ml_train_model_set_dataset(handle, dataset); + if (status != ML_ERROR_NONE) { + std::cerr << "Failed to set dataset to the dataset" << std::endl; + ml_train_dataset_destroy(dataset); + ml_train_model_destroy(handle); + return status; } /** - * @brief test + * @brief back propagation */ - for (int i = 0; i < TOTAL_TEST_SIZE; i++) { - std::string path = data_path; - path += "testset"; - printf("\n[%s]\n", path.c_str()); - std::string img = path + "/"; - img += "test" + std::to_string(i + 1) + ".bmp"; - printf("%s\n", img.c_str()); - - std::vector featureVector, resultVector; - featureVector.resize(128); - getFeature(img, featureVector); - nntrainer::Tensor X; - try { - X = nntrainer::Tensor({featureVector}); - NN.forwarding(MAKE_SHARED_TENSOR(X))->apply(stepFunction); - } catch (...) { - std::cerr << "Error during forwaring the model" << std::endl; - NN.finalize(); - return -1; - } + std::stringstream epoch_string; + epoch_string << "epochs=" << EPOCHS << std::endl; + status = ml_train_model_run(handle, epoch_string.str().c_str(), NULL); + if (status != ML_ERROR_NONE) { + std::cerr << "Failed to train the model" << std::endl; + ml_train_model_destroy(handle); + return status; } + /** destroy the model */ + ml_train_model_destroy(handle); + /** - * @brief Finalize NN + * @brief test */ - NN.finalize(); + // for (int i = 0; i < TOTAL_TEST_SIZE; i++) { + // std::string path = data_path; + // path += "testset"; + // printf("\n[%s]\n", path.c_str()); + // std::string img = path + "/"; + // img += "test" + std::to_string(i + 1) + ".bmp"; + // printf("%s\n", img.c_str()); + + // std::vector featureVector, resultVector; + // featureVector.resize(128); + // getFeature(img, featureVector); + // nntrainer::Tensor X; + // try { + // X = nntrainer::Tensor({featureVector}); + // NN.forwarding(MAKE_SHARED_TENSOR(X))->apply(stepFunction); + // } catch (...) { + // std::cerr << "Error during forwaring the model" << std::endl; + // NN.finalize(); + // return -1; + // } + // } } diff --git a/Applications/TransferLearning/Draw_Classification/jni/meson.build b/Applications/TransferLearning/Draw_Classification/jni/meson.build index 8b080ac..686b0ca 100644 --- a/Applications/TransferLearning/Draw_Classification/jni/meson.build +++ b/Applications/TransferLearning/Draw_Classification/jni/meson.build @@ -8,7 +8,7 @@ training_sources = [ e = executable('nntrainer_training', training_sources, - dependencies: [iniparser_dep, nntrainer_dep, tflite_dep], + dependencies: [iniparser_dep, nntrainer_capi_dep, tflite_dep], include_directories: include_directories('.'), install: get_option('install-app'), install_dir: application_install_dir -- 2.7.4