NNTrainerDSM();
~NNTrainerDSM() = default;
+ bool IsFeatureVectorAllowed(unsigned int label_idx) override;
void LoadDataSet(const std::string file_name, unsigned int new_label_cnt) override;
void AddDataSet(std::vector<float> &feature_vec, const unsigned int label_idx,
const unsigned int label_cnt) override;
data_set->LoadDataSet(_config.feature_vector_file_path, label_cnt);
}
- // Add new feature vectors.
- data_set->AddDataSet(input_vec, label_idx, label_cnt);
+ // Add new feature vector only in case that feature vector count of given label_idx is less then 5.
+ // It means that only 5 set of feature vector per a label is valid.
+ // TODO. According to feature vector priority, new feature vector should be added.
+ if (data_set->IsFeatureVectorAllowed(label_idx))
+ data_set->AddDataSet(input_vec, label_idx, label_cnt);
_training_model->ApplyDataSet(data_set);
_training_model->Compile();
#include "machine_learning_exception.h"
#include "nntrainer_dsm.h"
+#define MAX_FEATURE_VECTOR_CNT 5
using namespace std;
using namespace mediavision::machine_learning::exception;
NNTrainerDSM::NNTrainerDSM() : DataSetManager()
{}
+bool NNTrainerDSM::IsFeatureVectorAllowed(unsigned int label_idx)
+{
+ return (_fv_cnt_per_label[label_idx] < MAX_FEATURE_VECTOR_CNT);
+}
+
void NNTrainerDSM::AddDataSet(std::vector<float> &feature_vec, const unsigned int label_idx,
const unsigned int label_cnt)
{
_data.push_back(feature_vec);
_label_index.push_back(label_idx);
+ _fv_cnt_per_label[label_idx]++;
vector<float> oneHotEncoding;
_labels.push_back(label);
_label_index.push_back(label_idx);
+
+ _fv_cnt_per_label[label_idx]++;
}
}
\ No newline at end of file
#include <fstream>
#include <vector>
+#include <map>
#include "feature_vector_manager.h"
class DataSetManager
{
protected:
+ std::map<unsigned int, unsigned int> _fv_cnt_per_label;
std::vector<std::vector<float> > _data;
std::vector<std::vector<float> > _labels;
std::vector<unsigned int> _label_index;
size_t GetFeaVecSize(void);
std::vector<unsigned int> &GetLabelIdx(void);
+ virtual bool IsFeatureVectorAllowed(unsigned int label_idx) = 0;
virtual void LoadDataSet(const std::string file_name, unsigned int new_label_cnt) = 0;
virtual void AddDataSet(std::vector<float> &feature_vec, const unsigned int label_idx,
const unsigned int label_cnt) = 0;
_labels.clear();
_label_index.clear();
+ _fv_cnt_per_label.clear();
}
bool DataSetManager::IsFeatureVectorDuplicated(const vector<float> &vec)