From 227487ca17bfb6917de06a6c81ee2f1d8845d9ff Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Wed, 11 Aug 2021 14:40:14 +0900 Subject: [PATCH] [Dataset] Change finalize signature This patch changes finalize signature to contain user_data, user_data will be needed to replicate the current worker(with diffrent arguments) in case of MT_safty is not provided by the producer. Please note that this does not change any current behavior but the preparation for the future enhancements. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/dataset/data_producers.h | 7 +++++-- nntrainer/dataset/func_data_producer.cpp | 3 ++- nntrainer/dataset/func_data_producer.h | 5 +++-- nntrainer/dataset/random_data_producers.cpp | 2 +- nntrainer/dataset/random_data_producers.h | 5 +++-- nntrainer/dataset/raw_file_data_producer.cpp | 3 ++- nntrainer/dataset/raw_file_data_producer.h | 3 ++- 7 files changed, 18 insertions(+), 10 deletions(-) diff --git a/nntrainer/dataset/data_producers.h b/nntrainer/dataset/data_producers.h index d42ebe5..175e8de 100644 --- a/nntrainer/dataset/data_producers.h +++ b/nntrainer/dataset/data_producers.h @@ -121,12 +121,14 @@ public: * the batch dimension and assume it to be one. * @param input_dims input dimensions. * @param label_dims label dimensions. - * @return Generator generator is a function taht generates a sample upon + * @param user_data user data to be used when finalize. + * @return Generator generator is a function that generates a sample upon * call. */ virtual Generator_sample finalize_sample(const std::vector &input_dims, - const std::vector &label_dims) { + const std::vector &label_dims, + void *user_data = nullptr) { return Generator_sample(); } @@ -166,6 +168,7 @@ public: /** * @brief denote if given producer is thread safe and can be parallelized. + * @note if size() == SIZE_UNDEFIEND, thread safe shall be false * * @return bool true if thread safe. */ diff --git a/nntrainer/dataset/func_data_producer.cpp b/nntrainer/dataset/func_data_producer.cpp index 08b7ec0..498836c 100644 --- a/nntrainer/dataset/func_data_producer.cpp +++ b/nntrainer/dataset/func_data_producer.cpp @@ -93,7 +93,8 @@ FuncDataProducer::finalize(const std::vector &input_dims, DataProducer::Generator_sample FuncDataProducer::finalize_sample(const std::vector &input_dims, - const std::vector &label_dims) { + const std::vector &label_dims, + void *user_data) { NNTR_THROW_IF(!this->cb, std::invalid_argument) << "given callback is nullptr!"; diff --git a/nntrainer/dataset/func_data_producer.h b/nntrainer/dataset/func_data_producer.h index 35f1793..d6caab0 100644 --- a/nntrainer/dataset/func_data_producer.h +++ b/nntrainer/dataset/func_data_producer.h @@ -70,11 +70,12 @@ public: /** * @copydoc DataProducer::finalize_sample(const std::vector, const - * std::vector) + * std::vector, void* user_data) */ DataProducer::Generator_sample finalize_sample(const std::vector &input_dims, - const std::vector &label_dims) override; + const std::vector &label_dims, + void *user_data = nullptr) override; private: datagen_cb cb; diff --git a/nntrainer/dataset/random_data_producers.cpp b/nntrainer/dataset/random_data_producers.cpp index ba07578..18141a3 100644 --- a/nntrainer/dataset/random_data_producers.cpp +++ b/nntrainer/dataset/random_data_producers.cpp @@ -177,7 +177,7 @@ RandomDataOneHotProducer::finalize(const std::vector &input_dims, DataProducer::Generator_sample RandomDataOneHotProducer::finalize_sample( const std::vector &input_dims, - const std::vector &label_dims) { + const std::vector &label_dims, void *user_data) { /** check if the given producer is ready to finalize */ nntrainer::PropsMin min_; nntrainer::PropsMax max_; diff --git a/nntrainer/dataset/random_data_producers.h b/nntrainer/dataset/random_data_producers.h index 8930c93..49a9d51 100644 --- a/nntrainer/dataset/random_data_producers.h +++ b/nntrainer/dataset/random_data_producers.h @@ -78,11 +78,12 @@ public: /** * @copydoc DataProducer::finalize_sample(const std::vector, const - * std::vector) + * std::vector, void *) */ DataProducer::Generator_sample finalize_sample(const std::vector &input_dims, - const std::vector &label_dims) override; + const std::vector &label_dims, + void *user_data = nullptr) override; /** * @copydoc DataProducer::size_sample(const std::vector, const diff --git a/nntrainer/dataset/raw_file_data_producer.cpp b/nntrainer/dataset/raw_file_data_producer.cpp index f19e3a4..89ffae2 100644 --- a/nntrainer/dataset/raw_file_data_producer.cpp +++ b/nntrainer/dataset/raw_file_data_producer.cpp @@ -161,7 +161,8 @@ RawFileDataProducer::finalize(const std::vector &input_dims, DataProducer::Generator_sample RawFileDataProducer::finalize_sample(const std::vector &input_dims, - const std::vector &label_dims) { + const std::vector &label_dims, + void *user_data) { auto sz = size(input_dims, label_dims); auto path_prop = std::get(*raw_file_props); diff --git a/nntrainer/dataset/raw_file_data_producer.h b/nntrainer/dataset/raw_file_data_producer.h index 22f39ca..39593e9 100644 --- a/nntrainer/dataset/raw_file_data_producer.h +++ b/nntrainer/dataset/raw_file_data_producer.h @@ -93,7 +93,8 @@ public: */ DataProducer::Generator_sample finalize_sample(const std::vector &input_dims, - const std::vector &label_dims) override; + const std::vector &label_dims, + void *user_data = nullptr) override; /** * @copydoc DataProducer::size_sample(const std::vector, const -- 2.7.4