* @bug No known bugs except for NYI items
*/
+#include <connection.h>
+#include <iterator>
#include <layer_node.h>
#include <slice_realizer.h>
namespace nntrainer {
-SliceRealizer::SliceRealizer(const std::vector<std::string> &start_layers,
- const std::vector<std::string> &end_layers) :
- start_layers(start_layers),
- end_layers(end_layers.begin(), end_layers.end()) {}
+SliceRealizer::SliceRealizer(const std::vector<Connection> &start_layers,
+ const std::vector<Connection> &end_layers) {
+ /// discard index information as it is not needed as it is not really needed
+ this->start_layers.reserve(start_layers.size());
+
+ std::transform(start_layers.begin(), start_layers.end(),
+ std::back_inserter(this->start_layers),
+ [](const Connection &c) { return c.getName(); });
+
+ std::transform(end_layers.begin(), end_layers.end(),
+ std::inserter(this->end_layers, this->end_layers.begin()),
+ [](const Connection &c) { return c.getName(); });
+}
SliceRealizer::~SliceRealizer() {}
namespace nntrainer {
+class Connection;
+
/**
* @brief Graph realizer class which slice graph representation
*
/**
* @brief Construct a new Slice Realizer object
*
- * @param start_layers start layers
- * @param end_layers end layers
+ * @param start_connections start layers
+ * @param end_connections end layers
*/
- SliceRealizer(const std::vector<std::string> &start_layers,
- const std::vector<std::string> &end_layers);
+ SliceRealizer(const std::vector<Connection> &start_connections,
+ const std::vector<Connection> &end_connections);
/**
* @brief Destroy the Graph Realizer object
auto start_layers_ = normalize(start_layers);
auto end_layers_ = normalize(end_layers);
- auto start_conns_ =
+ auto start_conns =
std::vector<Connection>(start_layers.begin(), start_layers.end());
+ auto end_conns =
+ std::vector<Connection>(end_layers.begin(), end_layers.end());
+
std::vector<std::unique_ptr<GraphRealizer>> realizers;
- realizers.emplace_back(new PreviousInputRealizer(start_conns_));
- realizers.emplace_back(new SliceRealizer(start_layers_, end_layers_));
+ realizers.emplace_back(new PreviousInputRealizer(start_conns));
+ realizers.emplace_back(new SliceRealizer(start_conns, end_conns));
if (!input_layers_.empty()) {
realizers.emplace_back(new InputRealizer(start_layers_, input_layers_));
{"fully_connected", {"name=d2", "input_layers=c1"}},
};
- SliceRealizer r({"a1", "b1", "b2"}, {"a1", "d1", "d2"});
+ using C = Connection;
+ SliceRealizer r(
+ {
+ C("a1"),
+ C("b1"),
+ C("b2"),
+ },
+ {
+ C("a1"),
+ C("d1"),
+ C("d2"),
+ });
realizeAndEqual(r, before, after);
}
{"concat", {"name=c1", "input_layers=a1, b1"}},
};
- SliceRealizer r({"a1"}, {"c1"});
+ SliceRealizer r({Connection("a1")}, {Connection("c1")});
realizeAndEqual(r, before, after);
}