*
*/
-#include "include/databuffer.h"
+#include "databuffer.h"
#include <assert.h>
+#include <nntrainer_log.h>
#include <stdio.h>
#include <stdlib.h>
+#include <climits>
#include <condition_variable>
#include <cstring>
#include <functional>
}
void DataBuffer::UpdateData(buffer_type type, std::ifstream &file) {
+ unsigned int max_size = 0;
+ unsigned int buf_size = 0;
+ unsigned int *rest_size = NULL;
+ unsigned int *cur_size = NULL;
+ bool *running = NULL;
+ std::vector<std::vector<float>> *data = NULL;
+ std::vector<std::vector<float>> *datalabel = NULL;
+
switch (type) {
- case BUF_TRAIN: {
- std::vector<unsigned int> mark;
- mark.resize(max_train);
- file.seekg(0, std::ios_base::end);
- int64_t file_length = file.tellg();
+ case BUF_TRAIN:
+ max_size = max_train;
+ buf_size = train_bufsize;
+ rest_size = &rest_train;
+ cur_size = &cur_train_bufsize;
+ running = &train_running;
+ data = &trainData;
+ datalabel = &trainDataLabel;
+ break;
+ case BUF_VAL:
+ max_size = max_val;
+ buf_size = val_bufsize;
+ rest_size = &rest_val;
+ cur_size = &cur_val_bufsize;
+ running = &val_running;
+ data = &valData;
+ datalabel = &valDataLabel;
+ break;
+ case BUF_TEST:
+ max_size = max_test;
+ buf_size = test_bufsize;
+ rest_size = &rest_test;
+ cur_size = &cur_test_bufsize;
+ running = &test_running;
+ data = &testData;
+ datalabel = &testDataLabel;
+ break;
+ default:
+ break;
+ }
- for (unsigned int i = 0; i < max_train; ++i) {
- mark[i] = i;
- }
+ unsigned int I;
+ std::vector<unsigned int> mark;
+ mark.resize(max_size);
+ file.seekg(0, std::ios_base::end);
+ uint64_t file_length = file.tellg();
- unsigned int I;
- while (train_running && mark.size() != 0) {
- if (train_bufsize - cur_train_bufsize > 0 && rest_train > 0) {
- data_lock.lock();
- std::vector<float> vec;
- std::vector<float> veclabel;
+ for (unsigned int i = 0; i < max_size; ++i) {
+ mark[i] = i;
+ }
- unsigned int id = rangeRandom(0, mark.size() - 1);
- I = mark[id];
- if (I > max_test)
- throw std::runtime_error("Error: Test case id cannot exceed maximum number of test");
+ while ((*running) && mark.size() != 0) {
+ if (buf_size - (*cur_size) > 0 && (*rest_size) > 0) {
+ std::vector<float> vec;
+ std::vector<float> veclabel;
- mark.erase(mark.begin() + id);
+ unsigned int id = rangeRandom(0, mark.size() - 1);
+ I = mark[id];
+ if (I > max_size)
+ ml_loge("Error: Test case id cannot exceed maximum number of test");
- int64_t position = (I * input_size + I * class_num) * sizeof(float);
+ mark.erase(mark.begin() + id);
+ uint64_t position = (I * input_size + I * class_num) * sizeof(float);
- if (position > file_length)
- throw std::runtime_error("Error: Cannot exceed max file size");
+ if (position > file_length || position > ULLONG_MAX)
+ ml_loge("Error: Cannot exceed max file size");
- file.seekg(position, std::ios::beg);
+ file.seekg(position, std::ios::beg);
- for (unsigned int j = 0; j < input_size; ++j) {
- float data;
- file.read((char *)&data, sizeof(float));
- vec.push_back(data);
- }
- trainData.push_back(vec);
- for (unsigned int j = 0; j < class_num; ++j) {
- float data;
- file.read((char *)&data, sizeof(float));
- veclabel.push_back(data);
- }
- trainDataLabel.push_back(veclabel);
- rest_train--;
- cur_train_bufsize++;
- data_lock.unlock();
- }
- if (train_bufsize == cur_train_bufsize) {
- std::lock_guard<std::mutex> lgtrain(readyTrainData);
- trainReadyFlag = true;
- cv_train.notify_all();
- }
- }
- } break;
- case BUF_VAL: {
- unsigned int I;
- std::vector<unsigned int> mark;
- mark.resize(max_val);
- file.seekg(0, std::ios_base::end);
- int64_t file_length = file.tellg();
-
- for (unsigned int i = 0; i < max_val; ++i) {
- mark[i] = i;
+ for (unsigned int j = 0; j < input_size; ++j) {
+ float d;
+ file.read((char *)&d, sizeof(float));
+ vec.push_back(d);
}
- while (val_running && mark.size() != 0) {
- if (val_bufsize - cur_val_bufsize > 0 && rest_val > 0) {
- data_lock.lock();
- std::vector<float> vec;
- std::vector<float> veclabel;
-
- unsigned int id = rangeRandom(0, mark.size() - 1);
- I = mark[id];
- if (I > max_test)
- throw std::runtime_error("Error: Test case id cannot exceed maximum number of test");
-
- mark.erase(mark.begin() + id);
-
- int64_t position = (I * input_size + I * class_num) * sizeof(float);
-
- if (position > file_length)
- throw std::runtime_error("Error: Cannot exceed max file size");
-
- file.seekg(position, std::ios::beg);
+ for (unsigned int j = 0; j < class_num; ++j) {
+ float d;
+ file.read((char *)&d, sizeof(float));
+ veclabel.push_back(d);
+ }
- for (unsigned int j = 0; j < input_size; ++j) {
- float data;
- file.read((char *)&data, sizeof(float));
- vec.push_back(data);
- }
- valData.push_back(vec);
- for (unsigned int j = 0; j < class_num; ++j) {
- float data;
- file.read((char *)&data, sizeof(float));
- veclabel.push_back(data);
- }
- valDataLabel.push_back(veclabel);
- rest_val--;
- cur_val_bufsize++;
- data_lock.unlock();
- }
- if (val_bufsize == cur_val_bufsize) {
+ data_lock.lock();
+ data->push_back(vec);
+ datalabel->push_back(veclabel);
+ (*rest_size)--;
+ (*cur_size)++;
+ data_lock.unlock();
+ }
+
+ if (buf_size == (*cur_size)) {
+ switch (type) {
+ case ::BUF_TRAIN: {
+ std::lock_guard<std::mutex> lgtrain(readyTrainData);
+ trainReadyFlag = true;
+ cv_train.notify_all();
+ } break;
+ case ::BUF_VAL: {
std::lock_guard<std::mutex> lgval(readyValData);
valReadyFlag = true;
cv_val.notify_all();
- }
- }
- } break;
- case BUF_TEST: {
- unsigned int I;
- std::vector<int> mark;
- mark.resize(max_test);
- file.seekg(0, std::ios_base::end);
- int64_t file_length = file.tellg();
-
- for (unsigned int i = 0; i < max_test; ++i) {
- mark[i] = i;
- }
-
- while (test_running && mark.size() != 0) {
- if (test_bufsize - cur_test_bufsize >= 0 && rest_test > 0) {
- data_lock.lock();
- std::vector<float> vec;
- std::vector<float> veclabel;
-
- unsigned int id = rangeRandom(0, mark.size() - 1);
- I = mark[id];
- if (I > max_test)
- throw std::runtime_error("Error: Test case id cannot exceed maximum number of test");
-
- mark.erase(mark.begin() + id);
-
- int64_t position = (I * input_size + I * class_num) * sizeof(float);
-
- if (position > file_length)
- throw std::runtime_error("Error: Cannot exceed max file size");
-
- file.seekg(position, std::ios::beg);
-
- for (unsigned int j = 0; j < input_size; ++j) {
- float data;
- file.read((char *)&data, sizeof(float));
- vec.push_back(data);
- }
- testData.push_back(vec);
- for (unsigned int j = 0; j < class_num; ++j) {
- float data;
- file.read((char *)&data, sizeof(float));
- veclabel.push_back(data);
- }
- testDataLabel.push_back(veclabel);
- rest_test--;
- cur_test_bufsize++;
- data_lock.unlock();
- }
- if (test_bufsize == cur_test_bufsize) {
+ } break;
+ case ::BUF_TEST: {
std::lock_guard<std::mutex> lgtest(readyTestData);
testReadyFlag = true;
cv_test.notify_all();
- }
+ } break;
+ default:
+ break;
}
- } break;
- default:
- break;
+ }
}
}