[Dataset] Change finalize signature
authorJihoon Lee <jhoon.it.lee@samsung.com>
Wed, 11 Aug 2021 05:40:14 +0000 (14:40 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 24 Aug 2021 05:15:09 +0000 (14:15 +0900)
This patch changes finalize signature to contain user_data,
user_data will be needed to replicate the current worker(with diffrent
arguments) in case of
MT_safty is not provided by the producer. Please note that this does not
change any current behavior but the preparation for the future
enhancements.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/dataset/data_producers.h
nntrainer/dataset/func_data_producer.cpp
nntrainer/dataset/func_data_producer.h
nntrainer/dataset/random_data_producers.cpp
nntrainer/dataset/random_data_producers.h
nntrainer/dataset/raw_file_data_producer.cpp
nntrainer/dataset/raw_file_data_producer.h

index d42ebe5..175e8de 100644 (file)
@@ -121,12 +121,14 @@ public:
    * the batch dimension and assume it to be one.
    * @param input_dims input dimensions.
    * @param label_dims label dimensions.
-   * @return Generator generator is a function taht generates a sample upon
+   * @param user_data user data to be used when finalize.
+   * @return Generator generator is a function that generates a sample upon
    * call.
    */
   virtual Generator_sample
   finalize_sample(const std::vector<TensorDim> &input_dims,
-                  const std::vector<TensorDim> &label_dims) {
+                  const std::vector<TensorDim> &label_dims,
+                  void *user_data = nullptr) {
     return Generator_sample();
   }
 
@@ -166,6 +168,7 @@ public:
 
   /**
    * @brief denote if given producer is thread safe and can be parallelized.
+   * @note if size() == SIZE_UNDEFIEND, thread safe shall be false
    *
    * @return bool true if thread safe.
    */
index 08b7ec0..498836c 100644 (file)
@@ -93,7 +93,8 @@ FuncDataProducer::finalize(const std::vector<TensorDim> &input_dims,
 
 DataProducer::Generator_sample
 FuncDataProducer::finalize_sample(const std::vector<TensorDim> &input_dims,
-                                  const std::vector<TensorDim> &label_dims) {
+                                  const std::vector<TensorDim> &label_dims,
+                                  void *user_data) {
   NNTR_THROW_IF(!this->cb, std::invalid_argument)
     << "given callback is nullptr!";
 
index 35f1793..d6caab0 100644 (file)
@@ -70,11 +70,12 @@ public:
 
   /**
    * @copydoc DataProducer::finalize_sample(const std::vector<TensorDim>, const
-   * std::vector<TensorDim>)
+   * std::vector<TensorDim>, void* user_data)
    */
   DataProducer::Generator_sample
   finalize_sample(const std::vector<TensorDim> &input_dims,
-                  const std::vector<TensorDim> &label_dims) override;
+                  const std::vector<TensorDim> &label_dims,
+                  void *user_data = nullptr) override;
 
 private:
   datagen_cb cb;
index ba07578..18141a3 100644 (file)
@@ -177,7 +177,7 @@ RandomDataOneHotProducer::finalize(const std::vector<TensorDim> &input_dims,
 
 DataProducer::Generator_sample RandomDataOneHotProducer::finalize_sample(
   const std::vector<TensorDim> &input_dims,
-  const std::vector<TensorDim> &label_dims) {
+  const std::vector<TensorDim> &label_dims, void *user_data) {
   /** check if the given producer is ready to finalize */
   nntrainer::PropsMin min_;
   nntrainer::PropsMax max_;
index 8930c93..49a9d51 100644 (file)
@@ -78,11 +78,12 @@ public:
 
   /**
    * @copydoc DataProducer::finalize_sample(const std::vector<TensorDim>, const
-   * std::vector<TensorDim>)
+   * std::vector<TensorDim>, void *)
    */
   DataProducer::Generator_sample
   finalize_sample(const std::vector<TensorDim> &input_dims,
-                  const std::vector<TensorDim> &label_dims) override;
+                  const std::vector<TensorDim> &label_dims,
+                  void *user_data = nullptr) override;
 
   /**
    * @copydoc DataProducer::size_sample(const std::vector<TensorDim>, const
index f19e3a4..89ffae2 100644 (file)
@@ -161,7 +161,8 @@ RawFileDataProducer::finalize(const std::vector<TensorDim> &input_dims,
 
 DataProducer::Generator_sample
 RawFileDataProducer::finalize_sample(const std::vector<TensorDim> &input_dims,
-                                     const std::vector<TensorDim> &label_dims) {
+                                     const std::vector<TensorDim> &label_dims,
+                                     void *user_data) {
   auto sz = size(input_dims, label_dims);
   auto path_prop = std::get<props::FilePath>(*raw_file_props);
 
index 22f39ca..39593e9 100644 (file)
@@ -93,7 +93,8 @@ public:
    */
   DataProducer::Generator_sample
   finalize_sample(const std::vector<TensorDim> &input_dims,
-                  const std::vector<TensorDim> &label_dims) override;
+                  const std::vector<TensorDim> &label_dims,
+                  void *user_data = nullptr) override;
 
   /**
    * @copydoc DataProducer::size_sample(const std::vector<TensorDim>, const