From 5b5bb01f1084ab54b1e180ef60862a212a95f9b0 Mon Sep 17 00:00:00 2001 From: Jihoon Lee Date: Mon, 3 May 2021 15:10:51 +0900 Subject: [PATCH] [Props] Add dimension property This patch add dimension property **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Jihoon Lee --- nntrainer/utils/base_properties.cpp | 36 +++++++++++++ nntrainer/utils/base_properties.h | 1 + test/unittest/unittest_properties.cpp | 73 +++++++++++++++++++++++++-- 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/nntrainer/utils/base_properties.cpp b/nntrainer/utils/base_properties.cpp index 64435410..4cdedb24 100644 --- a/nntrainer/utils/base_properties.cpp +++ b/nntrainer/utils/base_properties.cpp @@ -11,7 +11,9 @@ */ #include +#include #include +#include namespace nntrainer { @@ -48,4 +50,38 @@ unsigned int str_converter::from_string( const std::string &value) { return std::stoul(value); } + +template <> +std::string str_converter::to_string( + const TensorDim &dimension) { + std::stringstream ss; + ss << dimension.batch() << ':' << dimension.channel() << ':' + << dimension.height() << ':' << dimension.width(); + return ss.str(); +} + +template <> +TensorDim str_converter::from_string( + const std::string &value) { + std::vector tokens; + std::string token; + std::istringstream iss(value); + + while (std::getline(iss, token, ':')) { + tokens.push_back(token); + } + + NNTR_THROW_IF(tokens.size() > MAXDIM, std::invalid_argument) + << "More than 4 axes is not supported, target string: " << value; + + TensorDim target; + + int cur_axis = 3; + for (auto iter = tokens.rbegin(); iter != tokens.rend(); iter++) { + target.setTensorDim(cur_axis--, std::stoul(*iter)); + } + + return target; +} + } // namespace nntrainer diff --git a/nntrainer/utils/base_properties.h b/nntrainer/utils/base_properties.h index 6ec21ef1..69cfdf72 100644 --- a/nntrainer/utils/base_properties.h +++ b/nntrainer/utils/base_properties.h @@ -11,6 +11,7 @@ */ #include #include +#include #ifndef __BASE_PROPERTIES_H__ #define __BASE_PROPERTIES_H__ diff --git a/test/unittest/unittest_properties.cpp b/test/unittest/unittest_properties.cpp index a3202522..4e0ddc9e 100644 --- a/test/unittest/unittest_properties.cpp +++ b/test/unittest/unittest_properties.cpp @@ -54,6 +54,22 @@ public: return nntrainer::endswith(v, "good"); } }; + +/** + * @brief DimensionOfBanana property for example, this has to have batch size of + * 1 + * + */ +class DimensionOfBanana : public nntrainer::Property { +public: + static constexpr const char *key = "banana_size"; + using prop_tag = nntrainer::dimension_prop_tag; + + bool isValid(const nntrainer::TensorDim &dim) const override { + std::cerr << dim; + return dim.batch() == 1; + } +}; } // namespace TEST(BasicProperty, tagCast) { @@ -111,6 +127,27 @@ TEST(BasicProperty, valid_p) { EXPECT_EQ(nntrainer::to_string(q), "this is good"); } + { /** set -> get / to_string, dimension*/ + DimensionOfBanana q; + q.set({1, 2, 3, 4}); + EXPECT_EQ(q.get(), nntrainer::TensorDim(1, 2, 3, 4)); + EXPECT_EQ(nntrainer::to_string(q), "1:2:3:4"); + } + + { /**< from_string -> get / to_string, dimension */ + DimensionOfBanana q; + nntrainer::from_string("1:2:3:4", q); + EXPECT_EQ(q.get(), nntrainer::TensorDim(1, 2, 3, 4)); + EXPECT_EQ(nntrainer::to_string(q), "1:2:3:4"); + } + + { /**< from_string -> get / to_string, dimension */ + DimensionOfBanana q; + nntrainer::from_string("3:4", q); + EXPECT_EQ(q.get(), nntrainer::TensorDim(1, 1, 3, 4)); + EXPECT_EQ(nntrainer::to_string(q), "1:1:3:4"); + } + { /**< exporter test */ auto props = std::make_tuple(NumBanana(), QualityOfBanana()); @@ -139,12 +176,13 @@ TEST(BasicProperty, valid_p) { } { /**< load from layer */ - auto props = std::make_tuple(NumBanana(), QualityOfBanana()); + auto props = + std::make_tuple(NumBanana(), QualityOfBanana(), DimensionOfBanana()); - auto v = - nntrainer::loadProperties({"num_banana=2", "quality_banana=thisisgood", - "num_banana=42", "not_used=key"}, - props); + auto v = nntrainer::loadProperties( + {"num_banana=2", "quality_banana=thisisgood", "num_banana=42", + "banana_size=2:2:3", "not_used=key"}, + props); EXPECT_EQ(v, std::vector{"not_used=key"}); EXPECT_EQ(std::get<0>(props).get(), 42); @@ -162,6 +200,11 @@ TEST(BasicProperty, setNotValid_02_n) { EXPECT_THROW(q.set("invalid_str"), std::invalid_argument); } +TEST(BasicProperty, setNotValid_03_n) { + DimensionOfBanana d; + EXPECT_THROW(d.set({3, 3, 2, 4}), std::invalid_argument); +} + TEST(BasicProperty, fromStringNotValid_01_n) { NumBanana b; EXPECT_THROW(nntrainer::from_string("not integer", b), std::invalid_argument); @@ -177,6 +220,26 @@ TEST(BasicProperty, fromStringNotValid_03_n) { EXPECT_THROW(nntrainer::from_string("invalid_str", q), std::invalid_argument); } +TEST(BasicProperty, fromStringNotValid_04_n) { + DimensionOfBanana d; + EXPECT_THROW(nntrainer::from_string("1:1:2:3:5", d), std::invalid_argument); +} + +TEST(BasicProperty, fromStringNotValid_05_n) { + DimensionOfBanana d; + EXPECT_THROW(nntrainer::from_string("2:2:3:5", d), std::invalid_argument); +} + +TEST(BasicProperty, fromStringNotValid_06_n) { + DimensionOfBanana d; + EXPECT_THROW(nntrainer::from_string("", d), std::invalid_argument); +} + +TEST(BasicProperty, fromStringNotValid_07_n) { + DimensionOfBanana d; + EXPECT_THROW(nntrainer::from_string(":2:3:5", d), std::invalid_argument); +} + TEST(Exporter, invalidMethods_n) { auto props = std::make_tuple(NumBanana(), QualityOfBanana()); -- 2.34.1