[Connection] Add indexing to connection
authorJihoon Lee <jhoon.it.lee@samsung.com>
Mon, 22 Nov 2021 06:43:16 +0000 (15:43 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 30 Nov 2021 03:02:12 +0000 (12:02 +0900)
This patch revisition connection spec to contain identifier and index to
make use of it.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
nntrainer/layers/common_properties.cpp
nntrainer/layers/common_properties.h
test/unittest/unittest_common_properties.cpp

index 27fcb2b..d307c9d 100644 (file)
  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
  * @bug    No known bugs except for NYI items
  */
+#include <base_properties.h>
 #include <common_properties.h>
 
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
+#include <stdexcept>
 #include <tensor_dim.h>
 
 #include <regex>
@@ -75,36 +77,22 @@ ReturnSequences::ReturnSequences(bool value) { set(value); }
 
 bool NumClass::isValid(const unsigned int &v) const { return v > 0; }
 
-ConnectionSpec::ConnectionSpec(const std::vector<props::Name> &layer_ids_,
-                               const std::string &op_type_) :
-  op_type(op_type_),
-  layer_ids(layer_ids_) {
-  NNTR_THROW_IF((op_type != ConnectionSpec::NoneType && layer_ids.size() < 2),
-                std::invalid_argument)
-    << "connection type is not none but has only a single or empty layer id, "
-       "type: "
-    << op_type << " number of names: " << layer_ids.size();
-
-  NNTR_THROW_IF((op_type == ConnectionSpec::NoneType && layer_ids.size() >= 2),
-                std::invalid_argument)
-    << "connection type is none but has only a single or empty layer id, "
-       "number of names: "
-    << layer_ids.size();
-}
+InputConnection::InputConnection() : nntrainer::Property<Connection>() {}
+InputConnection::InputConnection(const Connection &value) :
+  nntrainer::Property<Connection>(value) {} /**< default value if any */
 
-ConnectionSpec::ConnectionSpec(const ConnectionSpec &rhs) = default;
-ConnectionSpec &ConnectionSpec::operator=(const ConnectionSpec &rhs) = default;
-ConnectionSpec::ConnectionSpec(ConnectionSpec &&rhs) noexcept = default;
-ConnectionSpec &ConnectionSpec::
-operator=(ConnectionSpec &&rhs) noexcept = default;
+Connection::Connection(const std::string &layer_name, unsigned int idx) :
+  index(idx),
+  name(layer_name) {}
 
-bool ConnectionSpec::operator==(const ConnectionSpec &rhs) const {
-  return op_type == rhs.op_type && layer_ids == rhs.layer_ids;
-}
+Connection::Connection(const Connection &rhs) = default;
+Connection &Connection::operator=(const Connection &rhs) = default;
+Connection::Connection(Connection &&rhs) noexcept = default;
+Connection &Connection::operator=(Connection &&rhs) noexcept = default;
 
-bool InputSpec::isValid(const ConnectionSpec &v) const {
-  return v.getLayerIds().size() > 0;
-}
+bool Connection::operator==(const Connection &rhs) const noexcept {
+  return index == rhs.index and name == rhs.name;
+};
 
 Epsilon::Epsilon(float value) { set(value); }
 
@@ -255,8 +243,6 @@ std::array<unsigned int, 2> Padding1D::compute(const TensorDim &input,
   return {0, 0};
 }
 
-std::string ConnectionSpec::NoneType = "";
-
 WeightRegularizerConstant::WeightRegularizerConstant(float value) {
   set(value);
 }
@@ -323,83 +309,32 @@ void GenericShape::set(const TensorDim &value) {
 
 } // namespace props
 
-static const std::vector<std::pair<char, std::string>>
-  connection_supported_tokens = {{',', "concat"}, {'+', "addition"}};
-
 template <>
 std::string
-str_converter<props::connection_prop_tag, props::ConnectionSpec>::to_string(
-  const props::ConnectionSpec &value) {
-
-  auto &type = value.getOpType();
-
-  if (type == props::ConnectionSpec::NoneType) {
-    return value.getLayerIds().front();
-  }
-
-  auto &cst = connection_supported_tokens;
-
-  auto find_token = [&type](const std::pair<char, std::string> &token) {
-    return token.second == type;
-  };
-
-  auto token = std::find_if(cst.begin(), cst.end(), find_token);
-
-  NNTR_THROW_IF(token == cst.end(), std::invalid_argument)
-    << "Unsupported type given: " << type;
-
+str_converter<props::connection_prop_tag, props::Connection>::to_string(
+  const props::Connection &value) {
   std::stringstream ss;
-  auto last_iter = value.getLayerIds().end() - 1;
-  for (auto iter = value.getLayerIds().begin(); iter != last_iter; ++iter) {
-    ss << static_cast<std::string>(*iter) << token->first;
-  }
-  ss << static_cast<std::string>(*last_iter);
-
+  ss << value.getName().get() << '(' << value.getIndex() << ')';
   return ss.str();
 }
 
 template <>
-props::ConnectionSpec
-str_converter<props::connection_prop_tag, props::ConnectionSpec>::from_string(
+props::Connection
+str_converter<props::connection_prop_tag, props::Connection>::from_string(
   const std::string &value) {
-  auto generate_regex = [](char token) {
-    std::stringstream ss;
-    ss << "\\s*\\" << token << "\\s*";
+  auto pos = value.find_first_of('(');
+  auto idx = 0u;
+  auto name_part = value.substr(0, pos);
 
-    return std::regex(ss.str());
-  };
+  if (pos != std::string::npos) {
+    NNTR_THROW_IF(value.back() != ')', std::invalid_argument)
+      << "failed to parse connection invalid format: " << value;
 
-  auto generate_name_vector = [](const std::vector<std::string> &values) {
-    props::Name n;
-    std::vector<props::Name> names_;
-    names_.reserve(values.size());
-
-    for (auto &item : values) {
-      if (!n.isValid(item)) {
-        break;
-      }
-      names_.emplace_back(item);
-    }
-
-    return names_;
-  };
-
-  for (auto &token : connection_supported_tokens) {
-    auto reg_ = generate_regex(token.first);
-    auto values = split(value, reg_);
-    if (values.size() == 1) {
-      continue;
-    }
-
-    auto names = generate_name_vector(values);
-    if (names.size() == values.size()) {
-      return props::ConnectionSpec(names, token.second);
-    }
+    auto idx_part = value.substr(pos + 1, value.length() - 1);
+    idx = str_converter<uint_prop_tag, unsigned>::from_string(idx_part);
   }
 
-  props::Name n;
-  n.set(value); // explicitly trigger validation using set method
-  return props::ConnectionSpec({n});
+  return props::Connection(name_part, idx);
 }
 
 } // namespace nntrainer
index 06a56c1..1096261 100644 (file)
@@ -138,65 +138,75 @@ public:
 };
 
 /**
- * @brief RAII class to define the connection spec
+ * @brief RAII class to define the connection
  *
  */
-class ConnectionSpec {
+class Connection {
 public:
-  static std::string NoneType;
-
   /**
-   * @brief Construct a new Connection Spec object
+   * @brief Construct a new Connection object
    *
-   * @param layer_ids_ layer ids that will be an operand
-   * @param op_type_ operator type
+   * @param layer_name layer identifier
    */
-  ConnectionSpec(const std::vector<Name> &layer_ids_,
-                 const std::string &op_type_ = ConnectionSpec::NoneType);
+  Connection(const std::string &layer_name, unsigned int idx);
 
   /**
-   * @brief Construct a new Connection Spec object
+   * @brief Construct a new Connection object
    *
    * @param rhs rhs to copy
    */
-  ConnectionSpec(const ConnectionSpec &rhs);
+  Connection(const Connection &rhs);
 
   /**
    * @brief Copy assignment operator
    *
    * @param rhs rhs to copy
-   * @return ConnectionSpec&
+   * @return Connection&
    */
-  ConnectionSpec &operator=(const ConnectionSpec &rhs);
+  Connection &operator=(const Connection &rhs);
 
   /**
-   * @brief Move Construct Connection Spec object
+   * @brief Move Construct Connection object
    *
    * @param rhs rhs to move
    */
-  ConnectionSpec(ConnectionSpec &&rhs) noexcept;
+  Connection(Connection &&rhs) noexcept;
 
   /**
-   * @brief Move assign a connection spec operator
+   * @brief Move assign a connection operator
    *
    * @param rhs rhs to move
-   * @return ConnectionSpec&
+   * @return Connection&
+   */
+  Connection &operator=(Connection &&rhs) noexcept;
+
+  /**
+   * @brief Get the index
+   *
+   * @return unsigned index
+   */
+  const unsigned getIndex() const { return index; }
+
+  /**
+   * @brief Get the index
+   *
+   * @return unsigned index
    */
-  ConnectionSpec &operator=(ConnectionSpec &&rhs) noexcept;
+  unsigned &getIndex() { return index; }
 
   /**
-   * @brief Get the Op Type object
+   * @brief Get the Layer name object
    *
-   * @return const std::string& op_type (read-only)
+   * @return const Name& name of layer
    */
-  const std::string &getOpType() const { return op_type; }
+  const Name &getName() const { return name; }
 
   /**
-   * @brief Get the Layer Ids object
+   * @brief Get the Layer name object
    *
-   * @return const std::vector<Name>& vector of layer ids (read-only)
+   * @return Name& name of layer
    */
-  const std::vector<Name> &getLayerIds() const { return layer_ids; }
+  Name &getName() { return name; }
 
   /**
    *
@@ -206,11 +216,11 @@ public:
    * @return true if equal
    * @return false if not equal
    */
-  bool operator==(const ConnectionSpec &rhs) const;
+  bool operator==(const Connection &rhs) const noexcept;
 
 private:
-  std::string op_type;
-  std::vector<Name> layer_ids;
+  unsigned index;
+  Name name;
 };
 
 /**
@@ -223,25 +233,23 @@ struct connection_prop_tag {};
  * @brief InputSpec property, this defines connection specification of an input
  *
  */
-class InputSpec : public nntrainer::Property<ConnectionSpec> {
+class InputConnection : public nntrainer::Property<Connection> {
 public:
   /**
    * @brief Construct a new Input Spec object
    *
    */
-  InputSpec() : nntrainer::Property<ConnectionSpec>() {}
+  InputConnection();
 
   /**
    * @brief Construct a new Input Spec object
    *
    * @param value default value of a input spec
    */
-  InputSpec(const ConnectionSpec &value) :
-    nntrainer::Property<ConnectionSpec>(value) {} /**< default value if any */
+  InputConnection(const Connection &value);
   static constexpr const char *key =
     "input_layers";                     /**< unique key to access */
   using prop_tag = connection_prop_tag; /**< property type */
-  bool isValid(const ConnectionSpec &v) const override;
 };
 
 /**
index 6294d6c..6e93fc0 100644 (file)
@@ -101,92 +101,49 @@ TEST(NameProperty, mustStartWithAlphaNumeric_01_n) {
   EXPECT_THROW(n.set("+layer"), std::invalid_argument);
 }
 
-TEST(InputSpecProperty, setPropertyValid_p) {
+TEST(InputConnection, setPropertyValid_p) {
   using namespace nntrainer::props;
   {
-    InputSpec expected(
-      ConnectionSpec({Name("A"), Name("B"), Name("C")}, "concat"));
+    InputConnection expected(Connection("a", 0));
 
-    InputSpec actual;
-    nntrainer::from_string("A, B, C", actual);
+    InputConnection actual;
+    nntrainer::from_string("A", actual);
     EXPECT_EQ(actual, expected);
-    EXPECT_EQ("a,b,c", nntrainer::to_string(actual));
+    EXPECT_EQ("a(0)", nntrainer::to_string(actual));
   }
 
   {
-    InputSpec expected(
-      ConnectionSpec({Name("A"), Name("B"), Name("C")}, "addition"));
-
-    InputSpec actual;
-    nntrainer::from_string("A+ B +C", actual);
-    EXPECT_EQ("a+b+c", nntrainer::to_string(actual));
-
-    EXPECT_EQ(actual, expected);
-  }
-
-  {
-    InputSpec expected(ConnectionSpec({Name("a")}, ConnectionSpec::NoneType));
-
-    InputSpec actual;
-    nntrainer::from_string("a", actual);
-    EXPECT_EQ("a", nntrainer::to_string(actual));
+    InputConnection expected(Connection("a", 2));
 
+    InputConnection actual;
+    nntrainer::from_string("a(2)", actual);
     EXPECT_EQ(actual, expected);
+    EXPECT_EQ("a(2)", nntrainer::to_string(actual));
   }
 }
 
-TEST(InputSpecProperty, emptyString_n_01) {
+TEST(InputConnection, emptyString_n_01) {
   using namespace nntrainer::props;
-  InputSpec actual;
+  InputConnection actual;
   EXPECT_THROW(nntrainer::from_string("", actual), std::invalid_argument);
 }
 
-TEST(InputSpecProperty, combinedOperator_n_01) {
-  using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("A,B+C", actual), std::invalid_argument);
-}
-
-TEST(InputSpecProperty, combinedOperator_n_02) {
-  using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("A+B,C", actual), std::invalid_argument);
-}
-
-TEST(InputSpecProperty, noOperator_n_01) {
-  using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("A B", actual), std::invalid_argument);
-}
-
-TEST(InputSpecProperty, noOperator_n_02) {
-  using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("A B", actual), std::invalid_argument);
-}
-
-TEST(InputSpecProperty, leadingOperator_n_01) {
-  using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string(",A,B", actual), std::invalid_argument);
-}
-
-TEST(InputSpecProperty, leadingOperator_n_02) {
+TEST(InputConnection, onlyIndex_n_01) {
   using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("+A+B", actual), std::invalid_argument);
+  InputConnection actual;
+  EXPECT_THROW(nntrainer::from_string("[0]", actual), std::invalid_argument);
 }
 
-TEST(InputSpecProperty, trailingOperator_n_01) {
+TEST(InputConnection, invalidFormat_n_01) {
   using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("A,B,,", actual), std::invalid_argument);
+  InputConnection actual;
+  EXPECT_THROW(nntrainer::from_string("a[0", actual), std::invalid_argument);
 }
 
-TEST(InputSpecProperty, trailingOperator_n_02) {
+TEST(InputConnection, invalidFormat_n_02) {
   using namespace nntrainer::props;
-  InputSpec actual;
-  EXPECT_THROW(nntrainer::from_string("A+B++", actual), std::invalid_argument);
+  InputConnection actual;
+  EXPECT_THROW(nntrainer::from_string("[0", actual), std::invalid_argument);
 }
 
 TEST(Padding2D, setPropertyValid_p) {
@@ -259,14 +216,14 @@ int main(int argc, char **argv) {
   try {
     testing::InitGoogleTest(&argc, argv);
   } catch (...) {
-    std::cerr << "Error duing IniGoogleTest" << std::endl;
+    std::cerr << "Error during IniGoogleTest" << std::endl;
     return 0;
   }
 
   try {
     result = RUN_ALL_TESTS();
   } catch (...) {
-    std::cerr << "Error duing RUN_ALL_TESTS()" << std::endl;
+    std::cerr << "Error during RUN_ALL_TESTS()" << std::endl;
   }
 
   return result;