#include <parse_util.h>
#include <regex>
+#include <sstream>
+#include <utility>
#include <vector>
namespace nntrainer {
return std::regex_match(v, allowed);
}
+ConnectionSpec::ConnectionSpec(const std::vector<props::Name> &layer_ids_,
+ const std::string &op_type_) :
+ op_type(op_type_),
+ layer_ids(layer_ids_) {
+ NNTR_THROW_IF((op_type != ConnectionSpec::NoneType && layer_ids.size() < 2),
+ std::invalid_argument)
+ << "connection type is not none but has only a single or empty layer id, "
+ "type: "
+ << op_type << " number of names: " << layer_ids.size();
+
+ NNTR_THROW_IF((op_type == ConnectionSpec::NoneType && layer_ids.size() >= 2),
+ std::invalid_argument)
+ << "connection type is none but has only a single or empty layer id, "
+ "number of names: "
+ << layer_ids.size();
+}
+
+ConnectionSpec::ConnectionSpec(const ConnectionSpec &rhs) = default;
+ConnectionSpec &ConnectionSpec::operator=(const ConnectionSpec &rhs) = default;
+ConnectionSpec::ConnectionSpec(ConnectionSpec &&rhs) noexcept = default;
+ConnectionSpec &ConnectionSpec::
+operator=(ConnectionSpec &&rhs) noexcept = default;
+
+bool ConnectionSpec::operator==(const ConnectionSpec &rhs) const {
+ return op_type == rhs.op_type && layer_ids == rhs.layer_ids;
+}
+
+bool InputSpec::isValid(const ConnectionSpec &v) const {
+ return v.getLayerIds().size() > 0;
+}
+
+std::string ConnectionSpec::NoneType = "";
+
} // namespace props
+static const std::vector<std::pair<char, std::string>>
+ connection_supported_tokens = {{',', "concat"}, {'+', "addition"}};
+
+template <>
+std::string
+str_converter<props::connection_prop_tag, props::ConnectionSpec>::to_string(
+ const props::ConnectionSpec &value) {
+
+ auto &type = value.getOpType();
+
+ if (type == props::ConnectionSpec::NoneType) {
+ return value.getLayerIds().front();
+ }
+
+ auto &cst = connection_supported_tokens;
+
+ auto find_token = [&type](const std::pair<char, std::string> &token) {
+ return token.second == type;
+ };
+
+ auto token = std::find_if(cst.begin(), cst.end(), find_token);
+
+ NNTR_THROW_IF(token == cst.end(), std::invalid_argument)
+ << "Unsupported type given: " << type;
+
+ std::stringstream ss;
+ auto last_iter = value.getLayerIds().end() - 1;
+ for (auto iter = value.getLayerIds().begin(); iter != last_iter; ++iter) {
+ ss << static_cast<std::string>(*iter) << token->first;
+ }
+ ss << static_cast<std::string>(*last_iter);
+
+ return ss.str();
+}
+
+template <>
+props::ConnectionSpec
+str_converter<props::connection_prop_tag, props::ConnectionSpec>::from_string(
+ const std::string &value) {
+ auto generate_regex = [](char token) {
+ std::stringstream ss;
+ ss << "\\s*\\" << token << "\\s*";
+
+ return std::regex(ss.str());
+ };
+
+ auto generate_name_vector = [](const std::vector<std::string> &values) {
+ props::Name n;
+ std::vector<props::Name> names_;
+ names_.reserve(values.size());
+
+ for (auto &item : values) {
+ if (!n.isValid(item)) {
+ break;
+ }
+ names_.emplace_back(item);
+ }
+
+ return names_;
+ };
+
+ for (auto &token : connection_supported_tokens) {
+ auto reg_ = generate_regex(token.first);
+ auto values = split(value, reg_);
+ if (values.size() == 1) {
+ continue;
+ }
+
+ auto names = generate_name_vector(values);
+ if (names.size() == values.size()) {
+ return props::ConnectionSpec(names, token.second);
+ }
+ }
+
+ props::Name n;
+ n.set(value); // explicitly trigger validation using set method
+ return props::ConnectionSpec({n});
+}
+
} // namespace nntrainer
bool isValid(const unsigned int &v) const override { return v > 0; }
};
+/**
+ * @brief RAII class to define the connection spec
+ *
+ */
+class ConnectionSpec {
+public:
+ static std::string NoneType;
+
+ /**
+ * @brief Construct a new Connection Spec object
+ */
+ ConnectionSpec() = default;
+ /**
+ * @brief Construct a new Connection Spec object
+ *
+ * @param layer_ids_ layer ids that will be an operand
+ * @param op_type_ operator type
+ */
+ ConnectionSpec(const std::vector<Name> &layer_ids_,
+ const std::string &op_type_ = ConnectionSpec::NoneType);
+
+ /**
+ * @brief Construct a new Connection Spec object
+ *
+ * @param rhs rhs to copy
+ */
+ ConnectionSpec(const ConnectionSpec &rhs);
+
+ /**
+ * @brief Copy assignment operator
+ *
+ * @param rhs rhs to copy
+ * @return ConnectionSpec&
+ */
+ ConnectionSpec &operator=(const ConnectionSpec &rhs);
+
+ /**
+ * @brief Move Construct Connection Spec object
+ *
+ * @param rhs rhs to move
+ */
+ ConnectionSpec(ConnectionSpec &&rhs) noexcept;
+
+ /**
+ * @brief Move assign a connection spec operator
+ *
+ * @param rhs rhs to move
+ * @return ConnectionSpec&
+ */
+ ConnectionSpec &operator=(ConnectionSpec &&rhs) noexcept;
+
+ /**
+ * @brief Get the Op Type object
+ *
+ * @return const std::string& op_type (read-only)
+ */
+ const std::string &getOpType() const { return op_type; }
+
+ /**
+ * @brief Get the Layer Ids object
+ *
+ * @return const std::vector<Name>& vector of layer ids (read-only)
+ */
+ const std::vector<Name> &getLayerIds() const { return layer_ids; }
+
+ /**
+ *
+ * @brief operator==
+ *
+ * @param rhs right side to compare
+ * @return true if equal
+ * @return false if not equal
+ */
+ bool operator==(const ConnectionSpec &rhs) const;
+
+private:
+ std::string op_type;
+ std::vector<Name> layer_ids;
+};
+
+/**
+ * @brief Connection prop tag type
+ *
+ */
+struct connection_prop_tag {};
+
+/**
+ * @brief InputSpec property, this defines connection specification of an input
+ *
+ */
+class InputSpec : public nntrainer::Property<ConnectionSpec> {
+public:
+ InputSpec(const ConnectionSpec &value = ConnectionSpec()) :
+ nntrainer::Property<ConnectionSpec>(value) {} /**< default value if any */
+ static constexpr const char *key =
+ "input_layers"; /**< unique key to access */
+ using prop_tag = connection_prop_tag; /**< property type */
+ bool isValid(const ConnectionSpec &v) const override;
+};
+
} // namespace props
} // namespace nntrainer
EXPECT_THROW(n.set("+layer"), std::invalid_argument);
}
+TEST(InputSpecProperty, setPropertyValid_p) {
+ using namespace nntrainer::props;
+ {
+ InputSpec expected(
+ ConnectionSpec({Name("A"), Name("B"), Name("C")}, "concat"));
+
+ InputSpec actual;
+ nntrainer::from_string("A, B, C", actual);
+ EXPECT_EQ(actual, expected);
+ EXPECT_EQ("A,B,C", nntrainer::to_string(actual));
+ }
+
+ {
+ InputSpec expected(
+ ConnectionSpec({Name("A"), Name("B"), Name("C")}, "addition"));
+
+ InputSpec actual;
+ nntrainer::from_string("A+ B +C", actual);
+ EXPECT_EQ("A+B+C", nntrainer::to_string(actual));
+
+ EXPECT_EQ(actual, expected);
+ }
+
+ {
+ InputSpec expected(ConnectionSpec({Name("A")}, ConnectionSpec::NoneType));
+
+ InputSpec actual;
+ nntrainer::from_string("A", actual);
+ EXPECT_EQ("A", nntrainer::to_string(actual));
+
+ EXPECT_EQ(actual, expected);
+ }
+}
+
+TEST(InputSpecProperty, emptyString_n_01) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, combinedOperator_n_01) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("A,B+C", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, combinedOperator_n_02) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("A+B,C", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, noOperator_n_01) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("A B", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, noOperator_n_02) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("A B", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, leadingOperator_n_01) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string(",A,B", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, leadingOperator_n_02) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("+A+B", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, trailingOperator_n_01) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("A,B,,", actual), std::invalid_argument);
+}
+
+TEST(InputSpecProperty, trailingOperator_n_02) {
+ using namespace nntrainer::props;
+ InputSpec actual;
+ EXPECT_THROW(nntrainer::from_string("A+B++", actual), std::invalid_argument);
+}
+
/**
* @brief Main gtest
*/