ml_train_dataset_set_property_for_mode_(ml_train_dataset_h dataset,
ml_train_dataset_mode_e mode,
const std::vector<void *> &args) {
+ static constexpr char USER_DATA[] = "user_data";
int status = ML_ERROR_NONE;
ml_train_dataset *nndataset;
return status_;
}
- status_ = db->setProperty(args);
+ std::vector<std::string> properties;
+ for (unsigned int i = 0; i < args.size(); ++i) {
+ char *key_ptr = (char *)args[i];
+ std::string key = key_ptr;
+ std::for_each(key.begin(), key.end(),
+ [](char &c) { c = ::tolower(c); });
+ key.erase(std::remove_if(key.begin(), key.end(), ::isspace), key.end());
+
+ /** Handle the user_data as a special case, serialize the address and
+ * pass it to the databuffer */
+ if (key == USER_DATA) {
+ /** This ensures that a valid user_data element is passed by the user
+ */
+ if (i + 1 >= args.size()) {
+ ml_loge("key user_data expects, next value to be a pointer");
+ status_ = ML_ERROR_INVALID_PARAMETER;
+ return status_;
+ }
+ std::ostringstream ss;
+ ss << key << '=' << args[i + 1];
+ properties.push_back(ss.str());
+
+ /** As values of i+1 is consumed, increase i by 1 */
+ i++;
+ } else if (key.rfind("user_data=", 0) == 0) {
+ /** case that user tries to pass something like user_data=5, this is
+ * not allowed */
+ status_ = ML_ERROR_INVALID_PARAMETER;
+ return status_;
+ } else {
+ properties.push_back(key);
+ continue;
+ }
+ }
+
+ db->setProperty(properties);
return status_;
};
* { std::string property_name, std::string property_val, ...}
*/
virtual void setProperty(const std::vector<std::string> &values) = 0;
-
- /**
- * @brief set property to allow setting non-string values such as
- * user_data for callbacks
- * @param[in] values values of property
- * @retval #ML_ERROR_NONE Successful.
- * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
- * @note this is a superset of the setProperty(std::vector<std::string>)
- * @details Properties (values) is in the format -
- * { std::string property_name, void * property_val, ...}
- */
- virtual int setProperty(std::vector<void *> values) = 0;
};
/**
std::cout.flush();
}
-int DataBuffer::setProperty(std::vector<void *> values) {
- int status = ML_ERROR_NONE;
- std::vector<std::string> properties;
-
- for (unsigned int i = 0; i < values.size(); ++i) {
- char *key_ptr = (char *)values[i];
- std::string key = key_ptr;
- std::string value;
-
- /** Handle the user_data as a special case */
- if (key == USER_DATA) {
- /** This ensures that a valid user_data element is passed by the user */
- if (i + 1 >= values.size())
- return ML_ERROR_INVALID_PARAMETER;
-
- this->user_data = values[i + 1];
-
- /** As values of i+1 is consumed, increase i by 1 */
- i++;
- } else {
- properties.push_back(key);
- continue;
- }
- }
-
- setProperty(properties);
-
- return status;
-}
-
void DataBuffer::setProperty(const std::vector<std::string> &values) {
auto left = loadProperties(values, *db_props);
if (producer) {
*/
void setProperty(const std::vector<std::string> &values) override;
- /**
- * @brief set property to allow setting user_data for cb
- * @todo deprecate
- * @param[in] values values of property
- * @retval #ML_ERROR_NONE Successful.
- * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
- */
- int setProperty(std::vector<void *> values);
-
protected:
std::shared_ptr<DataProducer> producer;
std::weak_ptr<BatchQueue> bq_view;
#include <func_data_producer.h>
+#include <base_properties.h>
#include <nntrainer_error.h>
+#include <node_exporter.h>
namespace nntrainer {
+/**
+ * @brief User data props
+ *
+ */
+class PropsUserData final : public Property<void *> {
+public:
+ static constexpr const char *key = "user_data";
+ PropsUserData(void *user_data) { set(user_data); }
+ using prop_tag = ptr_prop_tag;
+};
+
FuncDataProducer::FuncDataProducer(datagen_cb datagen_cb, void *user_data_) :
cb(datagen_cb),
- user_data(user_data_) {}
+ user_data_prop(new PropsUserData(user_data_)) {}
FuncDataProducer::~FuncDataProducer() {}
}
void FuncDataProducer::setProperty(const std::vector<std::string> &properties) {
- NNTR_THROW_IF(!properties.empty(), std::invalid_argument)
+ auto left = loadProperties(properties, std::tie(*user_data_prop));
+ NNTR_THROW_IF(!left.empty(), std::invalid_argument)
<< "properties is not empty, size: " << properties.size();
}
auto label_data = std::shared_ptr<float *>(new float *[label_dims.size()],
std::default_delete<float *[]>());
- return [cb = this->cb, ud = this->user_data, input_dims, label_dims,
- input_data, label_data]() -> DataProducer::Iteration {
+ return [cb = this->cb, ud = this->user_data_prop->get(), input_dims,
+ label_dims, input_data, label_data]() -> DataProducer::Iteration {
std::vector<Tensor> inputs;
inputs.reserve(input_dims.size());
namespace nntrainer {
+class PropsUserData;
+
using datagen_cb = ml::train::datagen_cb;
/**
private:
datagen_cb cb;
- void *user_data;
+ std::unique_ptr<PropsUserData> user_data_prop;
};
} // namespace nntrainer
*/
struct enum_class_prop_tag {};
+/**
+ * @brief property treated as a raw pointer
+ *
+ */
+struct ptr_prop_tag {};
+
/**
* @brief base property class, inherit this to make a convenient property
*
static DataType from_string(const std::string &value);
};
+/**
+ * @brief str converter which serializes a pointer and returns back to a ptr
+ *
+ * @tparam DataType pointer type
+ */
+template <typename DataType> struct str_converter<ptr_prop_tag, DataType> {
+
+ /**
+ * @brief convert underlying value to string
+ *
+ * @param value value to convert to string
+ * @return std::string string
+ */
+ static std::string to_string(const DataType &value) {
+ std::ostringstream ss;
+ ss << value;
+ return ss.str();
+ }
+
+ /**
+ * @brief convert string to underlying value
+ *
+ * @param value value to convert to string
+ * @return DataType converted type
+ */
+ static DataType from_string(const std::string &value) {
+ std::stringstream ss(value);
+ uintptr_t addr = static_cast<uintptr_t>(std::stoull(value, 0, 16));
+ std::cerr << "value: " << value << " addr: " << addr;
+ return reinterpret_cast<DataType>(addr);
+ }
+};
+
/**
* @copydoc template <typename Tag, typename DataType> struct str_converter
*/
class QualityOfBanana : public nntrainer::Property<std::string> {
public:
QualityOfBanana() : nntrainer::Property<std::string>() {}
- QualityOfBanana(const char *value) : Property<std::string>(value) {}
+ QualityOfBanana(const char *value) { set(value); }
static constexpr const char *key = "quality_banana";
using prop_tag = nntrainer::str_prop_tag;
*/
class MarkAsGoodBanana : public nntrainer::Property<bool> {
public:
- MarkAsGoodBanana(bool val = true) :
- Property<bool>(val) {} /**< default value if any */
+ MarkAsGoodBanana(bool val = true) { set(val); } /**< default value if any */
static constexpr const char *key = "mark_good"; /**< unique key to access */
using prop_tag = nntrainer::bool_prop_tag; /**< property type */
};
*/
class FreshnessOfBanana : public nntrainer::Property<float> {
public:
- FreshnessOfBanana(float val = 0.0) :
- Property<float>(val) {} /**< default value if any */
- static constexpr const char *key = "how_fresh"; /**< unique key to access */
- using prop_tag = nntrainer::float_prop_tag; /**< property type */
+ FreshnessOfBanana(float val = 0.0) { set(val); } /**< default value if any */
+ static constexpr const char *key = "how_fresh"; /**< unique key to access */
+ using prop_tag = nntrainer::float_prop_tag; /**< property type */
};
/**
using prop_tag = nntrainer::dimension_prop_tag;
bool isValid(const nntrainer::TensorDim &dim) const override {
- std::cerr << dim;
return dim.batch() == 1;
}
};
+
+/**
+ * @brief Pointer of banana property
+ *
+ */
+class PtrOfBanana : public nntrainer::Property<int *> {
+public:
+ static constexpr const char *key = "ptr_banana";
+ using prop_tag = nntrainer::ptr_prop_tag;
+};
} // namespace
TEST(BasicProperty, tagCast) {
EXPECT_EQ(nntrainer::to_string(q), "true");
}
+ { /**< from_string -> get / to_string, ptr */
+ PtrOfBanana pb;
+ int a = 1;
+ std::ostringstream ss;
+ ss << &a;
+ nntrainer::from_string(ss.str(), pb);
+ EXPECT_EQ(pb.get(), &a);
+ EXPECT_EQ(*pb.get(), a);
+ EXPECT_EQ(nntrainer::to_string(pb), ss.str());
+ }
+
+ { /** set -> get / to_string, boolean*/
+ PtrOfBanana pb;
+ int a = 1;
+ pb.set(&a);
+ EXPECT_EQ(pb.get(), &a);
+ EXPECT_EQ(*pb.get(), 1);
+ std::ostringstream ss;
+ ss << &a;
+ EXPECT_EQ(nntrainer::to_string(pb), ss.str());
+ }
+
{ /**< from_string -> get / to_string, float */
FreshnessOfBanana q;
nntrainer::from_string("1.3245", q);