From e02b2eed9d689f35e5826630e9e9633cbbcf719d Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Fri, 9 Jul 2021 15:20:32 +0900 Subject: [PATCH] [dataset/cleanup] Remove type from dataset This patch removes `datasetUsageType` from dataset interface **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- Applications/LogisticRegression/jni/main.cpp | 3 +- Applications/VGG/jni/main.cpp | 6 +-- nntrainer/dataset/databuffer.cpp | 46 +++++++------------- nntrainer/dataset/databuffer.h | 28 +++--------- nntrainer/dataset/databuffer_factory.cpp | 7 +-- nntrainer/dataset/databuffer_file.cpp | 32 ++------------ nntrainer/dataset/databuffer_file.h | 17 ++------ nntrainer/dataset/databuffer_func.cpp | 7 +-- nntrainer/dataset/databuffer_func.h | 8 +--- nntrainer/models/model_loader.cpp | 5 +-- nntrainer/models/neuralnet.cpp | 38 ++++++---------- test/tizen_capi/unittest_tizen_capi_dataset.cpp | 13 ------ test/unittest/unittest_databuffer_file.cpp | 58 +++++-------------------- 13 files changed, 65 insertions(+), 203 deletions(-) diff --git a/Applications/LogisticRegression/jni/main.cpp b/Applications/LogisticRegression/jni/main.cpp index b10525d..551c15b 100644 --- a/Applications/LogisticRegression/jni/main.cpp +++ b/Applications/LogisticRegression/jni/main.cpp @@ -171,8 +171,7 @@ int main(int argc, char *argv[]) { srand(time(NULL)); auto data_train = std::make_shared(); - data_train->setGeneratorFunc(ml::train::DatasetDataUsageType::DATA_TRAIN, - getBatch_train); + data_train->setGeneratorFunc(getBatch_train); /** * @brief Create NN diff --git a/Applications/VGG/jni/main.cpp b/Applications/VGG/jni/main.cpp index 70bad96..73b1b93 100644 --- a/Applications/VGG/jni/main.cpp +++ b/Applications/VGG/jni/main.cpp @@ -400,11 +400,9 @@ int main(int argc, char *argv[]) { count_val.duplication[i] = i; auto db_train = std::make_shared(); - db_train->setGeneratorFunc(ml::train::DatasetDataUsageType::DATA_TRAIN, - getBatch_train_file); + db_train->setGeneratorFunc(getBatch_train_file); auto db_valid = std::make_shared(); - db_valid->setGeneratorFunc(ml::train::DatasetDataUsageType::DATA_VAL, - getBatch_val_file); + db_valid->setGeneratorFunc(getBatch_val_file); /** * @brief Neural Network Create & Initialization diff --git a/nntrainer/dataset/databuffer.cpp b/nntrainer/dataset/databuffer.cpp index eb3ad0d..4ad265c 100644 --- a/nntrainer/dataset/databuffer.cpp +++ b/nntrainer/dataset/databuffer.cpp @@ -94,7 +94,8 @@ int DataBuffer::rangeRandom(int min, int max) { return dist(rng); } -int DataBuffer::run(DatasetDataUsageType type) { +int DataBuffer::run() { + auto type = DatasetDataUsageType::DATA_TRAIN; int status = ML_ERROR_NONE; switch (type) { case DatasetDataUsageType::DATA_TRAIN: @@ -103,7 +104,7 @@ int DataBuffer::run(DatasetDataUsageType type) { if (validation[static_cast(DatasetDataUsageType::DATA_TRAIN)]) { this->train_running = true; - this->train_thread = std::thread(&DataBuffer::updateData, this, type); + this->train_thread = std::thread(&DataBuffer::updateData, this); if (globalExceptionPtr) { try { std::rethrow_exception(globalExceptionPtr); @@ -122,7 +123,7 @@ int DataBuffer::run(DatasetDataUsageType type) { return ML_ERROR_INVALID_PARAMETER; if (validation[static_cast(DatasetDataUsageType::DATA_VAL)]) { this->val_running = true; - this->val_thread = std::thread(&DataBuffer::updateData, this, type); + this->val_thread = std::thread(&DataBuffer::updateData, this); if (globalExceptionPtr) { try { std::rethrow_exception(globalExceptionPtr); @@ -142,7 +143,7 @@ int DataBuffer::run(DatasetDataUsageType type) { if (validation[static_cast(DatasetDataUsageType::DATA_TEST)]) { this->test_running = true; - this->test_thread = std::thread(&DataBuffer::updateData, this, type); + this->test_thread = std::thread(&DataBuffer::updateData, this); if (globalExceptionPtr) { try { std::rethrow_exception(globalExceptionPtr); @@ -165,7 +166,8 @@ int DataBuffer::run(DatasetDataUsageType type) { return status; } -int DataBuffer::clear(DatasetDataUsageType type) { +int DataBuffer::clear() { + auto type = DatasetDataUsageType::DATA_TRAIN; int status = ML_ERROR_NONE; NN_EXCEPTION_NOTI(DATA_NOT_READY); switch (type) { @@ -207,27 +209,8 @@ int DataBuffer::clear(DatasetDataUsageType type) { return status; } -int DataBuffer::clear() { - unsigned int i; - - int status = ML_ERROR_NONE; - for (i = (int)DatasetDataUsageType::DATA_TRAIN; - i <= (int)DatasetDataUsageType::DATA_TEST; ++i) { - DatasetDataUsageType type = static_cast(i); - status = this->clear(type); - - if (status != ML_ERROR_NONE) { - ml_loge("Error: error occurred during clearing"); - return status; - } - } - - return status; -} - -bool DataBuffer::getDataFromBuffer(DatasetDataUsageType type, float *out, - float *label) { - +bool DataBuffer::getDataFromBuffer(float *out, float *label) { + auto type = DatasetDataUsageType::DATA_TRAIN; using QueueType = std::vector>; auto wait_for_data_fill = [](std::mutex &ready_mutex, @@ -385,8 +368,8 @@ int DataBuffer::setFeatureSize(TensorDim indim) { return status; } -void DataBuffer::displayProgress(const int count, DatasetDataUsageType type, - float loss) { +void DataBuffer::displayProgress(const int count, float loss) { + auto type = DatasetDataUsageType::DATA_TRAIN; int barWidth = 20; float max_size = max_train; switch (type) { @@ -510,8 +493,11 @@ int DataBuffer::setProperty(const PropertyType type, std::string &value) { return status; } -int DataBuffer::setGeneratorFunc(DatasetDataUsageType type, datagen_cb func, - void *user_data) { +int DataBuffer::setGeneratorFunc(datagen_cb func, void *user_data) { + return ML_ERROR_NOT_SUPPORTED; +} + +int DataBuffer::setDataFile(const std::string &path) { return ML_ERROR_NOT_SUPPORTED; } diff --git a/nntrainer/dataset/databuffer.h b/nntrainer/dataset/databuffer.h index 7c2224a..ffff5b9 100644 --- a/nntrainer/dataset/databuffer.h +++ b/nntrainer/dataset/databuffer.h @@ -63,29 +63,19 @@ public: /** * @brief Update Data Buffer ( it is for child thread ) - * @param[in] BufferType training, validation, test * @retval void */ - virtual void updateData(DatasetDataUsageType type) = 0; + virtual void updateData() = 0; /** * @brief function for thread ( training, validation, test ) - * @param[in] BufferType training, validation, test * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - virtual int run(DatasetDataUsageType type); + virtual int run(); /** * @brief clear thread ( training, validation, test ) - * @param[in] BufferType training, validation, test - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - virtual int clear(DatasetDataUsageType type); - - /** - * @brief clear all thread ( training, validation, test ) * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ @@ -93,14 +83,13 @@ public: /** * @brief get Data from Data Buffer using databuffer param - * @param[in] BufferType training, validation, test * @param[out] out feature data ( batch_size size ), a contiguous and * allocated memory block should be passed * @param[out] label label data ( batch_size size ), a contiguous and * allocated memory block should be passed * @retval true/false */ - bool getDataFromBuffer(DatasetDataUsageType type, float *out, float *label); + bool getDataFromBuffer(float *out, float *label); /** * @brief set number of class @@ -158,7 +147,7 @@ public: * @param[in] type buffer type ( DATA_TRAIN, DATA_VAL, DATA_TEST ) * @retval void */ - void displayProgress(const int count, DatasetDataUsageType type, float loss); + void displayProgress(const int count, float loss); /** * @brief return validation of data set @@ -184,25 +173,20 @@ public: /** * @brief set function pointer for each type - * @param[in] type Buffer Type * @param[in] call back function pointer * @param[in] user_data user_data of the callback * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - virtual int setGeneratorFunc(DatasetDataUsageType type, datagen_cb func, - void *user_data = nullptr); + virtual int setGeneratorFunc(datagen_cb func, void *user_data = nullptr); /** * @brief set train data file name - * @param[in] type data type : DATA_TRAIN, DATA_VAL, DATA_TEST * @param[in] path file path * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - virtual int setDataFile(DatasetDataUsageType type, std::string path) { - return setDataFile(type, path); - } + virtual int setDataFile(const std::string &path); /** * @brief property type of databuffer diff --git a/nntrainer/dataset/databuffer_factory.cpp b/nntrainer/dataset/databuffer_factory.cpp index 32cfa04..07c2b76 100644 --- a/nntrainer/dataset/databuffer_factory.cpp +++ b/nntrainer/dataset/databuffer_factory.cpp @@ -45,9 +45,7 @@ std::unique_ptr createDataBuffer(DatasetType type, std::unique_ptr dataset = createDataBuffer(type); - NNTR_THROW_IF(file == nullptr || - dataset->setDataFile(DatasetDataUsageType::DATA_TRAIN, - file) != ML_ERROR_NONE, + NNTR_THROW_IF(file == nullptr || dataset->setDataFile(file) != ML_ERROR_NONE, std::invalid_argument) << "invalid train file, path: " << (file ? file : "null"); @@ -65,8 +63,7 @@ std::unique_ptr createDataBuffer(DatasetType type, datagen_cb cb, std::unique_ptr dataset = createDataBuffer(type); - if (dataset->setGeneratorFunc(DatasetDataUsageType::DATA_TRAIN, cb, - user_data) != ML_ERROR_NONE) + if (dataset->setGeneratorFunc(cb, user_data) != ML_ERROR_NONE) throw std::invalid_argument("Invalid train data generator"); return dataset; diff --git a/nntrainer/dataset/databuffer_file.cpp b/nntrainer/dataset/databuffer_file.cpp index f837747..d481d32 100644 --- a/nntrainer/dataset/databuffer_file.cpp +++ b/nntrainer/dataset/databuffer_file.cpp @@ -117,7 +117,8 @@ int DataBufferFromDataFile::init() { return ML_ERROR_NONE; } -void DataBufferFromDataFile::updateData(DatasetDataUsageType type) { +void DataBufferFromDataFile::updateData() { + auto type = DatasetDataUsageType::DATA_TRAIN; unsigned int max_size = 0; unsigned int buf_size = 0; unsigned int *rest_size = NULL; @@ -265,8 +266,8 @@ void DataBufferFromDataFile::updateData(DatasetDataUsageType type) { file.close(); } -int DataBufferFromDataFile::setDataFile(DatasetDataUsageType type, - std::string path) { +int DataBufferFromDataFile::setDataFile(const std::string &path) { + auto type = DatasetDataUsageType::DATA_TRAIN; int status = ML_ERROR_NONE; std::ifstream data_file(path.c_str()); @@ -360,29 +361,4 @@ int DataBufferFromDataFile::setFeatureSize(TensorDim tdim) { return status; } -int DataBufferFromDataFile::setProperty(const PropertyType type, - std::string &value) { - int status = ML_ERROR_NONE; - - if (data_buffer_type != DatasetType::FILE) - return ML_ERROR_INVALID_PARAMETER; - - switch (type) { - case PropertyType::train_data: - status = this->setDataFile(DatasetDataUsageType::DATA_TRAIN, value); - break; - case PropertyType::val_data: - status = this->setDataFile(DatasetDataUsageType::DATA_VAL, value); - break; - case PropertyType::test_data: - status = this->setDataFile(DatasetDataUsageType::DATA_TEST, value); - break; - default: - status = DataBuffer::setProperty(type, value); - break; - } - - return status; -} - } /* namespace nntrainer */ diff --git a/nntrainer/dataset/databuffer_file.h b/nntrainer/dataset/databuffer_file.h index 5f666de..029755c 100644 --- a/nntrainer/dataset/databuffer_file.h +++ b/nntrainer/dataset/databuffer_file.h @@ -38,7 +38,7 @@ namespace nntrainer { * @class DataBufferFromDataFile Data Buffer from Raw Data File * @brief Data Buffer from reading raw data */ -class DataBufferFromDataFile : public DataBuffer { +class DataBufferFromDataFile final : public DataBuffer { public: /** @@ -60,19 +60,17 @@ public: /** * @brief Update Data Buffer ( it is for child thread ) - * @param[in] BufferType training, validation, test * @retval void */ - void updateData(DatasetDataUsageType type); + void updateData() override; /** * @brief set train data file name - * @param[in] type data type : DATA_TRAIN, DATA_VAL, DATA_TEST * @param[in] path file path * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int setDataFile(DatasetDataUsageType type, std::string path); + int setDataFile(const std::string &path) override; /** * @brief set feature size @@ -82,15 +80,6 @@ public: */ int setFeatureSize(TensorDim indim); - /** - * @brief set property - * @param[in] type type of property - * @param[in] value string value of property - * @retval #ML_ERROR_NONE Successful. - * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. - */ - int setProperty(const PropertyType type, std::string &value); - private: /** * @brief raw data file names diff --git a/nntrainer/dataset/databuffer_func.cpp b/nntrainer/dataset/databuffer_func.cpp index 53c8bc3..f01cce5 100644 --- a/nntrainer/dataset/databuffer_func.cpp +++ b/nntrainer/dataset/databuffer_func.cpp @@ -85,8 +85,8 @@ int DataBufferFromCallback::init() { return ML_ERROR_NONE; } -int DataBufferFromCallback::setGeneratorFunc(DatasetDataUsageType type, - datagen_cb func, void *user_data) { +int DataBufferFromCallback::setGeneratorFunc(datagen_cb func, void *user_data) { + auto type = DatasetDataUsageType::DATA_TRAIN; int status = ML_ERROR_NONE; switch (type) { @@ -116,7 +116,8 @@ int DataBufferFromCallback::setGeneratorFunc(DatasetDataUsageType type, return status; } -void DataBufferFromCallback::updateData(DatasetDataUsageType type) { +void DataBufferFromCallback::updateData() { + auto type = DatasetDataUsageType::DATA_TRAIN; int status = ML_ERROR_NONE; unsigned int buf_size = 0; diff --git a/nntrainer/dataset/databuffer_func.h b/nntrainer/dataset/databuffer_func.h index 87dfc76..d89382a 100644 --- a/nntrainer/dataset/databuffer_func.h +++ b/nntrainer/dataset/databuffer_func.h @@ -60,20 +60,16 @@ public: /** * @brief set function pointer for each type - * @param[in] type Buffer Type * @param[in] call back function pointer * @retval #ML_ERROR_NONE Successful. * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter. */ - int setGeneratorFunc(DatasetDataUsageType type, datagen_cb func, - void *user_data = nullptr) override; + int setGeneratorFunc(datagen_cb func, void *user_data = nullptr) override; /** * @brief Update Data Buffer ( it is for child thread ) - * @param[in] DatasetDataUsageType training, validation, test - * @retval void */ - void updateData(DatasetDataUsageType type); + void updateData() override; /** * @brief set property diff --git a/nntrainer/models/model_loader.cpp b/nntrainer/models/model_loader.cpp index f8e8bda..7b8ab64 100644 --- a/nntrainer/models/model_loader.cpp +++ b/nntrainer/models/model_loader.cpp @@ -224,10 +224,7 @@ int ModelLoader::loadDatasetConfigIni(dictionary *ini, NeuralNetwork &model) { return status; } - /// setting data to data_train is intended for now. later the function - /// should be called without this enum - return dbuffer->setDataFile(DatasetDataUsageType::DATA_TRAIN, - resolvePath(path)); + return dbuffer->setDataFile(resolvePath(path)); }; status = diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index effa683..0397b17 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -623,12 +623,6 @@ int NeuralNetwork::train_run() { auto &label = last_layer_node->getOutputGrad(0); auto &in = first_layer_node->getInput(0); - /// below constant is needed after changing - /// databuffer having train, valid, test -> train buffer, valid buffer, test - /// buffer After the cahgne, only data train is used inside a databuffer. - /// RUN_CONSTANT is a stub value to deal with the situation - auto RUN_CONSTANT = DatasetDataUsageType::DATA_TRAIN; - auto &[train_buffer, valid_buffer, test_buffer] = data_buffers; if (train_buffer == nullptr) { @@ -638,18 +632,17 @@ int NeuralNetwork::train_run() { for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) { training.loss = 0.0f; - status = train_buffer->run(RUN_CONSTANT); + status = train_buffer->run(); if (status != ML_ERROR_NONE) { - train_buffer->clear(RUN_CONSTANT); + train_buffer->clear(); return status; } /// @todo make this working, test buffer is running but doing nothing - if (test_buffer != nullptr && - test_buffer->getValidation()[static_cast(RUN_CONSTANT)]) { - status = test_buffer->run(nntrainer::DatasetDataUsageType::DATA_TEST); + if (test_buffer != nullptr && test_buffer->getValidation()[0]) { + status = test_buffer->run(); if (status != ML_ERROR_NONE) { - test_buffer->clear(DatasetDataUsageType::DATA_TEST); + test_buffer->clear(); return status; } } @@ -657,23 +650,22 @@ int NeuralNetwork::train_run() { int count = 0; while (true) { - if (train_buffer->getDataFromBuffer(RUN_CONSTANT, in.getData(), - label.getData())) { + if (train_buffer->getDataFromBuffer(in.getData(), label.getData())) { try { forwarding(true); backwarding(iter++); } catch (std::exception &e) { - train_buffer->clear(RUN_CONSTANT); + train_buffer->clear(); ml_loge("Error: training error in #%d/%d. %s", epoch_idx, epochs, e.what()); throw; } std::cout << "#" << epoch_idx << "/" << epochs; float loss = getLoss(); - train_buffer->displayProgress(count++, RUN_CONSTANT, loss); + train_buffer->displayProgress(count++, loss); training.loss += loss; } else { - train_buffer->clear(RUN_CONSTANT); + train_buffer->clear(); break; } } @@ -687,21 +679,19 @@ int NeuralNetwork::train_run() { std::cout << "#" << epoch_idx << "/" << epochs << " - Training Loss: " << training.loss; - if (valid_buffer != nullptr && - valid_buffer->getValidation()[static_cast(RUN_CONSTANT)]) { + if (valid_buffer != nullptr && valid_buffer->getValidation()[0]) { int right = 0; validation.loss = 0.0f; unsigned int tcases = 0; - status = valid_buffer->run(RUN_CONSTANT); + status = valid_buffer->run(); if (status != ML_ERROR_NONE) { - valid_buffer->clear(RUN_CONSTANT); + valid_buffer->clear(); return status; } while (true) { - if (valid_buffer->getDataFromBuffer(RUN_CONSTANT, in.getData(), - label.getData())) { + if (valid_buffer->getDataFromBuffer(in.getData(), label.getData())) { forwarding(false); auto model_out = output.argmax(); auto label_out = label.argmax(); @@ -712,7 +702,7 @@ int NeuralNetwork::train_run() { validation.loss += getLoss(); tcases++; } else { - valid_buffer->clear(RUN_CONSTANT); + valid_buffer->clear(); break; } } diff --git a/test/tizen_capi/unittest_tizen_capi_dataset.cpp b/test/tizen_capi/unittest_tizen_capi_dataset.cpp index 8f744e6..ac7403f 100644 --- a/test/tizen_capi/unittest_tizen_capi_dataset.cpp +++ b/test/tizen_capi/unittest_tizen_capi_dataset.cpp @@ -148,22 +148,9 @@ TEST(nntrainer_capi_dataset, set_dataset_property_02_p) { &dataset, getTestResPath("trainingSet.dat").c_str(), NULL, NULL); EXPECT_EQ(status, ML_ERROR_NONE); - std::string train_prop = "train_data=" + getTestResPath("trainingSet.dat"); - std::string val_prop = "val_data=" + getTestResPath("valSet.dat"); - std::string test_prop = "test_data=" + getTestResPath("testSet.dat"); - - /** Multiple properties */ - status = ml_train_dataset_set_property(dataset, val_prop.c_str(), - test_prop.c_str(), NULL); - EXPECT_EQ(status, ML_ERROR_NONE); - status = ml_train_dataset_set_property(dataset, "buffer_size=100", NULL); EXPECT_EQ(status, ML_ERROR_NONE); - /** Overwrite properties */ - status = ml_train_dataset_set_property(dataset, train_prop.c_str(), NULL); - EXPECT_EQ(status, ML_ERROR_NONE); - status = ml_train_dataset_destroy(dataset); EXPECT_EQ(status, ML_ERROR_NONE); } diff --git a/test/unittest/unittest_databuffer_file.cpp b/test/unittest/unittest_databuffer_file.cpp index ea6a9b4..d0cf352 100644 --- a/test/unittest/unittest_databuffer_file.cpp +++ b/test/unittest/unittest_databuffer_file.cpp @@ -41,8 +41,7 @@ TEST(nntrainer_DataBuffer, setFeatureSize_01_p) { dim.setTensorDim("32:1:1:62720"); status = data_buffer.setClassNum(10); EXPECT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN, - getTestResPath("trainingSet.dat")); + status = data_buffer.setDataFile(getTestResPath("trainingSet.dat")); EXPECT_EQ(status, ML_ERROR_NONE); status = data_buffer.setFeatureSize(dim); EXPECT_EQ(status, ML_ERROR_NONE); @@ -80,14 +79,7 @@ TEST(nntrainer_DataBuffer, init_01_p) { EXPECT_EQ(status, ML_ERROR_NONE); status = data_buffer.setClassNum(10); EXPECT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN, - getTestResPath("trainingSet.dat")); - EXPECT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_VAL, - getTestResPath("valSet.dat")); - EXPECT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST, - getTestResPath("testSet.dat")); + status = data_buffer.setDataFile(getTestResPath("trainingSet.dat")); EXPECT_EQ(status, ML_ERROR_NONE); status = data_buffer.setFeatureSize(dim); EXPECT_EQ(status, ML_ERROR_NONE); @@ -123,8 +115,7 @@ TEST(nntrainer_DataBuffer, setClassNum_02_n) { TEST(nntrainer_DataBuffer, setDataFile_01_p) { int status = ML_ERROR_NONE; nntrainer::DataBufferFromDataFile data_buffer; - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN, - getTestResPath("trainingSet.dat")); + status = data_buffer.setDataFile(getTestResPath("trainingSet.dat")); EXPECT_EQ(status, ML_ERROR_NONE); } @@ -134,8 +125,7 @@ TEST(nntrainer_DataBuffer, setDataFile_01_p) { TEST(nntrainer_DataBuffer, setDataFile_02_n) { int status = ML_ERROR_NONE; nntrainer::DataBufferFromDataFile data_buffer; - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN, - "./no_exist.dat"); + status = data_buffer.setDataFile("./no_exist.dat"); EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER); } @@ -151,24 +141,13 @@ TEST(nntrainer_DataBuffer, clear_01_p) { ASSERT_EQ(status, ML_ERROR_NONE); status = data_buffer.setClassNum(10); ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN, - getTestResPath("trainingSet.dat")); - ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_VAL, - getTestResPath("valSet.dat")); - ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST, - getTestResPath("testSet.dat")); + status = data_buffer.setDataFile(getTestResPath("trainingSet.dat")); ASSERT_EQ(status, ML_ERROR_NONE); status = data_buffer.setFeatureSize(dim); ASSERT_EQ(status, ML_ERROR_NONE); status = data_buffer.init(); ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.run(nntrainer::DatasetDataUsageType::DATA_TRAIN); - ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.run(nntrainer::DatasetDataUsageType::DATA_TEST); - ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.run(nntrainer::DatasetDataUsageType::DATA_VAL); + status = data_buffer.run(); ASSERT_EQ(status, ML_ERROR_NONE); status = data_buffer.clear(); EXPECT_EQ(status, ML_ERROR_NONE); @@ -180,10 +159,9 @@ TEST(nntrainer_DataBuffer, clear_01_p) { TEST(nntrainer_DataBuffer, clear_02_p) { int status = ML_ERROR_NONE; nntrainer::DataBufferFromDataFile data_buffer; - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST, - getTestResPath("testSet.dat")); + status = data_buffer.setDataFile(getTestResPath("testSet.dat")); ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_TEST); + status = data_buffer.clear(); EXPECT_EQ(status, ML_ERROR_NONE); } @@ -193,11 +171,8 @@ TEST(nntrainer_DataBuffer, clear_02_p) { TEST(nntrainer_DataBuffer, clear_03_p) { int status = ML_ERROR_NONE; nntrainer::DataBufferFromDataFile data_buffer; - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST, - getTestResPath("testSet.dat")); + status = data_buffer.setDataFile(getTestResPath("testSet.dat")); ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_TEST); - EXPECT_EQ(status, ML_ERROR_NONE); status = data_buffer.clear(); EXPECT_EQ(status, ML_ERROR_NONE); } @@ -208,26 +183,13 @@ TEST(nntrainer_DataBuffer, clear_03_p) { TEST(nntrainer_DataBuffer, clear_04_p) { int status = ML_ERROR_NONE; nntrainer::DataBufferFromDataFile data_buffer; - status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST, - getTestResPath("testSet.dat")); + status = data_buffer.setDataFile(getTestResPath("testSet.dat")); ASSERT_EQ(status, ML_ERROR_NONE); - status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_TEST); - EXPECT_EQ(status, ML_ERROR_NONE); status = data_buffer.clear(); EXPECT_EQ(status, ML_ERROR_NONE); } /** - * @brief Data buffer clear BufferType::DATA_UNKNOWN - */ -TEST(nntrainer_DataBuffer, clear_05_n) { - int status = ML_ERROR_NONE; - nntrainer::DataBufferFromDataFile data_buffer; - status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_UNKNOWN); - EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER); -} - -/** * @brief Main gtest */ int main(int argc, char **argv) { -- 2.7.4