* @bug No known bugs except for NYI items
*/
#include <algorithm>
-#include <recurrent_realizer.h>
+#include <stdexcept>
#include <common_properties.h>
+#include <connection.h>
#include <grucell.h>
#include <input_layer.h>
#include <layer_node.h>
#include <lstmcell_core.h>
#include <nntrainer_error.h>
#include <node_exporter.h>
+#include <recurrent_realizer.h>
#include <remap_realizer.h>
#include <rnncell.h>
-#include <stdexcept>
#include <util_func.h>
#include <zoneout_lstmcell.h>
-#include <iostream>
-
namespace nntrainer {
namespace props {
* @brief Property for recurrent inputs
*
*/
-class RecurrentInput final : public Name {
+class RecurrentInput final : public Property<Connection> {
public:
/**
* @brief Construct a new Recurrent Input object
*
* @param name name
*/
- RecurrentInput(const std::string &name);
+ RecurrentInput(const Connection &name);
static constexpr const char *key = "recurrent_input";
- using prop_tag = str_prop_tag;
+ using prop_tag = connection_prop_tag;
};
RecurrentInput::RecurrentInput() {}
-RecurrentInput::RecurrentInput(const std::string &name) { set(name); };
+RecurrentInput::RecurrentInput(const Connection &con) { set(con); };
/**
* @brief Property for recurrent outputs
*
*/
-class RecurrentOutput final : public Name {
+class RecurrentOutput final : public Property<Connection> {
public:
/**
* @brief Construct a new Recurrent Output object
*
* @param name name
*/
- RecurrentOutput(const std::string &name);
+ RecurrentOutput(const Connection &name);
static constexpr const char *key = "recurrent_output";
- using prop_tag = str_prop_tag;
+ using prop_tag = connection_prop_tag;
};
RecurrentOutput::RecurrentOutput() {}
-RecurrentOutput::RecurrentOutput(const std::string &name) { set(name); };
+RecurrentOutput::RecurrentOutput(const Connection &con) { set(con); };
} // namespace props
RecurrentRealizer::RecurrentRealizer(
std::vector<props::AsSequence>(), props::UnrollFor(1))) {
auto left = loadProperties(properties, *recurrent_props);
- /// @todo check input, output number matches
- /// @todo check if as sequence is subset of recurrent output
/// @note AsSequence must be identifier based (not connection based) for now
/// consider A(layer) outputs a0, a1 connection and a0 needs return seq
/// Then it is impossible to locate a0 and a1 with the same name unless we
seq) != end_layers.end();
}),
std::invalid_argument)
- << "as_sequence property must be subset of recurrent_outputs";
+ << "as_sequence property must be subset of end_layers";
std::unordered_set<std::string> check_seqs;
for (auto &name : as_sequence) {
GraphRepresentation
RecurrentRealizer::realize(const GraphRepresentation &reference) {
- auto step0_verify_and_prepare = [this, &reference]() {
- for (auto &node : reference) {
- if (recurrent_info.count(node->getName())) {
- /// @todo this does not have to be the restriction as we
- /// are supporting connections (#1760)
- NNTR_THROW_IF(node->getNumInputConnections() != 1,
- std::invalid_argument)
- << "recurrent input must have single connection: " << node->getName();
- }
- }
+ auto step0_verify_and_prepare = []() {
+ /// empty intended
};
/**
});
/// 2. override first output name to $name/$idx - 1
- auto &name = node->getName();
- auto suffix_len = std::string("/0").length();
- if (auto iter =
- recurrent_info.find(name.substr(0, name.length() - suffix_len));
- iter != recurrent_info.end()) {
- std::string output_name =
- iter->second + "/" + std::to_string(time_idx - 1);
- new_node->remapConnections(
- [&output_name](std::string &name, unsigned &idx) {
- /// @todo alter only when idx matches
- name = output_name;
- });
+ for (auto &[recurrent_input, recurrent_output] : recurrent_info) {
+ if (node->getName() != recurrent_input.getName() + "/0") {
+ continue;
+ }
+ new_node->setInputConnectionIndex(recurrent_input.getIndex(),
+ recurrent_output.getIndex());
+ new_node->setInputConnectionName(recurrent_input.getIndex(),
+ recurrent_output.getName() + "/" +
+ std::to_string(time_idx - 1));
}
/// 3. set shared_from
realizeAndEqual(r, before, expected);
}
-TEST(DISABLED_RecurrentRealizer,
- recurrent_multi_inout_using_connection_return_seq_p) {
- /// NYI, just for specification
- EXPECT_TRUE(false);
+TEST(RecurrentRealizer, recurrent_multi_inout_using_connection_return_seq_p) {
+
+ RecurrentRealizer r(
+ {
+ "unroll_for=3",
+ "as_sequence=fc_out",
+ "recurrent_input=lstm,add(2)",
+ "recurrent_output=fc_out,split(1)",
+ },
+ {"source", "source2", "source3"}, {"fc_out"});
+
+ /// @note for below graph,
+ /// 1. fc_out feds back to lstm
+ /// 2. ouput_dummy feds back to source2_dummy
+ /// ========================================================
+ /// lstm -------- addition - split ---- fc_out (to_lstm)
+ /// source2_dummy --/ \----- (to addition 3)
+ std::vector<LayerRepresentation> before = {
+ {"lstm", {"name=lstm", "input_layers=source"}},
+ {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+ {"split", {"name=split", "input_layers=add"}},
+ {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
+ };
+
+ std::vector<LayerRepresentation> expected = {
+ /// timestep 0
+ {"lstm",
+ {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0"}},
+ {"addition", {"name=add/0", "input_layers=lstm/0,source2,source3"}},
+ {"split", {"name=split/0", "input_layers=add/0"}},
+ {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)"}},
+
+ /// timestep 1
+ {"lstm",
+ {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
+ "max_timestep=3", "timestep=1"}},
+ {"addition",
+ {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+ "shared_from=add/0"}},
+ {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
+ {"fully_connected",
+ {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
+
+ /// timestep 2
+ {"lstm",
+ {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
+ "max_timestep=3", "timestep=2"}},
+ {"addition",
+ {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+ "shared_from=add/0"}},
+ {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
+ {"fully_connected",
+ {"name=fc_out/2", "input_layers=split/2(0)", "shared_from=fc_out/0"}},
+ {"concat", {"name=fc_out", "input_layers=fc_out/0,fc_out/1,fc_out/2"}},
+ };
+
+ realizeAndEqual(r, before, expected);
}
-TEST(DISABLED_RecurrentRealizer, recurrent_multi_inout_using_connection_p) {
- /// NYI, just for specification
- EXPECT_TRUE(false);
+TEST(RecurrentRealizer, recurrent_multi_inout_using_connection_p) {
+ RecurrentRealizer r(
+ {
+ "unroll_for=3",
+ "recurrent_input=lstm,add(2)",
+ "recurrent_output=fc_out,split(1)",
+ },
+ {"source", "source2", "source3"}, {"fc_out"});
+
+ /// @note for below graph,
+ /// 1. fc_out feds back to lstm
+ /// 2. ouput_dummy feds back to source2_dummy
+ /// ========================================================
+ /// lstm -------- addition - split ---- fc_out (to_lstm)
+ /// source2_dummy --/ \----- (to addition 3)
+ std::vector<LayerRepresentation> before = {
+ {"lstm", {"name=lstm", "input_layers=source"}},
+ {"addition", {"name=add", "input_layers=lstm,source2,source3"}},
+ {"split", {"name=split", "input_layers=add"}},
+ {"fully_connected", {"name=fc_out", "input_layers=split(0)"}},
+ };
+
+ std::vector<LayerRepresentation> expected = {
+ /// timestep 0
+ {"lstm",
+ {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0"}},
+ {"addition", {"name=add/0", "input_layers=lstm/0,source2,source3"}},
+ {"split", {"name=split/0", "input_layers=add/0"}},
+ {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)"}},
+
+ /// timestep 1
+ {"lstm",
+ {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0",
+ "max_timestep=3", "timestep=1"}},
+ {"addition",
+ {"name=add/1", "input_layers=lstm/1,source2,split/0(1)",
+ "shared_from=add/0"}},
+ {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}},
+ {"fully_connected",
+ {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}},
+
+ /// timestep 2
+ {"lstm",
+ {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0",
+ "max_timestep=3", "timestep=2"}},
+ {"addition",
+ {"name=add/2", "input_layers=lstm/2,source2,split/1(1)",
+ "shared_from=add/0"}},
+ {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}},
+ {"fully_connected",
+ {"name=fc_out", "input_layers=split/2(0)", "shared_from=fc_out/0"}},
+ };
+
+ realizeAndEqual(r, before, expected);
}
TEST(RemapRealizer, remap_p) {