* the batch dimension and assume it to be one.
* @param input_dims input dimensions.
* @param label_dims label dimensions.
- * @return Generator generator is a function taht generates a sample upon
+ * @param user_data user data to be used when finalize.
+ * @return Generator generator is a function that generates a sample upon
* call.
*/
virtual Generator_sample
finalize_sample(const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) {
+ const std::vector<TensorDim> &label_dims,
+ void *user_data = nullptr) {
return Generator_sample();
}
/**
* @brief denote if given producer is thread safe and can be parallelized.
+ * @note if size() == SIZE_UNDEFIEND, thread safe shall be false
*
* @return bool true if thread safe.
*/
DataProducer::Generator_sample
FuncDataProducer::finalize_sample(const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) {
+ const std::vector<TensorDim> &label_dims,
+ void *user_data) {
NNTR_THROW_IF(!this->cb, std::invalid_argument)
<< "given callback is nullptr!";
/**
* @copydoc DataProducer::finalize_sample(const std::vector<TensorDim>, const
- * std::vector<TensorDim>)
+ * std::vector<TensorDim>, void* user_data)
*/
DataProducer::Generator_sample
finalize_sample(const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) override;
+ const std::vector<TensorDim> &label_dims,
+ void *user_data = nullptr) override;
private:
datagen_cb cb;
DataProducer::Generator_sample RandomDataOneHotProducer::finalize_sample(
const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) {
+ const std::vector<TensorDim> &label_dims, void *user_data) {
/** check if the given producer is ready to finalize */
nntrainer::PropsMin min_;
nntrainer::PropsMax max_;
/**
* @copydoc DataProducer::finalize_sample(const std::vector<TensorDim>, const
- * std::vector<TensorDim>)
+ * std::vector<TensorDim>, void *)
*/
DataProducer::Generator_sample
finalize_sample(const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) override;
+ const std::vector<TensorDim> &label_dims,
+ void *user_data = nullptr) override;
/**
* @copydoc DataProducer::size_sample(const std::vector<TensorDim>, const
DataProducer::Generator_sample
RawFileDataProducer::finalize_sample(const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) {
+ const std::vector<TensorDim> &label_dims,
+ void *user_data) {
auto sz = size(input_dims, label_dims);
auto path_prop = std::get<props::FilePath>(*raw_file_props);
*/
DataProducer::Generator_sample
finalize_sample(const std::vector<TensorDim> &input_dims,
- const std::vector<TensorDim> &label_dims) override;
+ const std::vector<TensorDim> &label_dims,
+ void *user_data = nullptr) override;
/**
* @copydoc DataProducer::size_sample(const std::vector<TensorDim>, const