[Dataset] Add user_data set Property handler
authorJihoon Lee <jhoon.it.lee@samsung.com>
Thu, 15 Jul 2021 10:14:22 +0000 (19:14 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 18 Aug 2021 07:05:47 +0000 (16:05 +0900)
**Changes proposed in this PR:**
- PtrType properties
- Add user_data set property handler for backward competibility

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
api/capi/src/nntrainer.cpp
api/ccapi/include/dataset.h
nntrainer/dataset/databuffer.cpp
nntrainer/dataset/databuffer.h
nntrainer/dataset/func_data_producer.cpp
nntrainer/dataset/func_data_producer.h
nntrainer/utils/base_properties.h
test/unittest/unittest_base_properties.cpp

index fee9c7c..0f6f66c 100644 (file)
@@ -868,6 +868,7 @@ static int
 ml_train_dataset_set_property_for_mode_(ml_train_dataset_h dataset,
                                         ml_train_dataset_mode_e mode,
                                         const std::vector<void *> &args) {
+  static constexpr char USER_DATA[] = "user_data";
   int status = ML_ERROR_NONE;
   ml_train_dataset *nndataset;
 
@@ -888,7 +889,42 @@ ml_train_dataset_set_property_for_mode_(ml_train_dataset_h dataset,
         return status_;
       }
 
-      status_ = db->setProperty(args);
+      std::vector<std::string> properties;
+      for (unsigned int i = 0; i < args.size(); ++i) {
+        char *key_ptr = (char *)args[i];
+        std::string key = key_ptr;
+        std::for_each(key.begin(), key.end(),
+                      [](char &c) { c = ::tolower(c); });
+        key.erase(std::remove_if(key.begin(), key.end(), ::isspace), key.end());
+
+        /** Handle the user_data as a special case, serialize the address and
+         * pass it to the databuffer */
+        if (key == USER_DATA) {
+          /** This ensures that a valid user_data element is passed by the user
+           */
+          if (i + 1 >= args.size()) {
+            ml_loge("key user_data expects, next value to be a pointer");
+            status_ = ML_ERROR_INVALID_PARAMETER;
+            return status_;
+          }
+          std::ostringstream ss;
+          ss << key << '=' << args[i + 1];
+          properties.push_back(ss.str());
+
+          /** As values of i+1 is consumed, increase i by 1 */
+          i++;
+        } else if (key.rfind("user_data=", 0) == 0) {
+          /** case that user tries to pass something like user_data=5, this is
+           * not allowed */
+          status_ = ML_ERROR_INVALID_PARAMETER;
+          return status_;
+        } else {
+          properties.push_back(key);
+          continue;
+        }
+      }
+
+      db->setProperty(properties);
       return status_;
     };
 
index b659bda..926ab75 100644 (file)
@@ -71,18 +71,6 @@ public:
    *  { std::string property_name, std::string property_val, ...}
    */
   virtual void setProperty(const std::vector<std::string> &values) = 0;
-
-  /**
-   * @brief     set property to allow setting non-string values such as
-   * user_data for callbacks
-   * @param[in] values values of property
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   * @note      this is a superset of the setProperty(std::vector<std::string>)
-   * @details   Properties (values) is in the format -
-   *  { std::string property_name, void * property_val, ...}
-   */
-  virtual int setProperty(std::vector<void *> values) = 0;
 };
 
 /**
index 015fa69..fd43645 100644 (file)
@@ -163,36 +163,6 @@ void DataBuffer::displayProgress(const int count, float loss) {
   std::cout.flush();
 }
 
-int DataBuffer::setProperty(std::vector<void *> values) {
-  int status = ML_ERROR_NONE;
-  std::vector<std::string> properties;
-
-  for (unsigned int i = 0; i < values.size(); ++i) {
-    char *key_ptr = (char *)values[i];
-    std::string key = key_ptr;
-    std::string value;
-
-    /** Handle the user_data as a special case */
-    if (key == USER_DATA) {
-      /** This ensures that a valid user_data element is passed by the user */
-      if (i + 1 >= values.size())
-        return ML_ERROR_INVALID_PARAMETER;
-
-      this->user_data = values[i + 1];
-
-      /** As values of i+1 is consumed, increase i by 1 */
-      i++;
-    } else {
-      properties.push_back(key);
-      continue;
-    }
-  }
-
-  setProperty(properties);
-
-  return status;
-}
-
 void DataBuffer::setProperty(const std::vector<std::string> &values) {
   auto left = loadProperties(values, *db_props);
   if (producer) {
index c72d198..bcbda2f 100644 (file)
@@ -122,15 +122,6 @@ public:
    */
   void setProperty(const std::vector<std::string> &values) override;
 
-  /**
-   * @brief     set property to allow setting user_data for cb
-   * @todo   deprecate
-   * @param[in] values values of property
-   * @retval #ML_ERROR_NONE Successful.
-   * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
-   */
-  int setProperty(std::vector<void *> values);
-
 protected:
   std::shared_ptr<DataProducer> producer;
   std::weak_ptr<BatchQueue> bq_view;
index 31b4ab4..433f436 100644 (file)
 
 #include <func_data_producer.h>
 
+#include <base_properties.h>
 #include <nntrainer_error.h>
+#include <node_exporter.h>
 
 namespace nntrainer {
 
+/**
+ * @brief User data props
+ *
+ */
+class PropsUserData final : public Property<void *> {
+public:
+  static constexpr const char *key = "user_data";
+  PropsUserData(void *user_data) { set(user_data); }
+  using prop_tag = ptr_prop_tag;
+};
+
 FuncDataProducer::FuncDataProducer(datagen_cb datagen_cb, void *user_data_) :
   cb(datagen_cb),
-  user_data(user_data_) {}
+  user_data_prop(new PropsUserData(user_data_)) {}
 
 FuncDataProducer::~FuncDataProducer() {}
 
@@ -28,7 +41,8 @@ const std::string FuncDataProducer::getType() const {
 }
 
 void FuncDataProducer::setProperty(const std::vector<std::string> &properties) {
-  NNTR_THROW_IF(!properties.empty(), std::invalid_argument)
+  auto left = loadProperties(properties, std::tie(*user_data_prop));
+  NNTR_THROW_IF(!left.empty(), std::invalid_argument)
     << "properties is not empty, size: " << properties.size();
 }
 
@@ -43,8 +57,8 @@ FuncDataProducer::finalize(const std::vector<TensorDim> &input_dims,
   auto label_data = std::shared_ptr<float *>(new float *[label_dims.size()],
                                              std::default_delete<float *[]>());
 
-  return [cb = this->cb, ud = this->user_data, input_dims, label_dims,
-          input_data, label_data]() -> DataProducer::Iteration {
+  return [cb = this->cb, ud = this->user_data_prop->get(), input_dims,
+          label_dims, input_data, label_data]() -> DataProducer::Iteration {
     std::vector<Tensor> inputs;
     inputs.reserve(input_dims.size());
 
index 10fc2dc..4e5d375 100644 (file)
@@ -23,6 +23,8 @@
 
 namespace nntrainer {
 
+class PropsUserData;
+
 using datagen_cb = ml::train::datagen_cb;
 
 /**
@@ -68,7 +70,7 @@ public:
 
 private:
   datagen_cb cb;
-  void *user_data;
+  std::unique_ptr<PropsUserData> user_data_prop;
 };
 
 } // namespace nntrainer
index 77b234f..7ef4d41 100644 (file)
@@ -117,6 +117,12 @@ struct bool_prop_tag {};
 struct enum_class_prop_tag {};
 
 /**
+ * @brief property treated as a raw pointer
+ *
+ */
+struct ptr_prop_tag {};
+
+/**
  * @brief base property class, inherit this to make a convenient property
  *
  * @tparam T
@@ -317,6 +323,39 @@ template <typename Tag, typename DataType> struct str_converter {
 };
 
 /**
+ * @brief str converter which serializes a pointer and returns back to a ptr
+ *
+ * @tparam DataType pointer type
+ */
+template <typename DataType> struct str_converter<ptr_prop_tag, DataType> {
+
+  /**
+   * @brief convert underlying value to string
+   *
+   * @param value value to convert to string
+   * @return std::string string
+   */
+  static std::string to_string(const DataType &value) {
+    std::ostringstream ss;
+    ss << value;
+    return ss.str();
+  }
+
+  /**
+   * @brief convert string to underlying value
+   *
+   * @param value value to convert to string
+   * @return DataType converted type
+   */
+  static DataType from_string(const std::string &value) {
+    std::stringstream ss(value);
+    uintptr_t addr = static_cast<uintptr_t>(std::stoull(value, 0, 16));
+    std::cerr << "value: " << value << " addr: " << addr;
+    return reinterpret_cast<DataType>(addr);
+  }
+};
+
+/**
  * @copydoc template <typename Tag, typename DataType> struct str_converter
  */
 template <>
index d06aa07..0bfa649 100644 (file)
@@ -50,7 +50,7 @@ public:
 class QualityOfBanana : public nntrainer::Property<std::string> {
 public:
   QualityOfBanana() : nntrainer::Property<std::string>() {}
-  QualityOfBanana(const char *value) : Property<std::string>(value) {}
+  QualityOfBanana(const char *value) { set(value); }
   static constexpr const char *key = "quality_banana";
   using prop_tag = nntrainer::str_prop_tag;
 
@@ -66,8 +66,7 @@ public:
  */
 class MarkAsGoodBanana : public nntrainer::Property<bool> {
 public:
-  MarkAsGoodBanana(bool val = true) :
-    Property<bool>(val) {}                        /**< default value if any */
+  MarkAsGoodBanana(bool val = true) { set(val); } /**< default value if any */
   static constexpr const char *key = "mark_good"; /**< unique key to access */
   using prop_tag = nntrainer::bool_prop_tag;      /**< property type */
 };
@@ -78,10 +77,9 @@ public:
  */
 class FreshnessOfBanana : public nntrainer::Property<float> {
 public:
-  FreshnessOfBanana(float val = 0.0) :
-    Property<float>(val) {}                       /**< default value if any */
-  static constexpr const char *key = "how_fresh"; /**< unique key to access */
-  using prop_tag = nntrainer::float_prop_tag;     /**< property type */
+  FreshnessOfBanana(float val = 0.0) { set(val); } /**< default value if any */
+  static constexpr const char *key = "how_fresh";  /**< unique key to access */
+  using prop_tag = nntrainer::float_prop_tag;      /**< property type */
 };
 
 /**
@@ -95,10 +93,19 @@ public:
   using prop_tag = nntrainer::dimension_prop_tag;
 
   bool isValid(const nntrainer::TensorDim &dim) const override {
-    std::cerr << dim;
     return dim.batch() == 1;
   }
 };
+
+/**
+ * @brief Pointer of banana property
+ *
+ */
+class PtrOfBanana : public nntrainer::Property<int *> {
+public:
+  static constexpr const char *key = "ptr_banana";
+  using prop_tag = nntrainer::ptr_prop_tag;
+};
 } // namespace
 
 TEST(BasicProperty, tagCast) {
@@ -204,6 +211,28 @@ TEST(BasicProperty, valid_p) {
     EXPECT_EQ(nntrainer::to_string(q), "true");
   }
 
+  { /**< from_string -> get / to_string, ptr */
+    PtrOfBanana pb;
+    int a = 1;
+    std::ostringstream ss;
+    ss << &a;
+    nntrainer::from_string(ss.str(), pb);
+    EXPECT_EQ(pb.get(), &a);
+    EXPECT_EQ(*pb.get(), a);
+    EXPECT_EQ(nntrainer::to_string(pb), ss.str());
+  }
+
+  { /** set -> get / to_string, boolean*/
+    PtrOfBanana pb;
+    int a = 1;
+    pb.set(&a);
+    EXPECT_EQ(pb.get(), &a);
+    EXPECT_EQ(*pb.get(), 1);
+    std::ostringstream ss;
+    ss << &a;
+    EXPECT_EQ(nntrainer::to_string(pb), ss.str());
+  }
+
   { /**< from_string -> get / to_string, float */
     FreshnessOfBanana q;
     nntrainer::from_string("1.3245", q);