[dataset] Clean up dataset enums
authorJihoon Lee <jhoon.it.lee@samsung.com>
Thu, 8 Jul 2021 10:27:50 +0000 (19:27 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 15 Jul 2021 04:37:48 +0000 (13:37 +0900)
`nntrainer::DataType` and `nntrainer::BufferType` is duplicated from
ccapi which was adding complication. This patch simply alternate those
types ccapi `DatasetType` / `DatasetDataUsageType`

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
15 files changed:
Applications/LogisticRegression/jni/main.cpp
Applications/VGG/jni/main.cpp
api/ccapi/include/dataset.h
nntrainer/dataset/databuffer.cpp
nntrainer/dataset/databuffer.h
nntrainer/dataset/databuffer_factory.cpp
nntrainer/dataset/databuffer_factory.h
nntrainer/dataset/databuffer_file.cpp
nntrainer/dataset/databuffer_file.h
nntrainer/dataset/databuffer_func.cpp
nntrainer/dataset/databuffer_func.h
nntrainer/dataset/databuffer_util.h
nntrainer/models/model_loader.cpp
nntrainer/models/neuralnet.cpp
test/unittest/unittest_databuffer_file.cpp

index a2435ed..7d35857 100644 (file)
@@ -172,7 +172,8 @@ int main(int argc, char *argv[]) {
 
   std::shared_ptr<nntrainer::DataBufferFromCallback> DB =
     std::make_shared<nntrainer::DataBufferFromCallback>();
-  DB->setGeneratorFunc(nntrainer::BufferType::BUF_TRAIN, getBatch_train);
+  DB->setGeneratorFunc(nntrainer::DatasetDataUsageType::DATA_TRAIN,
+                       getBatch_train);
 
   /**
    * @brief     Create NN
index b803331..f0aa196 100644 (file)
@@ -401,8 +401,10 @@ int main(int argc, char *argv[]) {
 
   std::shared_ptr<nntrainer::DataBufferFromCallback> DB =
     std::make_shared<nntrainer::DataBufferFromCallback>();
-  DB->setGeneratorFunc(nntrainer::BufferType::BUF_TRAIN, getBatch_train_file);
-  DB->setGeneratorFunc(nntrainer::BufferType::BUF_VAL, getBatch_val_file);
+  DB->setGeneratorFunc(nntrainer::DatasetDataUsageType::DATA_TRAIN,
+                       getBatch_train_file);
+  DB->setGeneratorFunc(nntrainer::DatasetDataUsageType::DATA_VAL,
+                       getBatch_val_file);
 
   /**
    * @brief     Neural Network Create & Initialization
index 4d6f691..0c285e8 100644 (file)
@@ -47,7 +47,7 @@ enum class DatasetType {
 /**
  * @brief     Enumeration of data type
  */
-enum class DatasetDataType {
+enum class DatasetDataUsageType {
   DATA_TRAIN = ML_TRAIN_DATASET_DATA_USAGE_TRAIN, /** data for training */
   DATA_VAL = ML_TRAIN_DATASET_DATA_USAGE_VALID,   /** data for validation */
   DATA_TEST = ML_TRAIN_DATASET_DATA_USAGE_TEST,   /** data for test */
index 81a04ba..f3eed19 100644 (file)
@@ -56,7 +56,7 @@ std::condition_variable cv_train;
 std::condition_variable cv_val;
 std::condition_variable cv_test;
 
-DataBuffer::DataBuffer(DataBufferType type) :
+DataBuffer::DataBuffer(DatasetType type) :
   train_running(),
   val_running(),
   test_running(),
@@ -94,14 +94,14 @@ int DataBuffer::rangeRandom(int min, int max) {
   return dist(rng);
 }
 
-int DataBuffer::run(BufferType type) {
+int DataBuffer::run(DatasetDataUsageType type) {
   int status = ML_ERROR_NONE;
   switch (type) {
-  case BufferType::BUF_TRAIN:
+  case DatasetDataUsageType::DATA_TRAIN:
     if (trainReadyFlag == DATA_ERROR)
       return ML_ERROR_INVALID_PARAMETER;
 
-    if (validation[DATA_TRAIN]) {
+    if (validation[static_cast<int>(DatasetDataUsageType::DATA_TRAIN)]) {
       this->train_running = true;
       this->train_thread = std::thread(&DataBuffer::updateData, this, type);
       if (globalExceptionPtr) {
@@ -117,10 +117,10 @@ int DataBuffer::run(BufferType type) {
       return ML_ERROR_INVALID_PARAMETER;
     }
     break;
-  case BufferType::BUF_VAL:
+  case DatasetDataUsageType::DATA_VAL:
     if (valReadyFlag == DATA_ERROR)
       return ML_ERROR_INVALID_PARAMETER;
-    if (validation[DATA_VAL]) {
+    if (validation[static_cast<int>(DatasetDataUsageType::DATA_VAL)]) {
       this->val_running = true;
       this->val_thread = std::thread(&DataBuffer::updateData, this, type);
       if (globalExceptionPtr) {
@@ -136,11 +136,11 @@ int DataBuffer::run(BufferType type) {
       return ML_ERROR_INVALID_PARAMETER;
     }
     break;
-  case BufferType::BUF_TEST:
+  case DatasetDataUsageType::DATA_TEST:
     if (testReadyFlag == DATA_ERROR)
       return ML_ERROR_INVALID_PARAMETER;
 
-    if (validation[DATA_TEST]) {
+    if (validation[static_cast<int>(DatasetDataUsageType::DATA_TEST)]) {
       this->test_running = true;
       this->test_thread = std::thread(&DataBuffer::updateData, this, type);
       if (globalExceptionPtr) {
@@ -165,31 +165,34 @@ int DataBuffer::run(BufferType type) {
   return status;
 }
 
-int DataBuffer::clear(BufferType type) {
+int DataBuffer::clear(DatasetDataUsageType type) {
   int status = ML_ERROR_NONE;
   NN_EXCEPTION_NOTI(DATA_NOT_READY);
   switch (type) {
-  case BufferType::BUF_TRAIN: {
+  case DatasetDataUsageType::DATA_TRAIN: {
     train_running = false;
-    if (validation[DATA_TRAIN] && true == train_thread.joinable())
+    if (validation[static_cast<int>(DatasetDataUsageType::DATA_TRAIN)] &&
+        true == train_thread.joinable())
       train_thread.join();
     this->train_data.clear();
     this->train_data_label.clear();
     this->cur_train_bufsize = 0;
     this->rest_train = max_train;
   } break;
-  case BufferType::BUF_VAL: {
+  case DatasetDataUsageType::DATA_VAL: {
     val_running = false;
-    if (validation[DATA_VAL] && true == val_thread.joinable())
+    if (validation[static_cast<int>(DatasetDataUsageType::DATA_VAL)] &&
+        true == val_thread.joinable())
       val_thread.join();
     this->val_data.clear();
     this->val_data_label.clear();
     this->cur_val_bufsize = 0;
     this->rest_val = max_val;
   } break;
-  case BufferType::BUF_TEST: {
+  case DatasetDataUsageType::DATA_TEST: {
     test_running = false;
-    if (validation[DATA_TEST] && true == test_thread.joinable())
+    if (validation[static_cast<int>(DatasetDataUsageType::DATA_TEST)] &&
+        true == test_thread.joinable())
       test_thread.join();
     this->test_data.clear();
     this->test_data_label.clear();
@@ -208,8 +211,9 @@ int DataBuffer::clear() {
   unsigned int i;
 
   int status = ML_ERROR_NONE;
-  for (i = (int)BufferType::BUF_TRAIN; i <= (int)BufferType::BUF_TEST; ++i) {
-    BufferType type = static_cast<BufferType>(i);
+  for (i = (int)DatasetDataUsageType::DATA_TRAIN;
+       i <= (int)DatasetDataUsageType::DATA_TEST; ++i) {
+    DatasetDataUsageType type = static_cast<DatasetDataUsageType>(i);
     status = this->clear(type);
 
     if (status != ML_ERROR_NONE) {
@@ -221,7 +225,8 @@ int DataBuffer::clear() {
   return status;
 }
 
-bool DataBuffer::getDataFromBuffer(BufferType type, float *out, float *label) {
+bool DataBuffer::getDataFromBuffer(DatasetDataUsageType type, float *out,
+                                   float *label) {
 
   using QueueType = std::vector<std::vector<float>>;
 
@@ -256,7 +261,7 @@ bool DataBuffer::getDataFromBuffer(BufferType type, float *out, float *label) {
 
   /// facade that wait for the databuffer to be filled and pass it to outparam
   /// note that batch_size is passed as an argument because it can vary by
-  /// BufferType::BUF_TYPE later...
+  /// DatasetDataUsageType::BUF_TYPE later...
   auto fill_out_params =
     [&](std::mutex &ready_mutex, std::condition_variable &cv, DataStatus &flag,
         QueueType &data_q, QueueType &label_q, const unsigned int batch_size,
@@ -275,17 +280,17 @@ bool DataBuffer::getDataFromBuffer(BufferType type, float *out, float *label) {
     };
 
   switch (type) {
-  case BufferType::BUF_TRAIN:
+  case DatasetDataUsageType::DATA_TRAIN:
     if (!fill_out_params(readyTrainData, cv_train, trainReadyFlag, train_data,
                          train_data_label, batch_size, cur_train_bufsize))
       return false;
     break;
-  case BufferType::BUF_VAL:
+  case DatasetDataUsageType::DATA_VAL:
     if (!fill_out_params(readyValData, cv_val, valReadyFlag, val_data,
                          val_data_label, batch_size, cur_val_bufsize))
       return false;
     break;
-  case BufferType::BUF_TEST:
+  case DatasetDataUsageType::DATA_TEST:
     if (!fill_out_params(readyTestData, cv_test, testReadyFlag, test_data,
                          test_data_label, batch_size, cur_test_bufsize))
       return false;
@@ -380,17 +385,18 @@ int DataBuffer::setFeatureSize(TensorDim indim) {
   return status;
 }
 
-void DataBuffer::displayProgress(const int count, BufferType type, float loss) {
+void DataBuffer::displayProgress(const int count, DatasetDataUsageType type,
+                                 float loss) {
   int barWidth = 20;
   float max_size = max_train;
   switch (type) {
-  case BufferType::BUF_TRAIN:
+  case DatasetDataUsageType::DATA_TRAIN:
     max_size = max_train;
     break;
-  case BufferType::BUF_VAL:
+  case DatasetDataUsageType::DATA_VAL:
     max_size = max_val;
     break;
-  case BufferType::BUF_TEST:
+  case DatasetDataUsageType::DATA_TEST:
     max_size = max_test;
     break;
   default:
@@ -504,11 +510,7 @@ int DataBuffer::setProperty(const PropertyType type, std::string &value) {
   return status;
 }
 
-int DataBuffer::setGeneratorFunc(BufferType type, datagen_cb func) {
-  return ML_ERROR_NOT_SUPPORTED;
-}
-
-int DataBuffer::setDataFile(DataType type, std::string path) {
+int DataBuffer::setGeneratorFunc(DatasetDataUsageType type, datagen_cb func) {
   return ML_ERROR_NOT_SUPPORTED;
 }
 
index 86f66af..f7a245c 100644 (file)
@@ -38,34 +38,11 @@ namespace nntrainer {
 /**
  * @brief     Aliasing from ccapi ml::train
  */
-using DataBufferType = ml::train::DatasetType;
-using DatasetDataType = ml::train::DatasetDataType;
+using DatasetType = ml::train::DatasetType;
+using DatasetDataUsageType = ml::train::DatasetDataUsageType;
 using datagen_cb = ml::train::datagen_cb;
 
 /**
- * @brief     Enumeration of data type
- */
-typedef enum {
-  DATA_TRAIN =
-    (int)ml::train::DatasetDataType::DATA_TRAIN, /** data for training */
-  DATA_VAL =
-    (int)ml::train::DatasetDataType::DATA_VAL, /** data for validation */
-  DATA_TEST = (int)ml::train::DatasetDataType::DATA_TEST, /** data for test */
-  DATA_UNKNOWN =
-    (int)ml::train::DatasetDataType::DATA_UNKNOWN /** data not known */
-} DataType;
-
-/**
- * @brief     Enumeration of buffer type
- */
-enum class BufferType {
-  BUF_TRAIN = DATA_TRAIN,    /** BUF_TRAIN ( Buffer for training ) */
-  BUF_VAL = DATA_VAL,        /** BUF_VAL ( Buffer for validation ) */
-  BUF_TEST = DATA_TEST,      /** BUF_TEST ( Buffer for test ) */
-  BUF_UNKNOWN = DATA_UNKNOWN /** BUF_UNKNOWN ( unknown ) */
-};
-
-/**
  * @class   DataBuffer Data Buffers
  * @brief   Data Buffer for read and manage data
  */
@@ -75,7 +52,7 @@ public:
    * @brief     Create Buffer
    * @retval    DataBuffer
    */
-  DataBuffer(DataBufferType type);
+  DataBuffer(DatasetType type);
 
   /**
    * @brief     Initialize Buffer with data buffer private variables
@@ -89,7 +66,7 @@ public:
    * @param[in] BufferType training, validation, test
    * @retval    void
    */
-  virtual void updateData(BufferType type) = 0;
+  virtual void updateData(DatasetDataUsageType type) = 0;
 
   /**
    * @brief     function for thread ( training, validation, test )
@@ -97,7 +74,7 @@ public:
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  virtual int run(BufferType type);
+  virtual int run(DatasetDataUsageType type);
 
   /**
    * @brief     clear thread ( training, validation, test )
@@ -105,7 +82,7 @@ public:
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  virtual int clear(BufferType type);
+  virtual int clear(DatasetDataUsageType type);
 
   /**
    * @brief     clear all thread ( training, validation, test )
@@ -123,7 +100,7 @@ public:
    * allocated memory block should be passed
    * @retval    true/false
    */
-  bool getDataFromBuffer(BufferType type, float *out, float *label);
+  bool getDataFromBuffer(DatasetDataUsageType type, float *out, float *label);
 
   /**
    * @brief     set number of class
@@ -178,10 +155,10 @@ public:
   /**
    * @brief     Display Progress
    * @param[in] count calculated set ( batch_size size )
-   * @param[in] type buffer type ( BUF_TRAIN, BUF_VAL, BUF_TEST )
+   * @param[in] type buffer type ( DATA_TRAIN, DATA_VAL, DATA_TEST )
    * @retval void
    */
-  void displayProgress(const int count, BufferType type, float loss);
+  void displayProgress(const int count, DatasetDataUsageType type, float loss);
 
   /**
    * @brief     return validation of data set
@@ -207,23 +184,12 @@ public:
 
   /**
    * @brief     set function pointer for each type
-   * @param[in] type data type : DATA_TRAIN, DATA_VAL, DATA_TEST
-   * @param[in] call back function pointer
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   */
-  virtual int setGeneratorFunc(DatasetDataType type, datagen_cb func) {
-    return setGeneratorFunc((BufferType)type, func);
-  }
-
-  /**
-   * @brief     set function pointer for each type
    * @param[in] type Buffer Type
    * @param[in] call back function pointer
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  virtual int setGeneratorFunc(BufferType type, datagen_cb func);
+  virtual int setGeneratorFunc(DatasetDataUsageType type, datagen_cb func);
 
   /**
    * @brief     set train data file name
@@ -232,20 +198,11 @@ public:
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  virtual int setDataFile(DatasetDataType type, std::string path) {
-    return setDataFile((DataType)type, path);
+  virtual int setDataFile(DatasetDataUsageType type, std::string path) {
+    return setDataFile(type, path);
   }
 
   /**
-   * @brief     set train data file name
-   * @param[in] type data type : DATA_TRAIN, DATA_VAL, DATA_TEST
-   * @param[in] path file path
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   */
-  virtual int setDataFile(DataType type, std::string path);
-
-  /**
    * @brief property type of databuffer
    *
    */
@@ -367,7 +324,7 @@ protected:
   /**
    * @brief     The type of data buffer
    */
-  DataBufferType data_buffer_type;
+  DatasetType data_buffer_type;
 
   /** The user_data to be used for the data generator callback */
   void *user_data;
index 7b3e508..63972eb 100644 (file)
@@ -21,13 +21,13 @@ namespace nntrainer {
 /**
  * @brief Factory creator with constructor
  */
-std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type) {
+std::unique_ptr<DataBuffer> createDataBuffer(DatasetType type) {
   switch (type) {
-  case DataBufferType::GENERATOR:
+  case DatasetType::GENERATOR:
     return std::make_unique<DataBufferFromCallback>();
-  case DataBufferType::FILE:
+  case DatasetType::FILE:
     return std::make_unique<DataBufferFromDataFile>();
-  case DataBufferType::UNKNOWN:
+  case DatasetType::UNKNOWN:
     /** fallthrough intended */
   default:
     throw std::invalid_argument("Unknown type for the dataset");
@@ -37,32 +37,32 @@ std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type) {
 /**
  * @brief Factory creator with constructor for dataset
  */
-std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type,
+std::unique_ptr<DataBuffer> createDataBuffer(DatasetType type,
                                              const char *train_file,
                                              const char *valid_file,
                                              const char *test_file) {
-  if (type != DataBufferType::FILE)
+  if (type != DatasetType::FILE)
     throw std::invalid_argument(
       "Cannot create dataset with files with the given dataset type");
 
   std::unique_ptr<DataBuffer> dataset = createDataBuffer(type);
 
   NNTR_THROW_IF(train_file == nullptr ||
-                  dataset->setDataFile(DataType::DATA_TRAIN, train_file) !=
-                    ML_ERROR_NONE,
+                  dataset->setDataFile(DatasetDataUsageType::DATA_TRAIN,
+                                       train_file) != ML_ERROR_NONE,
                 std::invalid_argument)
     << "invalid train file, path: " << (train_file ? train_file : "null");
 
   if (valid_file) {
-    NNTR_THROW_IF(dataset->setDataFile(DataType::DATA_VAL, valid_file) !=
-                    ML_ERROR_NONE,
+    NNTR_THROW_IF(dataset->setDataFile(DatasetDataUsageType::DATA_VAL,
+                                       valid_file) != ML_ERROR_NONE,
                   std::invalid_argument)
       << "invalid valid file, path: " << (valid_file ? valid_file : "null");
   }
 
   if (test_file) {
-    NNTR_THROW_IF(dataset->setDataFile(DataType::DATA_TEST, test_file) !=
-                    ML_ERROR_NONE,
+    NNTR_THROW_IF(dataset->setDataFile(DatasetDataUsageType::DATA_TEST,
+                                       test_file) != ML_ERROR_NONE,
                   std::invalid_argument)
       << "invalid test file, path: " << (test_file ? test_file : "null");
   }
@@ -73,24 +73,24 @@ std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type,
 /**
  * @brief Factory creator with constructor for dataset
  */
-std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type,
-                                             datagen_cb train, datagen_cb valid,
+std::unique_ptr<DataBuffer> createDataBuffer(DatasetType type, datagen_cb train,
+                                             datagen_cb valid,
                                              datagen_cb test) {
-  if (type != DataBufferType::GENERATOR)
+  if (type != DatasetType::GENERATOR)
     throw std::invalid_argument("Cannot create dataset with generator "
                                 "callbacks with the given dataset type");
 
   std::unique_ptr<DataBuffer> dataset = createDataBuffer(type);
 
-  if (dataset->setGeneratorFunc((BufferType)DataType::DATA_TRAIN, train) !=
+  if (dataset->setGeneratorFunc(DatasetDataUsageType::DATA_TRAIN, train) !=
       ML_ERROR_NONE)
     throw std::invalid_argument("Invalid train data generator");
 
-  if (valid && dataset->setGeneratorFunc((BufferType)DataType::DATA_VAL,
+  if (valid && dataset->setGeneratorFunc(DatasetDataUsageType::DATA_VAL,
                                          valid) != ML_ERROR_NONE)
     throw std::invalid_argument("Invalid valid data generator");
 
-  if (test && dataset->setGeneratorFunc((BufferType)DataType::DATA_TEST,
+  if (test && dataset->setGeneratorFunc(DatasetDataUsageType::DATA_TEST,
                                         test) != ML_ERROR_NONE)
     throw std::invalid_argument("Invalid test data generator");
 
index 93ed81e..306a574 100644 (file)
@@ -21,12 +21,12 @@ namespace nntrainer {
 /**
  * @brief Factory creator with constructor
  */
-std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type);
+std::unique_ptr<DataBuffer> createDataBuffer(DatasetType type);
 
 /**
  * @brief Factory creator with constructor for databuffer with files
  */
-std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type,
+std::unique_ptr<DataBuffer> createDataBuffer(DatasetType type,
                                              const char *train_file,
                                              const char *valid_file = nullptr,
                                              const char *test_file = nullptr);
@@ -34,8 +34,7 @@ std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type,
 /**
  * @brief Factory creator with constructor for databuffer with callbacks
  */
-std::unique_ptr<DataBuffer> createDataBuffer(DataBufferType type,
-                                             datagen_cb train,
+std::unique_ptr<DataBuffer> createDataBuffer(DatasetType type, datagen_cb train,
                                              datagen_cb valid = nullptr,
                                              datagen_cb test = nullptr);
 
index c054844..f837747 100644 (file)
@@ -69,15 +69,18 @@ int DataBufferFromDataFile::init() {
   if (status != ML_ERROR_NONE)
     return status;
 
-  if (validation[DATA_TRAIN] && max_train < batch_size) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_TRAIN)] &&
+      max_train < batch_size) {
     max_train = batch_size;
   }
 
-  if (validation[DATA_VAL] && max_val < batch_size) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_VAL)] &&
+      max_val < batch_size) {
     max_val = batch_size;
   }
 
-  if (validation[DATA_TEST] && max_test < batch_size) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_TEST)] &&
+      max_test < batch_size) {
     max_test = batch_size;
   }
 
@@ -89,20 +92,23 @@ int DataBufferFromDataFile::init() {
   this->val_running = true;
   this->test_running = true;
 
-  if (validation[DATA_TRAIN] && max_train < train_bufsize) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_TRAIN)] &&
+      max_train < train_bufsize) {
     ml_logw("Warning: Total number of train is less than train buffer size. "
             "Train buffer size is set as total number of train");
     train_bufsize = batch_size;
   }
 
-  if (validation[DATA_VAL] && max_val < val_bufsize) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_VAL)] &&
+      max_val < val_bufsize) {
     ml_logw(
       "Warning: Total number of validation is less than validation buffer "
       "size. Validation buffer size is set as total number of validation");
     val_bufsize = batch_size;
   }
 
-  if (validation[DATA_TEST] && max_test < test_bufsize) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_TEST)] &&
+      max_test < test_bufsize) {
     ml_logw("Warning: Total number of test is less than test buffer size. Test "
             "buffer size is set as total number of test");
     test_bufsize = batch_size;
@@ -111,7 +117,7 @@ int DataBufferFromDataFile::init() {
   return ML_ERROR_NONE;
 }
 
-void DataBufferFromDataFile::updateData(BufferType type) {
+void DataBufferFromDataFile::updateData(DatasetDataUsageType type) {
   unsigned int max_size = 0;
   unsigned int buf_size = 0;
   unsigned int *rest_size = NULL;
@@ -121,7 +127,7 @@ void DataBufferFromDataFile::updateData(BufferType type) {
   std::vector<std::vector<float>> *datalabel = NULL;
   std::ifstream file;
   switch (type) {
-  case BufferType::BUF_TRAIN: {
+  case DatasetDataUsageType::DATA_TRAIN: {
     max_size = max_train;
     buf_size = train_bufsize;
     rest_size = &rest_train;
@@ -137,7 +143,7 @@ void DataBufferFromDataFile::updateData(BufferType type) {
     readyTrainData.unlock();
 
   } break;
-  case BufferType::BUF_VAL: {
+  case DatasetDataUsageType::DATA_VAL: {
     max_size = max_val;
     buf_size = val_bufsize;
     rest_size = &rest_val;
@@ -153,7 +159,7 @@ void DataBufferFromDataFile::updateData(BufferType type) {
     readyValData.unlock();
 
   } break;
-  case BufferType::BUF_TEST: {
+  case DatasetDataUsageType::DATA_TEST: {
     max_size = max_test;
     buf_size = test_bufsize;
     rest_size = &rest_test;
@@ -259,48 +265,50 @@ void DataBufferFromDataFile::updateData(BufferType type) {
   file.close();
 }
 
-int DataBufferFromDataFile::setDataFile(DataType type, std::string path) {
+int DataBufferFromDataFile::setDataFile(DatasetDataUsageType type,
+                                        std::string path) {
   int status = ML_ERROR_NONE;
   std::ifstream data_file(path.c_str());
 
   switch (type) {
-  case DATA_TRAIN: {
-    validation[type] = true;
+  case DatasetDataUsageType::DATA_TRAIN: {
+    validation[static_cast<int>(type)] = true;
     if (!data_file.good()) {
       ml_loge(
         "Error: Cannot open data file, Datafile is necessary for training");
-      validation[type] = false;
+      validation[static_cast<int>(type)] = false;
       return ML_ERROR_INVALID_PARAMETER;
     }
     train_name = path;
   } break;
-  case DATA_VAL: {
-    validation[type] = true;
+  case DatasetDataUsageType::DATA_VAL: {
+    validation[static_cast<int>(type)] = true;
     if (!data_file.good()) {
       ml_loge("Error: Cannot open validation data file. Cannot validate "
               "training result");
-      validation[type] = false;
+      validation[static_cast<int>(type)] = false;
       return ML_ERROR_INVALID_PARAMETER;
     }
     val_name = path;
   } break;
-  case DATA_TEST: {
-    validation[type] = true;
+  case DatasetDataUsageType::DATA_TEST: {
+    validation[static_cast<int>(type)] = true;
     if (!data_file.good()) {
       ml_loge("Error: Cannot open test data file. Cannot test training result");
-      validation[type] = false;
+      validation[static_cast<int>(type)] = false;
       return ML_ERROR_INVALID_PARAMETER;
     }
     test_name = path;
   } break;
-  case DATA_UNKNOWN:
+  case DatasetDataUsageType::DATA_UNKNOWN:
   default:
     ml_loge("Error: Not Supported Data Type");
     SET_VALIDATION(false);
     return ML_ERROR_INVALID_PARAMETER;
     break;
   }
-  ml_logd("datafile has set. type: %d, path: %s", type, path.c_str());
+  ml_logd("datafile has set. type: %d, path: %s", static_cast<int>(type),
+          path.c_str());
 
   return status;
 }
@@ -313,7 +321,7 @@ int DataBufferFromDataFile::setFeatureSize(TensorDim tdim) {
   if (status != ML_ERROR_NONE)
     return status;
 
-  if (validation[DATA_TRAIN]) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_TRAIN)]) {
     file_size = getFileSize(train_name);
     max_train = static_cast<unsigned int>(
       file_size /
@@ -325,7 +333,7 @@ int DataBufferFromDataFile::setFeatureSize(TensorDim tdim) {
     max_train = 0;
   }
 
-  if (validation[DATA_VAL]) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_VAL)]) {
     file_size = getFileSize(val_name);
     max_val = static_cast<unsigned int>(
       file_size /
@@ -337,7 +345,7 @@ int DataBufferFromDataFile::setFeatureSize(TensorDim tdim) {
     max_val = 0;
   }
 
-  if (validation[DATA_TEST]) {
+  if (validation[static_cast<int>(DatasetDataUsageType::DATA_TEST)]) {
     file_size = getFileSize(test_name);
     max_test = static_cast<unsigned int>(
       file_size /
@@ -356,18 +364,18 @@ int DataBufferFromDataFile::setProperty(const PropertyType type,
                                         std::string &value) {
   int status = ML_ERROR_NONE;
 
-  if (data_buffer_type != DataBufferType::FILE)
+  if (data_buffer_type != DatasetType::FILE)
     return ML_ERROR_INVALID_PARAMETER;
 
   switch (type) {
   case PropertyType::train_data:
-    status = this->setDataFile(DATA_TRAIN, value);
+    status = this->setDataFile(DatasetDataUsageType::DATA_TRAIN, value);
     break;
   case PropertyType::val_data:
-    status = this->setDataFile(DATA_VAL, value);
+    status = this->setDataFile(DatasetDataUsageType::DATA_VAL, value);
     break;
   case PropertyType::test_data:
-    status = this->setDataFile(DATA_TEST, value);
+    status = this->setDataFile(DatasetDataUsageType::DATA_TEST, value);
     break;
   default:
     status = DataBuffer::setProperty(type, value);
index e7a5e8c..5f666de 100644 (file)
@@ -44,7 +44,7 @@ public:
   /**
    * @brief     Constructor
    */
-  DataBufferFromDataFile() : DataBuffer(DataBufferType::FILE) {}
+  DataBufferFromDataFile() : DataBuffer(DatasetType::FILE) {}
 
   /**
    * @brief     Destructor
@@ -63,7 +63,7 @@ public:
    * @param[in] BufferType training, validation, test
    * @retval    void
    */
-  void updateData(BufferType type);
+  void updateData(DatasetDataUsageType type);
 
   /**
    * @brief     set train data file name
@@ -72,7 +72,7 @@ public:
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  int setDataFile(DataType type, std::string path);
+  int setDataFile(DatasetDataUsageType type, std::string path);
 
   /**
    * @brief     set feature size
index c161867..e87e9e4 100644 (file)
@@ -85,23 +85,24 @@ int DataBufferFromCallback::init() {
   return ML_ERROR_NONE;
 }
 
-int DataBufferFromCallback::setGeneratorFunc(BufferType type, datagen_cb func) {
+int DataBufferFromCallback::setGeneratorFunc(DatasetDataUsageType type,
+                                             datagen_cb func) {
 
   int status = ML_ERROR_NONE;
   switch (type) {
-  case BufferType::BUF_TRAIN:
+  case DatasetDataUsageType::DATA_TRAIN:
     if (!func)
       return ML_ERROR_INVALID_PARAMETER;
     callback_train = func;
     if (func)
       validation[0] = true;
     break;
-  case BufferType::BUF_VAL:
+  case DatasetDataUsageType::DATA_VAL:
     callback_val = func;
     if (func)
       validation[1] = true;
     break;
-  case BufferType::BUF_TEST:
+  case DatasetDataUsageType::DATA_TEST:
     callback_test = func;
     if (func)
       validation[2] = true;
@@ -114,7 +115,7 @@ int DataBufferFromCallback::setGeneratorFunc(BufferType type, datagen_cb func) {
   return status;
 }
 
-void DataBufferFromCallback::updateData(BufferType type) {
+void DataBufferFromCallback::updateData(DatasetDataUsageType type) {
   int status = ML_ERROR_NONE;
 
   unsigned int buf_size = 0;
@@ -125,7 +126,7 @@ void DataBufferFromCallback::updateData(BufferType type) {
   datagen_cb callback;
 
   switch (type) {
-  case BufferType::BUF_TRAIN: {
+  case DatasetDataUsageType::DATA_TRAIN: {
     buf_size = train_bufsize;
     cur_size = &cur_train_bufsize;
     running = &train_running;
@@ -133,7 +134,7 @@ void DataBufferFromCallback::updateData(BufferType type) {
     datalabel = &train_data_label;
     callback = callback_train;
   } break;
-  case BufferType::BUF_VAL: {
+  case DatasetDataUsageType::DATA_VAL: {
     buf_size = val_bufsize;
     cur_size = &cur_val_bufsize;
     running = &val_running;
@@ -141,7 +142,7 @@ void DataBufferFromCallback::updateData(BufferType type) {
     datalabel = &val_data_label;
     callback = callback_val;
   } break;
-  case BufferType::BUF_TEST: {
+  case DatasetDataUsageType::DATA_TEST: {
     buf_size = test_bufsize;
     cur_size = &cur_test_bufsize;
     running = &test_running;
index c1712fd..184c459 100644 (file)
@@ -44,7 +44,7 @@ public:
   /**
    * @brief     Constructor
    */
-  DataBufferFromCallback() : DataBuffer(DataBufferType::GENERATOR) {}
+  DataBufferFromCallback() : DataBuffer(DatasetType::GENERATOR) {}
 
   /**
    * @brief     Destructor
@@ -65,14 +65,14 @@ public:
    * @retval #ML_ERROR_NONE Successful.
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
-  int setGeneratorFunc(BufferType type, datagen_cb func);
+  int setGeneratorFunc(DatasetDataUsageType type, datagen_cb func);
 
   /**
    * @brief     Update Data Buffer ( it is for child thread )
-   * @param[in] BufferType training, validation, test
+   * @param[in] DatasetDataUsageType training, validation, test
    * @retval    void
    */
-  void updateData(BufferType type);
+  void updateData(DatasetDataUsageType type);
 
   /**
    * @brief     set property
index 1641c66..f4b95b6 100644 (file)
  *
  */
 
-#define SET_VALIDATION(val)                                              \
-  do {                                                                   \
-    for (DataType i = DATA_TRAIN; i < DATA_UNKNOWN; i = DataType(i + 1)) \
-      validation[i] = val;                                               \
+#define SET_VALIDATION(val)                                               \
+  do {                                                                    \
+    validation[static_cast<int>(DatasetDataUsageType::DATA_TRAIN)] = val; \
+    validation[static_cast<int>(DatasetDataUsageType::DATA_VAL)] = val;   \
+    validation[static_cast<int>(DatasetDataUsageType::DATA_TEST)] = val;  \
   } while (0)
 
 #define NN_EXCEPTION_NOTI(val)                             \
   do {                                                     \
     switch (type) {                                        \
-    case BufferType::BUF_TRAIN: {                          \
+    case DatasetDataUsageType::DATA_TRAIN: {               \
       std::lock_guard<std::mutex> lgtrain(readyTrainData); \
       trainReadyFlag = val;                                \
       cv_train.notify_all();                               \
     } break;                                               \
-    case BufferType::BUF_VAL: {                            \
+    case DatasetDataUsageType::DATA_VAL: {                 \
       std::lock_guard<std::mutex> lgval(readyValData);     \
       valReadyFlag = val;                                  \
       cv_val.notify_all();                                 \
     } break;                                               \
-    case BufferType::BUF_TEST: {                           \
+    case DatasetDataUsageType::DATA_TEST: {                \
       std::lock_guard<std::mutex> lgtest(readyTestData);   \
       testReadyFlag = val;                                 \
       cv_test.notify_all();                                \
index 5a4c177..3c8fa64 100644 (file)
@@ -193,7 +193,7 @@ int ModelLoader::loadDatasetConfigIni(dictionary *ini, NeuralNetwork &model) {
   int status = ML_ERROR_NONE;
 
   if (iniparser_find_entry(ini, "Dataset") == 0) {
-    model.data_buffer = nntrainer::createDataBuffer(DataBufferType::GENERATOR);
+    model.data_buffer = nntrainer::createDataBuffer(DatasetType::GENERATOR);
     status = model.data_buffer->setBatchSize(model.batch_size);
     return status;
   }
@@ -203,12 +203,12 @@ int ModelLoader::loadDatasetConfigIni(dictionary *ini, NeuralNetwork &model) {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
-  model.data_buffer = nntrainer::createDataBuffer(DataBufferType::FILE);
+  model.data_buffer = nntrainer::createDataBuffer(DatasetType::FILE);
   std::shared_ptr<DataBufferFromDataFile> dbuffer =
     std::static_pointer_cast<DataBufferFromDataFile>(model.data_buffer);
 
-  std::function<int(const char *, DataType, bool)> parse_and_set =
-    [&](const char *key, DataType dt, bool required) -> int {
+  std::function<int(const char *, DatasetDataUsageType, bool)> parse_and_set =
+    [&](const char *key, DatasetDataUsageType dt, bool required) -> int {
     const char *path = iniparser_getstring(ini, key, NULL);
 
     if (path == NULL) {
@@ -218,11 +218,14 @@ int ModelLoader::loadDatasetConfigIni(dictionary *ini, NeuralNetwork &model) {
     return dbuffer->setDataFile(dt, resolvePath(path));
   };
 
-  status = parse_and_set("DataSet:TrainData", DATA_TRAIN, true);
+  status =
+    parse_and_set("DataSet:TrainData", DatasetDataUsageType::DATA_TRAIN, true);
   NN_RETURN_STATUS();
-  status = parse_and_set("DataSet:ValidData", DATA_VAL, false);
+  status =
+    parse_and_set("DataSet:ValidData", DatasetDataUsageType::DATA_VAL, false);
   NN_RETURN_STATUS();
-  status = parse_and_set("DataSet:TestData", DATA_TEST, false);
+  status =
+    parse_and_set("DataSet:TestData", DatasetDataUsageType::DATA_TEST, false);
   NN_RETURN_STATUS();
   const char *path = iniparser_getstring(ini, "Dataset:LabelData", NULL);
   if (path != NULL) {
index 661dfb7..544f1ab 100644 (file)
@@ -600,16 +600,17 @@ int NeuralNetwork::train_run() {
 
   for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
     training.loss = 0.0f;
-    status = data_buffer->run(nntrainer::BufferType::BUF_TRAIN);
+    status = data_buffer->run(nntrainer::DatasetDataUsageType::DATA_TRAIN);
     if (status != ML_ERROR_NONE) {
-      data_buffer->clear(BufferType::BUF_TRAIN);
+      data_buffer->clear(DatasetDataUsageType::DATA_TRAIN);
       return status;
     }
 
-    if (data_buffer->getValidation()[(int)nntrainer::BufferType::BUF_TEST]) {
-      status = data_buffer->run(nntrainer::BufferType::BUF_TEST);
+    if (data_buffer
+          ->getValidation()[(int)nntrainer::DatasetDataUsageType::DATA_TEST]) {
+      status = data_buffer->run(nntrainer::DatasetDataUsageType::DATA_TEST);
       if (status != ML_ERROR_NONE) {
-        data_buffer->clear(BufferType::BUF_TEST);
+        data_buffer->clear(DatasetDataUsageType::DATA_TEST);
         return status;
       }
     }
@@ -617,25 +618,25 @@ int NeuralNetwork::train_run() {
     int count = 0;
 
     while (true) {
-
-      if (data_buffer->getDataFromBuffer(nntrainer::BufferType::BUF_TRAIN,
-                                         in.getData(), label.getData())) {
+      if (data_buffer->getDataFromBuffer(
+            nntrainer::DatasetDataUsageType::DATA_TRAIN, in.getData(),
+            label.getData())) {
         try {
           forwarding(true);
           backwarding(iter++);
         } catch (std::exception &e) {
-          data_buffer->clear(nntrainer::BufferType::BUF_TRAIN);
+          data_buffer->clear(nntrainer::DatasetDataUsageType::DATA_TRAIN);
           ml_loge("Error: training error in #%d/%d. %s", epoch_idx, epochs,
                   e.what());
           throw;
         }
         std::cout << "#" << epoch_idx << "/" << epochs;
         float loss = getLoss();
-        data_buffer->displayProgress(count++, nntrainer::BufferType::BUF_TRAIN,
-                                     loss);
+        data_buffer->displayProgress(
+          count++, nntrainer::DatasetDataUsageType::DATA_TRAIN, loss);
         training.loss += loss;
       } else {
-        data_buffer->clear(nntrainer::BufferType::BUF_TRAIN);
+        data_buffer->clear(nntrainer::DatasetDataUsageType::DATA_TRAIN);
         break;
       }
     }
@@ -649,20 +650,22 @@ int NeuralNetwork::train_run() {
     std::cout << "#" << epoch_idx << "/" << epochs
               << " - Training Loss: " << training.loss;
 
-    if (data_buffer->getValidation()[(int)nntrainer::BufferType::BUF_VAL]) {
+    if (data_buffer
+          ->getValidation()[(int)nntrainer::DatasetDataUsageType::DATA_VAL]) {
       int right = 0;
       validation.loss = 0.0f;
       unsigned int tcases = 0;
 
-      status = data_buffer->run(nntrainer::BufferType::BUF_VAL);
+      status = data_buffer->run(nntrainer::DatasetDataUsageType::DATA_VAL);
       if (status != ML_ERROR_NONE) {
-        data_buffer->clear(BufferType::BUF_VAL);
+        data_buffer->clear(DatasetDataUsageType::DATA_VAL);
         return status;
       }
 
       while (true) {
-        if (data_buffer->getDataFromBuffer(nntrainer::BufferType::BUF_VAL,
-                                           in.getData(), label.getData())) {
+        if (data_buffer->getDataFromBuffer(
+              nntrainer::DatasetDataUsageType::DATA_VAL, in.getData(),
+              label.getData())) {
           forwarding(false);
           auto model_out = output.argmax();
           auto label_out = label.argmax();
@@ -673,7 +676,7 @@ int NeuralNetwork::train_run() {
           validation.loss += getLoss();
           tcases++;
         } else {
-          data_buffer->clear(nntrainer::BufferType::BUF_VAL);
+          data_buffer->clear(nntrainer::DatasetDataUsageType::DATA_VAL);
           break;
         }
       }
index d199793..ea6a9b4 100644 (file)
@@ -41,7 +41,7 @@ TEST(nntrainer_DataBuffer, setFeatureSize_01_p) {
   dim.setTensorDim("32:1:1:62720");
   status = data_buffer.setClassNum(10);
   EXPECT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.setDataFile(nntrainer::DATA_TRAIN,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN,
                                    getTestResPath("trainingSet.dat"));
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.setFeatureSize(dim);
@@ -80,13 +80,13 @@ TEST(nntrainer_DataBuffer, init_01_p) {
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.setClassNum(10);
   EXPECT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.setDataFile(nntrainer::DATA_TRAIN,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN,
                                    getTestResPath("trainingSet.dat"));
   EXPECT_EQ(status, ML_ERROR_NONE);
-  status =
-    data_buffer.setDataFile(nntrainer::DATA_VAL, getTestResPath("valSet.dat"));
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_VAL,
+                                   getTestResPath("valSet.dat"));
   EXPECT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.setDataFile(nntrainer::DATA_TEST,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST,
                                    getTestResPath("testSet.dat"));
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.setFeatureSize(dim);
@@ -123,7 +123,7 @@ TEST(nntrainer_DataBuffer, setClassNum_02_n) {
 TEST(nntrainer_DataBuffer, setDataFile_01_p) {
   int status = ML_ERROR_NONE;
   nntrainer::DataBufferFromDataFile data_buffer;
-  status = data_buffer.setDataFile(nntrainer::DATA_TRAIN,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN,
                                    getTestResPath("trainingSet.dat"));
   EXPECT_EQ(status, ML_ERROR_NONE);
 }
@@ -134,7 +134,8 @@ TEST(nntrainer_DataBuffer, setDataFile_01_p) {
 TEST(nntrainer_DataBuffer, setDataFile_02_n) {
   int status = ML_ERROR_NONE;
   nntrainer::DataBufferFromDataFile data_buffer;
-  status = data_buffer.setDataFile(nntrainer::DATA_TRAIN, "./no_exist.dat");
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN,
+                                   "./no_exist.dat");
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
 }
 
@@ -150,24 +151,24 @@ TEST(nntrainer_DataBuffer, clear_01_p) {
   ASSERT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.setClassNum(10);
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.setDataFile(nntrainer::DATA_TRAIN,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TRAIN,
                                    getTestResPath("trainingSet.dat"));
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status =
-    data_buffer.setDataFile(nntrainer::DATA_VAL, getTestResPath("valSet.dat"));
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_VAL,
+                                   getTestResPath("valSet.dat"));
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.setDataFile(nntrainer::DATA_TEST,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST,
                                    getTestResPath("testSet.dat"));
   ASSERT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.setFeatureSize(dim);
   ASSERT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.init();
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.run(nntrainer::BufferType::BUF_TRAIN);
+  status = data_buffer.run(nntrainer::DatasetDataUsageType::DATA_TRAIN);
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.run(nntrainer::BufferType::BUF_TEST);
+  status = data_buffer.run(nntrainer::DatasetDataUsageType::DATA_TEST);
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.run(nntrainer::BufferType::BUF_VAL);
+  status = data_buffer.run(nntrainer::DatasetDataUsageType::DATA_VAL);
   ASSERT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.clear();
   EXPECT_EQ(status, ML_ERROR_NONE);
@@ -179,10 +180,10 @@ TEST(nntrainer_DataBuffer, clear_01_p) {
 TEST(nntrainer_DataBuffer, clear_02_p) {
   int status = ML_ERROR_NONE;
   nntrainer::DataBufferFromDataFile data_buffer;
-  status = data_buffer.setDataFile(nntrainer::DATA_TEST,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST,
                                    getTestResPath("testSet.dat"));
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.clear(nntrainer::BufferType::BUF_TEST);
+  status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_TEST);
   EXPECT_EQ(status, ML_ERROR_NONE);
 }
 
@@ -192,10 +193,10 @@ TEST(nntrainer_DataBuffer, clear_02_p) {
 TEST(nntrainer_DataBuffer, clear_03_p) {
   int status = ML_ERROR_NONE;
   nntrainer::DataBufferFromDataFile data_buffer;
-  status = data_buffer.setDataFile(nntrainer::DATA_TEST,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST,
                                    getTestResPath("testSet.dat"));
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.clear(nntrainer::BufferType::BUF_TEST);
+  status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_TEST);
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.clear();
   EXPECT_EQ(status, ML_ERROR_NONE);
@@ -207,22 +208,22 @@ TEST(nntrainer_DataBuffer, clear_03_p) {
 TEST(nntrainer_DataBuffer, clear_04_p) {
   int status = ML_ERROR_NONE;
   nntrainer::DataBufferFromDataFile data_buffer;
-  status = data_buffer.setDataFile(nntrainer::DATA_TEST,
+  status = data_buffer.setDataFile(nntrainer::DatasetDataUsageType::DATA_TEST,
                                    getTestResPath("testSet.dat"));
   ASSERT_EQ(status, ML_ERROR_NONE);
-  status = data_buffer.clear(nntrainer::BufferType::BUF_TEST);
+  status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_TEST);
   EXPECT_EQ(status, ML_ERROR_NONE);
   status = data_buffer.clear();
   EXPECT_EQ(status, ML_ERROR_NONE);
 }
 
 /**
- * @brief Data buffer clear BufferType::BUF_UNKNOWN
+ * @brief Data buffer clear BufferType::DATA_UNKNOWN
  */
 TEST(nntrainer_DataBuffer, clear_05_n) {
   int status = ML_ERROR_NONE;
   nntrainer::DataBufferFromDataFile data_buffer;
-  status = data_buffer.clear(nntrainer::BufferType::BUF_UNKNOWN);
+  status = data_buffer.clear(nntrainer::DatasetDataUsageType::DATA_UNKNOWN);
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
 }