* @author Jihoon Lee <jhoon.it.lee@samsung.com>
* @bug No known bugs except for NYI items
*/
+#include <algorithm>
#include <compiler_fwd.h>
#include <memory>
+#include <stdexcept>
#include <vector>
+#include <layer_node.h>
+#include <nntrainer_log.h>
#include <previous_input_realizer.h>
namespace nntrainer {
PreviousInputRealizer::PreviousInputRealizer(
- const std::vector<std::string> &identified_input) {}
+ const std::vector<std::string> &identified_inputs_) :
+ identified_inputs(identified_inputs_) {}
PreviousInputRealizer::~PreviousInputRealizer() {}
GraphRepresentation
PreviousInputRealizer::realize(const GraphRepresentation &reference) {
- return GraphRepresentation();
+ GraphRepresentation processed(reference.begin(), reference.end());
+
+ /**
+ * @brief for node has input connection, below function determines if the node
+ * should be input node or add input_layers from previous layer
+ *
+ */
+ auto is_actually_an_input_node = [this](const LayerNode &node) {
+ return node.hasInputShapeProperty() or
+ std::any_of(identified_inputs.begin(), identified_inputs.end(),
+ [&node](auto &name) { return node.getName() == name; });
+ };
+
+ for (auto iter = processed.begin(); iter != processed.end(); ++iter) {
+ auto &node = *iter;
+ if (node->getNumInputConnections() != 0) {
+ continue;
+ }
+
+ if (is_actually_an_input_node(*node)) {
+ continue;
+ }
+
+ NNTR_THROW_IF(iter == processed.begin(), std::invalid_argument)
+ << "First node must be identified as an input if it is qualified to be "
+ "input, name: "
+ << node->getName();
+
+ auto &prev_node = *(iter - 1);
+ ml_logi(
+ "%s is identified as a non-input node and default input layer(%s) is "
+ "being set ",
+ node->getName().c_str(), prev_node->getName().c_str());
+
+ node->setInputLayers({prev_node->getName()});
+ }
+
+ return processed;
}
} // namespace nntrainer
/**
* @brief Construct a new Previous Input Realizer object
*
- * @param identified_input node that is identified as an input, this must not
+ * @param identified_inputs node that is identified as an input, this must not
* connect to other nodes automatically
*/
- PreviousInputRealizer(const std::vector<std::string> &identified_input);
+ PreviousInputRealizer(const std::vector<std::string> &identified_inputs);
/**
* @brief Destroy the Graph Realizer object
*
*/
GraphRepresentation realize(const GraphRepresentation &reference) override;
+
+private:
+ std::vector<std::string> identified_inputs; /**< inputs are identified */
};
} // namespace nntrainer
realizeAndEqual(r, before, after);
}
-TEST(PreviousInputRealizer, DISABLED_previous_p) {
+TEST(PreviousInputRealizer, previous_p) {
{ /// realization without identifying custom input
std::vector<LayerRepresentation> before = {
{"fully_connected", {"name=fc1", "input_shape=1"}}, // model input
}
}
-TEST(PreviousInputRealizer, DISABLED_user_not_identifying_first_input_n) {
+TEST(PreviousInputRealizer, user_not_identifying_first_input_n) {
/// realization without identifying custom input
std::vector<LayerRepresentation> before = {
{"fully_connected", {"name=fc1"}}, // this should be model input but