Add DataBuffer for big data
authorjijoong.moon <jijoong.moon@samsung.com>
Mon, 17 Feb 2020 11:49:16 +0000 (20:49 +0900)
committer문지중/On-Device Lab(SR)/Principal Engineer/삼성전자 <jijoong.moon@samsung.com>
Mon, 17 Feb 2020 12:00:13 +0000 (21:00 +0900)
- Add DataBuffer Class to read/write big data from file
- Multi-Threaded
- Automatic Updated

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
Applications/Classification/jni/main.cpp
Applications/Classification/res/Classification.ini
CMakeLists.txt
include/databuffer.h [new file with mode: 0644]
src/databuffer.cpp [new file with mode: 0644]
src/neuralnet.cpp

index 6c97e66..5bfb245 100644 (file)
@@ -38,6 +38,7 @@
 #include "tensorflow/contrib/lite/string_util.h"
 #include "tensorflow/contrib/lite/tools/gen_op_registration.h"
 
+#include "databuffer.h"
 #include "layers.h"
 #include "neuralnet.h"
 #include "tensor.h"
 /**
  * @brief     Data size for each category
  */
-#define TOTAL_TRAIN_DATA_SIZE 500
+#define TOTAL_TRAIN_DATA_SIZE 100
 
-#define TOTAL_VAL_DATA_SIZE 50
+#define TOTAL_VAL_DATA_SIZE 10
+
+#define TOTAL_TEST_DATA_SIZE 100
+
+#define BUFFER_SIZE 100
 
 /**
  * @brief     Number of category : Three
@@ -213,7 +218,7 @@ void getFeature(const string filename, vector<float> &feature_input) {
  * @param[out] feature_output save label data
  */
 void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector<vector<float>> &feature_output,
-                     std::string type) {
+                     std::string type, std::ofstream &f) {
   string total_label[10] = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
 
   int data_size = TOTAL_TRAIN_DATA_SIZE;
@@ -222,6 +227,9 @@ void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector
   if (!type.compare("val")) {
     data_size = TOTAL_VAL_DATA_SIZE;
     val = true;
+  } else if (!type.compare("test")) {
+    data_size = TOTAL_TEST_DATA_SIZE;
+    val = false;
   }
 
   int trainingSize = TOTAL_LABEL_SIZE * data_size;
@@ -233,7 +241,11 @@ void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector
 
   for (int i = 0; i < TOTAL_LABEL_SIZE; i++) {
     std::string path = p;
-    path += total_label[i];
+    if (!type.compare("val") || !type.compare("training")) {
+      path += "train/" + total_label[i];
+    } else if (!type.compare("test")) {
+      path += "test/" + total_label[i];
+    }
 
     for (int j = 0; j < data_size; j++) {
       std::string img = path + "/";
@@ -248,24 +260,45 @@ void ExtractFeatures(std::string p, vector<vector<float>> &feature_input, vector
       img += ss.str() + ".bmp";
       printf("%s\n", img.c_str());
 
-      feature_input[count].resize(FEATURE_SIZE);
+      std::vector<float> _input, _output;
+      _input.resize(FEATURE_SIZE);
+      _output.resize(TOTAL_LABEL_SIZE);
+
+      getFeature(img, _input);
+      _output[i] = 1;
+
+      for (unsigned int k = 0; k < FEATURE_SIZE; ++k)
+        f.write((char *)&_input[k], sizeof(float));
+
+      for (unsigned int k = 0; k < TOTAL_LABEL_SIZE; ++k)
+        f.write((char *)&_output[k], sizeof(float));
 
-      getFeature(img, feature_input[count]);
-      feature_output[count].resize(TOTAL_LABEL_SIZE);
-      feature_output[count][i] = 1;
       count++;
     }
   }
 }
 
-bool getMiniBatch(std::vector<std::vector<float>> inVec, std::vector<std::vector<float>> inLabel,
-                  std::vector<std::vector<std::vector<float>>> &outVec,
-                  std::vector<std::vector<std::vector<float>>> &outLabel) {
+bool getData(std::ifstream &F, std::vector<float> &outVec, std::vector<float> &outLabel, int id) {
+  long pos = F.tellg();
+  F.seekg(pos + (FEATURE_SIZE + TOTAL_LABEL_SIZE) * id);
+  for (int i = 0; i < FEATURE_SIZE; i++)
+    F.read((char *)&outVec[i], sizeof(float));
+  for (int i = 0; i < TOTAL_LABEL_SIZE; i++)
+    F.read((char *)&outLabel[i], sizeof(float));
+
+  return true;
+}
+
+bool getMiniBatch(std::vector<std::vector<std::vector<float>>> &outVec,
+                  std::vector<std::vector<std::vector<float>>> &outLabel, std::string type) {
   std::vector<int> memI;
   std::vector<int> memJ;
   int count = 0;
   int data_size = TOTAL_TRAIN_DATA_SIZE;
 
+  std::string filename = type + "Set.dat";
+  std::ifstream F(filename, std::ios::in | std::ios::binary);
+
   for (int i = 0; i < TOTAL_LABEL_SIZE * data_size; i++) {
     if (!duplicate[i])
       count++;
@@ -286,13 +319,23 @@ bool getMiniBatch(std::vector<std::vector<float>> inVec, std::vector<std::vector
 
   for (int i = 0; i < count; i++) {
     std::vector<std::vector<float>> out;
-    out.push_back(inVec[memI[i]]);
-    outVec.push_back(out);
-
     std::vector<std::vector<float>> outL;
-    outL.push_back(inLabel[memI[i]]);
+    std::vector<float> o;
+    std::vector<float> l;
+
+    o.resize(FEATURE_SIZE);
+    l.resize(TOTAL_LABEL_SIZE);
+
+    getData(F, o, l, memI[i]);
+
+    out.push_back(o);
+    outL.push_back(l);
+
+    outVec.push_back(out);
     outLabel.push_back(outL);
   }
+
+  F.close();
   return true;
 }
 
@@ -303,14 +346,16 @@ void save(std::vector<std::vector<float>> inVec, std::vector<std::vector<float>>
     data_size = TOTAL_TRAIN_DATA_SIZE;
   } else if (!type.compare("val")) {
     data_size = TOTAL_VAL_DATA_SIZE;
+  } else if (!type.compare("test")) {
+    data_size = TOTAL_TEST_DATA_SIZE;
   }
 
   std::ofstream TrainigSet(file, std::ios::out | std::ios::binary);
-  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * data_size; i++) {
-    for (unsigned int j = 0; j < FEATURE_SIZE; j++) {
+  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * data_size; ++i) {
+    for (unsigned int j = 0; j < FEATURE_SIZE; ++j) {
       TrainigSet.write((char *)&inVec[i][j], sizeof(float));
     }
-    for (unsigned int j = 0; j < TOTAL_LABEL_SIZE; j++)
+    for (unsigned int j = 0; j < TOTAL_LABEL_SIZE; ++j)
       TrainigSet.write((char *)&inLabel[i][j], sizeof(float));
   }
 }
@@ -318,30 +363,11 @@ void save(std::vector<std::vector<float>> inVec, std::vector<std::vector<float>>
 bool read(std::vector<std::vector<float>> &inVec, std::vector<std::vector<float>> &inLabel, std::string type) {
   std::string file = type + "Set.dat";
 
-  unsigned int data_size;
-  if (!type.compare("training")) {
-    data_size = TOTAL_TRAIN_DATA_SIZE;
-  } else if (!type.compare("val")) {
-    data_size = TOTAL_VAL_DATA_SIZE;
-  }
+  std::ifstream TrainingSet(file, std::ios::in | std::ios::binary);
 
-  std::ifstream TrainigSet(file, std::ios::out | std::ios::binary);
-  if (!TrainigSet.good())
+  if (!TrainingSet.good())
     return false;
 
-  inVec.resize(TOTAL_LABEL_SIZE * data_size);
-  inLabel.resize(TOTAL_LABEL_SIZE * data_size);
-
-  for (unsigned int i = 0; i < TOTAL_LABEL_SIZE * data_size; i++) {
-    inVec[i].resize(FEATURE_SIZE);
-    for (unsigned int j = 0; j < FEATURE_SIZE; j++) {
-      TrainigSet.read((char *)&inVec[i][j], sizeof(float));
-    }
-    inLabel[i].resize(TOTAL_LABEL_SIZE);
-    for (unsigned int j = 0; j < TOTAL_LABEL_SIZE; j++)
-      TrainigSet.read((char *)&inLabel[i][j], sizeof(float));
-  }
-  std::cout << "read done\n" << std::endl;
   return true;
 }
 
@@ -364,21 +390,36 @@ int main(int argc, char *argv[]) {
   std::string ini_file = data_path + "ini.bin";
   std::vector<std::vector<float>> inputVector, outputVector;
   std::vector<std::vector<float>> inputValVector, outputValVector;
+  std::vector<std::vector<float>> inputTestVector, outputTestVector;
 
   if (!read(inputVector, outputVector, "training")) {
     /**
      * @brief     Extract Feature
      */
-    ExtractFeatures(data_path, inputVector, outputVector, "training");
-    save(inputVector, outputVector, "training");
+    std::string filename = "trainingSet.dat";
+    std::ofstream f(filename, std::ios::out | std::ios::binary);
+    ExtractFeatures(data_path, inputVector, outputVector, "training", f);
+    f.close();
   }
 
   if (!read(inputValVector, outputValVector, "val")) {
     /**
      * @brief     Extract Feature
      */
-    ExtractFeatures(data_path, inputValVector, outputValVector, "val");
-    save(inputValVector, outputValVector, "val");
+    std::string filename = "valSet.dat";
+    std::ofstream f(filename, std::ios::out | std::ios::binary);
+    ExtractFeatures(data_path, inputValVector, outputValVector, "val", f);
+    f.close();
+  }
+
+  if (!read(inputValVector, outputValVector, "test")) {
+    /**
+     * @brief     Extract Feature
+     */
+    std::string filename = "testSet.dat";
+    std::ofstream f(filename, std::ios::out | std::ios::binary);
+    ExtractFeatures(data_path, inputTestVector, outputTestVector, "test", f);
+    f.close();
   }
 
   /**
@@ -389,6 +430,18 @@ int main(int argc, char *argv[]) {
   NN.init();
   NN.readModel();
 
+  DataBuffer buf;
+
+  std::ifstream train_file("trainingSet.dat", std::ios::in | std::ios::binary);
+  std::ifstream val_file("valSet.dat", std::ios::in | std::ios::binary);
+  std::ifstream test_file("testSet.dat", std::ios::in | std::ios::binary);
+
+  buf.init(MINI_BATCH, BUFFER_SIZE, BUFFER_SIZE, BUFFER_SIZE, train_file, val_file, test_file,
+           TOTAL_LABEL_SIZE * TOTAL_TRAIN_DATA_SIZE, TOTAL_LABEL_SIZE * TOTAL_VAL_DATA_SIZE,
+           TOTAL_LABEL_SIZE * TOTAL_TEST_DATA_SIZE, FEATURE_SIZE, 10);
+  buf.run(BUF_TRAIN, train_file);
+  buf.run(BUF_VAL, val_file);
+
   /**
    * @brief     back propagation
    */
@@ -402,11 +455,13 @@ int main(int argc, char *argv[]) {
 
       while (true) {
         std::vector<std::vector<std::vector<float>>> in, label;
-        if (getMiniBatch(inputVector, outputVector, in, label)) {
+        if (buf.getDatafromBuffer(BUF_TRAIN, in, label, MINI_BATCH, FEATURE_SIZE, 1, TOTAL_LABEL_SIZE)) {
           NN.backwarding(Tensor(in), Tensor(label), i);
           count++;
           std::cout << count * 32 << " backwoarding done : " << NN.getLoss() << std::endl;
         } else {
+          buf.clear(BUF_TRAIN, train_file);
+          buf.run(BUF_TRAIN, train_file);
           break;
         }
       }
@@ -418,14 +473,21 @@ int main(int argc, char *argv[]) {
       int right = 0;
       float valloss = 0.0;
 
-      for (int j = 0; j < TOTAL_LABEL_SIZE; j++) {
-        for (int k = 0; k < TOTAL_VAL_DATA_SIZE; k++) {
-          Tensor X = Tensor({inputValVector[j * TOTAL_VAL_DATA_SIZE + k]});
-          Tensor Y2 = Tensor({outputValVector[j * TOTAL_VAL_DATA_SIZE + k]});
-          Tensor Y = NN.forwarding(X, Y2);
-          if (Y.argmax() == j)
-            right++;
-          valloss += NN.getLoss();
+      while (true) {
+        std::vector<std::vector<std::vector<float>>> in, label;
+        if (buf.getDatafromBuffer(BUF_VAL, in, label, MINI_BATCH, FEATURE_SIZE, 1, TOTAL_LABEL_SIZE)) {
+          for (int i = 0; i < MINI_BATCH; ++i) {
+            Tensor X = Tensor({in[i]});
+            Tensor Y2 = Tensor({label[i]});
+            Tensor Y = NN.forwarding(X, Y2);
+            if (Y.argmax() == Y2.argmax())
+              right++;
+            valloss += NN.getLoss();
+          }
+        } else {
+          buf.clear(BUF_VAL, val_file);
+          buf.run(BUF_VAL, val_file);
+          break;
         }
       }
 
@@ -436,10 +498,14 @@ int main(int argc, char *argv[]) {
            << " ) >> [ Accuracy : " << right / (float)(TOTAL_LABEL_SIZE * TOTAL_VAL_DATA_SIZE) * 100.0
            << "% ] [ Validation Loss : " << valloss << " ] " << endl;
 
-      NN.setLoss(0.0);
+      // NN.setLoss(0.0);
+
       if (training)
         NN.saveModel();
     }
+    buf.clear(BUF_TRAIN, train_file);
+    buf.clear(BUF_VAL, val_file);
+    buf.clear(BUF_TEST, test_file);
   }
 
   if (!training) {
index cb93043..028294e 100644 (file)
@@ -8,7 +8,7 @@ Learning_rate = 0.001   # Learning Rate
 Decay_rate = 0.96      # for the decay_rate for the decayed learning rate
 Decay_steps = 1000       # decay step for the exponential decayed learning rate
 Epoch = 30000          # Epoch 
-Optimizer = adam       # Optimizer : sgd (stochastic gradien decent),
+Optimizer = sgd        # Optimizer : sgd (stochastic gradien decent),
                        #             adam (Adamtive Moment Estimation)
 Activation = sigmoid   # activation : sigmoid, tanh
 Cost = cross                   # Cost(loss) function : msr (mean square root error)
index a94e533..c883f92 100644 (file)
@@ -24,11 +24,12 @@ set(SRCS
        src/neuralnet.cpp
        src/tensor.cpp
        src/layers.cpp
+       src/databuffer.cpp
        ${INIPARSER}/iniparser.c
        ${INIPARSER}/dictionary.c
        )
 
-set(NNTRAINER_HEADERS include/neuralnet.h include/tensor.h include/layers.h)
+set(NNTRAINER_HEADERS include/neuralnet.h include/tensor.h include/layers.h include/databuffer.h)
 
 if(USE_BLAS)
   ADD_DEFINITIONS(-DUSE_BLAS)
diff --git a/include/databuffer.h b/include/databuffer.h
new file mode 100644 (file)
index 0000000..e0e899d
--- /dev/null
@@ -0,0 +1,171 @@
+/**
+ * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ *
+ * @file       databuffer.h
+ * @date       04 December 2019
+ * @brief      This is Matrix class for calculation using blas library.
+ * @see https://github.sec.samsung.net/jijoong-moon/Transfer-Learning.git
+ * @author     Jijoong Moon <jijoong.moon@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#ifndef __DATABUFFER_H__
+#define __DATABUFFER_H__
+
+#include <atomic>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <thread>
+#include <vector>
+
+/**
+ * @brief     Enumeration of buffer type
+ *            0. BUF_TRAIN ( Buffer for training )
+ *            1. BUF_VAL ( Buffer for validation )
+ *            2. BUF_TEST ( Buffer for test )
+ *            3. Unknown
+ */
+typedef enum { BUF_TRAIN, BUF_VAL, BUF_TEST, BUFF_UNKNOWN } buffer_type;
+
+
+/**
+ * @class   DataBuffer Data Buffers
+ * @brief   Data Buffer for read and manage data
+ */
+class DataBuffer {
+ public:
+
+  /**
+   * @brief     Create Buffer
+   * @retval    DataBuffer
+   */  
+  DataBuffer() : train_running(), val_running(), test_running(), train_thread(), val_thread(), test_thread(){};
+
+  /**
+   * @brief     Create Buffer
+   * @param[in] train_bufsize size buffer
+   * @param[in] val_bufsize size buffer
+   * @param[in] test_bufsize size buffer
+   * @retval    DataBuffer
+   */  
+  DataBuffer(int train_bufsize, int val_bufsize, int test_bufsize);
+  
+  /**
+   * @brief     Initialize Buffer
+   * @param[in] mini_batch size of minibatch
+   * @param[in] train_bufsize size of training buffer
+   * @param[in] val_bufsize size of validation buffer
+   * @param[in] test_bufsize size of test buffer
+   * @param[in] train_file input file stream for training
+   * @param[in] val_file input file stream for validataion
+   * @param[in] test_file input file stream for test
+   * @param[in] max_train maximum number of traing data
+   * @param[in] max_val maximum number of validation data
+   * @param[in] max_test maximum number of test data
+   * @param[in] in_size input size
+   * @param[in] c_num number of class
+   * @retval    true / false 
+   */  
+  bool init(int mini_batch, unsigned int train_bufsize, unsigned int val_bufsize, unsigned int test_bufsize,
+            std::ifstream &train_file, std::ifstream &val_file, std::ifstream &test_file, unsigned int max_train,
+            unsigned int max_val, unsigned int max_test, unsigned int in_size, unsigned int c_num);
+
+  /**
+   * @brief     Update Data Buffer ( it is for child thread )
+   * @param[in] buffer_type training, validation, test
+   * @param[in] file input file stream 
+   * @retval    void
+   */  
+  void UpdateData(buffer_type type, std::ifstream &file);
+
+  /**
+   * @brief     function for thread ( training, validation, test )
+   * @param[in] buffer_type training, validation, test
+   * @param[in] file input file stream 
+   * @retval    void
+   */  
+  void run(buffer_type type, std::ifstream &file);
+
+  /**
+   * @brief     clear thread ( training, validation, test )
+   * @param[in] buffer_type training, validation, test
+   * @param[in] file input file stream 
+   * @retval    void
+   */  
+  void clear(buffer_type type, std::ifstream &file);
+
+  /**
+   * @brief     get Status of Buffer. if number of rest data 
+   *            is samller than minibatch, the return false
+   * @param[in] buffer_type training, validation, test
+   * @retval    true/false
+   */  
+  bool getStatus(buffer_type type);
+
+  /**
+   * @brief     get Data from Data Buffer
+   * @param[in] buffer_type training, validation, test
+   * @param[in] outVec feature data ( minibatch size )
+   * @param[in] outLabel label data ( minibatch size )
+   * @param[in] batch size of batch
+   * @param[in] width width
+   * @param[in] height height
+   * @param[in] c_num number of class
+   * @retval    true/false
+   */  
+  bool getDatafromBuffer(buffer_type type, std::vector<std::vector<std::vector<float>>> &outVec,
+                         std::vector<std::vector<std::vector<float>>> &outLabel, unsigned int batch, unsigned int width,
+                         unsigned int height, unsigned int c_num);
+
+ private:
+  std::vector<std::vector<float>> trainData;
+  std::vector<std::vector<float>> trainDataLabel;
+  std::vector<std::vector<float>> valData;
+  std::vector<std::vector<float>> valDataLabel;
+  std::vector<std::vector<float>> testData;
+  std::vector<std::vector<float>> testDataLabel;
+
+  unsigned int input_size;
+  unsigned int class_num;
+
+  unsigned int cur_train_bufsize;
+  unsigned int cur_val_bufsize;
+  unsigned int cur_test_bufsize;
+
+  unsigned int train_bufsize;
+  unsigned int val_bufsize;
+  unsigned int test_bufsize;
+
+  unsigned int max_train;
+  unsigned int max_val;
+  unsigned int max_test;
+
+  unsigned int rest_train;
+  unsigned int rest_val;
+  unsigned int rest_test;
+
+  unsigned int mini_batch;
+
+  bool train_running;
+  bool val_running;
+  bool test_running;
+
+  std::thread train_thread;
+  std::thread val_thread;
+  std::thread test_thread;
+};
+
+#endif
diff --git a/src/databuffer.cpp b/src/databuffer.cpp
new file mode 100644 (file)
index 0000000..1400e12
--- /dev/null
@@ -0,0 +1,501 @@
+/**
+ * Copyright (C) 2019 Samsung Electronics Co., Ltd. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ *
+ * @file       databuffer.cpp
+ * @date       04 December 2019
+ * @brief      This is Matrix class for calculation using blas library
+ * @see https://github.sec.samsung.net/jijoong-moon/Transfer-Learning.git
+ * @author     Jijoong Moon <jijoong.moon@samsung.com>
+ * @bug                No known bugs except for NYI items
+ *
+ */
+
+#include "include/databuffer.h"
+#include <assert.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <cstring>
+#include <functional>
+#include <mutex>
+#include <sstream>
+#include <thread>
+
+std::mutex data_lock;
+
+static int rangeRandom(int min, int max) {
+  int n = max - min + 1;
+  int remainder = RAND_MAX % n;
+  int x;
+  do {
+    x = rand();
+  } while (x >= RAND_MAX - remainder);
+  return min + x % n;
+}
+
+DataBuffer::DataBuffer(int train_num, int val_num, int test_num) {
+  this->train_bufsize = train_num;
+  this->val_bufsize = val_num;
+  this->test_bufsize = test_num;
+}
+
+bool DataBuffer::init(int mini_batch, unsigned int train_bufsize, unsigned int val_bufsize, unsigned int test_bufsize,
+                      std::ifstream &train_file, std::ifstream &val_file, std::ifstream &test_file,
+                      unsigned int max_train, unsigned int max_val, unsigned int max_test, unsigned int in_size,
+                      unsigned int c_num) {
+  this->input_size = in_size;
+  this->class_num = c_num;
+
+  this->cur_train_bufsize = train_bufsize;
+  this->cur_val_bufsize = val_bufsize;
+  this->cur_test_bufsize = test_bufsize;
+
+  this->train_bufsize = train_bufsize;
+  this->val_bufsize = val_bufsize;
+  this->test_bufsize = test_bufsize;
+
+  this->mini_batch = mini_batch;
+
+  this->max_train = max_train;
+  this->max_val = max_val;
+  this->max_test = max_test;
+
+  this->rest_train = max_train;
+  this->rest_val = max_val;
+  this->rest_test = max_test;
+
+  this->train_running = true;
+  this->val_running = true;
+  this->test_running = true;
+
+  for (unsigned int i = 0; i < train_bufsize; ++i) {
+    std::vector<float> vec;
+    std::vector<float> veclabel;
+    for (unsigned int j = 0; j < input_size; ++j) {
+      float data;
+      train_file.read((char *)&data, sizeof(float));
+      vec.push_back(data);
+    }
+    trainData.push_back(vec);
+    for (unsigned int j = 0; j < class_num; ++j) {
+      float data;
+      train_file.read((char *)&data, sizeof(float));
+      veclabel.push_back(data);
+    }
+    trainDataLabel.push_back(veclabel);
+    rest_train--;
+  }
+  for (unsigned int i = 0; i < val_bufsize; ++i) {
+    std::vector<float> vec;
+    std::vector<float> veclabel;
+    for (unsigned int j = 0; j < input_size; ++j) {
+      float data;
+      val_file.read((char *)&data, sizeof(float));
+      vec.push_back(data);
+    }
+    valData.push_back(vec);
+    for (unsigned int j = 0; j < class_num; ++j) {
+      float data;
+      val_file.read((char *)&data, sizeof(float));
+      veclabel.push_back(data);
+    }
+    valDataLabel.push_back(veclabel);
+    rest_val--;
+  }
+
+  for (unsigned int i = 0; i < test_bufsize; ++i) {
+    std::vector<float> vec;
+    std::vector<float> veclabel;
+    for (unsigned int j = 0; j < input_size; ++j) {
+      float data;
+      test_file.read((char *)&data, sizeof(float));
+      vec.push_back(data);
+    }
+    testData.push_back(vec);
+    for (unsigned int j = 0; j < class_num; ++j) {
+      float data;
+      test_file.read((char *)&data, sizeof(float));
+      veclabel.push_back(data);
+    }
+    testDataLabel.push_back(veclabel);
+    rest_test--;
+  }
+  return true;
+}
+
+void DataBuffer::UpdateData(buffer_type type, std::ifstream &file) {
+  switch (type) {
+    case BUF_TRAIN: {
+      std::vector<unsigned int> mark;
+      mark.resize(max_train);
+      for (unsigned int i = 0; i < max_train; ++i) {
+        mark[i] = i;
+      }
+      unsigned int I;
+      while (train_running || mark.size() == 0) {
+        if (train_bufsize - cur_train_bufsize >= mini_batch && rest_train > mini_batch) {
+          data_lock.lock();
+          std::vector<int> list;
+          for (unsigned int i = 0; i < mini_batch; ++i) {
+            std::vector<float> vec;
+            std::vector<float> veclabel;
+
+            unsigned int id = rangeRandom(0, mark.size() - 1);
+            I = mark[id];
+            mark.erase(mark.begin() + id);
+
+            list.push_back(I);
+            int64_t position = (I * input_size + I * class_num) * sizeof(float);
+            file.seekg(position, std::ios::beg);
+
+            for (unsigned int j = 0; j < input_size; ++j) {
+              float data;
+              file.read((char *)&data, sizeof(float));
+              vec.push_back(data);
+            }
+            trainData.push_back(vec);
+            for (unsigned int j = 0; j < class_num; ++j) {
+              float data;
+              file.read((char *)&data, sizeof(float));
+              veclabel.push_back(data);
+            }
+            trainDataLabel.push_back(veclabel);
+            rest_train--;
+          }
+          cur_train_bufsize += mini_batch;
+          data_lock.unlock();
+        }
+      }
+    } break;
+    case BUF_VAL: {
+      std::vector<int> mark;
+      mark.resize(max_val);
+
+      memset(mark.data(), 0, sizeof(int) * max_val);
+
+      for (unsigned int i = 0; i < max_val; ++i) {
+        mark[i] = i;
+      }
+
+      int I;
+
+      while (val_running || mark.size() == 0) {
+        if (val_bufsize - cur_val_bufsize >= mini_batch && rest_val > mini_batch) {
+          data_lock.lock();
+          for (unsigned int i = 0; i < mini_batch; ++i) {
+            std::vector<float> vec;
+            std::vector<float> veclabel;
+
+            unsigned int id = rangeRandom(0, mark.size() - 1);
+            I = mark[id];
+            mark.erase(mark.begin() + id);
+
+            int64_t position = (I * input_size + I * class_num) * sizeof(float);
+            file.seekg(position, std::ios::beg);
+
+            for (unsigned int j = 0; j < input_size; ++j) {
+              float data;
+              file.read((char *)&data, sizeof(float));
+              vec.push_back(data);
+            }
+            valData.push_back(vec);
+            for (unsigned int j = 0; j < class_num; ++j) {
+              float data;
+              file.read((char *)&data, sizeof(float));
+              veclabel.push_back(data);
+            }
+            valDataLabel.push_back(veclabel);
+            rest_val--;
+          }
+          cur_val_bufsize += mini_batch;
+          data_lock.unlock();
+        }
+      }
+    } break;
+    case BUF_TEST: {
+      std::vector<int> mark;
+      mark.resize(max_test);
+      for (unsigned int i = 0; i < max_test; ++i) {
+        mark[i] = i;
+      }
+      int I;
+
+      while (test_running || mark.size() == 0) {
+        if (test_bufsize - cur_test_bufsize >= mini_batch && rest_test > mini_batch) {
+          data_lock.lock();
+          for (unsigned int i = 0; i < mini_batch; ++i) {
+            std::vector<float> vec;
+            std::vector<float> veclabel;
+
+            unsigned int id = rangeRandom(0, mark.size() - 1);
+            I = mark[id];
+            mark.erase(mark.begin() + id);
+
+            int64_t position = (I * input_size + I * class_num) * sizeof(float);
+            file.seekg(position, std::ios::beg);
+
+            for (unsigned int j = 0; j < input_size; ++j) {
+              float data;
+              file.read((char *)&data, sizeof(float));
+              vec.push_back(data);
+            }
+            testData.push_back(vec);
+            for (unsigned int j = 0; j < class_num; ++j) {
+              float data;
+              file.read((char *)&data, sizeof(float));
+              veclabel.push_back(data);
+            }
+            testDataLabel.push_back(veclabel);
+            rest_test--;
+          }
+          cur_test_bufsize += mini_batch;
+          data_lock.unlock();
+        }
+      }
+    } break;
+    default:
+      break;
+  }
+}
+
+void DataBuffer::run(buffer_type type, std::ifstream &file) {
+  switch (type) {
+    case BUF_TRAIN:
+      this->train_thread = std::thread(&DataBuffer::UpdateData, this, BUF_TRAIN, std::ref(file));
+      this->train_thread.detach();
+      break;
+    case BUF_VAL:
+      this->val_thread = std::thread(&DataBuffer::UpdateData, this, BUF_VAL, std::ref(file));
+      this->val_thread.detach();
+      break;
+    case BUF_TEST:
+      this->test_thread = std::thread(&DataBuffer::UpdateData, this, BUF_TEST, std::ref(file));
+      this->test_thread.detach();
+      break;
+    default:
+      break;
+  }
+}
+
+bool DataBuffer::getStatus(buffer_type type) {
+  bool ret = true;
+  switch (type) {
+    case BUF_TRAIN:
+      if (trainData.size() < mini_batch)
+        ret = false;
+      break;
+    case BUF_VAL:
+      if (valData.size() < mini_batch)
+        ret = false;
+      break;
+    case BUF_TEST:
+      if (testData.size() < mini_batch)
+        ret = false;
+      break;
+    default:
+      break;
+  }
+  return ret;
+}
+
+void DataBuffer::clear(buffer_type type, std::ifstream &file) {
+  switch (type) {
+    case BUF_TRAIN: {
+      train_running = false;
+      this->trainData.clear();
+      this->trainDataLabel.clear();
+      this->cur_train_bufsize = this->train_bufsize;
+      this->rest_train = max_train;
+
+      this->train_running = true;
+      file.clear();
+      file.seekg(0, std::ios::beg);
+
+      for (unsigned int i = 0; i < train_bufsize; ++i) {
+        std::vector<float> vec;
+        std::vector<float> veclabel;
+        for (unsigned int j = 0; j < input_size; ++j) {
+          float data;
+          file.read((char *)&data, sizeof(float));
+          vec.push_back(data);
+        }
+        trainData.push_back(vec);
+        for (unsigned int j = 0; j < class_num; ++j) {
+          float data;
+          file.read((char *)&data, sizeof(float));
+          veclabel.push_back(data);
+        }
+        trainDataLabel.push_back(veclabel);
+        rest_train--;
+      }
+    } break;
+    case BUF_VAL: {
+      val_running = false;
+      this->valData.clear();
+      this->cur_val_bufsize = this->val_bufsize;
+      this->rest_val = max_val;
+
+      this->val_running = true;
+      file.clear();
+      file.seekg(0, std::ios::beg);
+
+      for (unsigned int i = 0; i < val_bufsize; ++i) {
+        std::vector<float> vec;
+        std::vector<float> veclabel;
+        for (unsigned int j = 0; j < input_size; ++j) {
+          float data;
+          file.read((char *)&data, sizeof(float));
+          vec.push_back(data);
+        }
+        valData.push_back(vec);
+        for (unsigned int j = 0; j < class_num; ++j) {
+          float data;
+          file.read((char *)&data, sizeof(float));
+          veclabel.push_back(data);
+        }
+        valDataLabel.push_back(veclabel);
+        rest_val--;
+      }
+
+    } break;
+    case BUF_TEST: {
+      test_running = false;
+      this->testData.clear();
+      this->cur_test_bufsize = this->test_bufsize;
+      this->rest_test = max_test;
+
+      this->test_running = true;
+
+      file.clear();
+      file.seekg(0, std::ios::beg);
+
+      for (unsigned int i = 0; i < test_bufsize; ++i) {
+        std::vector<float> vec;
+        std::vector<float> veclabel;
+        for (unsigned int j = 0; j < input_size; ++j) {
+          float data;
+          file.read((char *)&data, sizeof(float));
+          vec.push_back(data);
+        }
+        testData.push_back(vec);
+        for (unsigned int j = 0; j < class_num; ++j) {
+          float data;
+          file.read((char *)&data, sizeof(float));
+          veclabel.push_back(data);
+        }
+        testDataLabel.push_back(veclabel);
+        rest_test--;
+      }
+    } break;
+    default:
+      break;
+  }
+}
+
+bool DataBuffer::getDatafromBuffer(buffer_type type, std::vector<std::vector<std::vector<float>>> &outVec,
+                                   std::vector<std::vector<std::vector<float>>> &outLabel, unsigned int batch,
+                                   unsigned int width, unsigned int height, unsigned int c_num) {
+  int nomI;
+  unsigned int J, i, j, k;
+
+  switch (type) {
+    case BUF_TRAIN: {
+      std::vector<int> list;
+
+      if (!getStatus(BUF_TRAIN))
+        return false;
+      data_lock.lock();
+      for (k = 0; k < batch; ++k) {
+        nomI = rangeRandom(0, trainData.size() - 1);
+        std::vector<std::vector<float>> v_height;
+        for (j = 0; j < height; ++j) {
+          J = j * width;
+          std::vector<float> v_width;
+          for (i = 0; i < width; ++i) {
+            v_width.push_back(trainData[nomI][J + i]);
+          }
+          v_height.push_back(v_width);
+        }
+
+        list.push_back(nomI);
+        outVec.push_back(v_height);
+        outLabel.push_back({trainDataLabel[nomI]});
+      }
+      for (i = 0; i < batch; ++i) {
+        trainData.erase(trainData.begin() + list[i]);
+        trainDataLabel.erase(trainDataLabel.begin() + list[i]);
+        cur_train_bufsize--;
+      }
+    } break;
+    case BUF_VAL: {
+      std::vector<int> list;
+      if (!getStatus(BUF_VAL))
+        return false;
+      data_lock.lock();
+      for (k = 0; k < batch; ++k) {
+        nomI = rangeRandom(0, valData.size() - 1);
+        std::vector<std::vector<float>> v_height;
+        for (j = 0; j < height; ++j) {
+          J = j * width;
+          std::vector<float> v_width;
+          for (i = 0; i < width; ++i) {
+            v_width.push_back(valData[nomI][J + i]);
+          }
+          v_height.push_back(v_width);
+        }
+
+        list.push_back(nomI);
+        outVec.push_back(v_height);
+        outLabel.push_back({valDataLabel[nomI]});
+      }
+      for (i = 0; i < batch; ++i) {
+        valData.erase(valData.begin() + list[i]);
+        valDataLabel.erase(valDataLabel.begin() + list[i]);
+        cur_val_bufsize--;
+      }
+    } break;
+    case BUF_TEST: {
+      std::vector<int> list;
+      if (!getStatus(BUF_TEST))
+        return false;
+      data_lock.lock();
+      for (k = 0; k < batch; ++k) {
+        nomI = rangeRandom(0, testData.size() - 1);
+        std::vector<std::vector<float>> v_height;
+        for (j = 0; j < height; ++j) {
+          J = j * width;
+          std::vector<float> v_width;
+          for (i = 0; i < width; ++i) {
+            v_width.push_back(testData[nomI][J + i]);
+          }
+          v_height.push_back(v_width);
+        }
+
+        list.push_back(nomI);
+        outVec.push_back(v_height);
+        outLabel.push_back({testDataLabel[nomI]});
+      }
+      for (i = 0; i < batch; ++i) {
+        testData.erase(testData.begin() + list[i]);
+        testDataLabel.erase(testDataLabel.begin() + list[i]);
+        cur_test_bufsize--;
+      }
+    }
+    break;
+  default:
+    return false;
+    break;
+  }
+  data_lock.unlock();
+  return true;
+}
index 5920439..9043f3c 100644 (file)
@@ -304,7 +304,7 @@ void NeuralNetwork::backwarding(Tensor input, Tensor expected_output, int iterat
   Tensor Y = forwarding(X);
 
   for (unsigned int i = layers.size() - 1; i > 0; i--) {
-    Y2 = layers[i]->backwarding(Y2, i);
+    Y2 = layers[i]->backwarding(Y2, iteration);
   }
 }