From: hyeonseok lee Date: Wed, 12 Jan 2022 11:29:41 +0000 (+0900) Subject: [lstm] remove timestep property X-Git-Tag: accepted/tizen/unified/20220323.062643~31 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=af0b5de8d62a2f357d0e7b1bbef9b7cd9f2eb31b;p=platform%2Fcore%2Fml%2Fnntrainer.git [lstm] remove timestep property - Remove timestep property from lstm layer. This will disable unrolling the lstm layer. - Adjust recurrent unittest to simple lstm unittest. Signed-off-by: hyeonseok lee --- diff --git a/nntrainer/compiler/recurrent_realizer.cpp b/nntrainer/compiler/recurrent_realizer.cpp index 6e56fd6..71f0f0f 100644 --- a/nntrainer/compiler/recurrent_realizer.cpp +++ b/nntrainer/compiler/recurrent_realizer.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -186,7 +185,6 @@ static void propagateTimestep(LayerNode *node, unsigned int time_step, /** @todo add an interface to check if a layer supports a property */ auto is_recurrent_type = [](LayerNode *node) { return node->getType() == RNNCellLayer::type || - node->getType() == LSTMLayer::type || node->getType() == ZoneoutLSTMCellLayer::type; }; diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index 71a0bb7..6afca49 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -40,7 +40,7 @@ LSTMLayer::LSTMLayer() : props::HiddenStateActivation() = ActivationType::ACT_TANH, props::RecurrentActivation() = ActivationType::ACT_SIGMOID, props::ReturnSequences(), props::DropOutRate(), - props::MaxTimestep(), props::Timestep()), + props::MaxTimestep()), acti_func(ActivationType::ACT_NONE, true), recurrent_acti_func(ActivationType::ACT_NONE, true), epsilon(1e-3) { @@ -141,18 +141,18 @@ void LSTMLayer::finalize(InitLayerContext &context) { const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit); wt_idx[LSTMParams::hidden_state] = context.requestTensor( hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN, false); + TensorLifespan::ITERATION_LIFESPAN); // cell_state_dim : [ batch_size, 1, max_timestep, unit ] const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit); wt_idx[LSTMParams::cell_state] = context.requestTensor( cell_state_dim, "cell_state", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN, false); + TensorLifespan::ITERATION_LIFESPAN); // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit); wt_idx[LSTMParams::ifgo] = context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, - TensorLifespan::ITERATION_LIFESPAN, false); + TensorLifespan::ITERATION_LIFESPAN); if (dropout_rate > epsilon) { // dropout_mask_dim = [ batch, 1, time_iteration, unit ] @@ -189,23 +189,11 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { const float dropout_rate = std::get(lstm_props).get(); const unsigned int max_timestep = std::get(lstm_props).get(); - const props::Timestep timestep = std::get(lstm_props); - - unsigned int start_timestep = 0; - unsigned int end_timestep = max_timestep; - if (!timestep.empty()) { - const unsigned int current_timestep = timestep.get(); - if (current_timestep >= end_timestep) { - throw std::runtime_error("Timestep to run exceeds input dimensions"); - } - - start_timestep = current_timestep; - end_timestep = current_timestep + 1; - } const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX); - const unsigned int batch_size = inputs.getDim().batch(); - const unsigned int feature_size = inputs.getDim().width(); + const TensorDim input_dim = inputs.getDim(); + const unsigned int batch_size = input_dim.batch(); + const unsigned int feature_size = input_dim.width(); Tensor &output = context.getOutput(SINGLE_INOUT_IDX); const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]); @@ -225,17 +213,8 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { Tensor &cs = context.getTensor(wt_idx[LSTMParams::cell_state]); Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]); - if (!start_timestep) { - hs.setZero(); - cs.setZero(); - } - - /** - * @note when the recurrent realization happens, different instances of lstm - * will share the weights, hidden state, cell and ifgo memory. However, they - * do not share the input, output and derivatives memory. The input/output - * will be contain a single timestep data only. - */ + hs.setZero(); + cs.setZero(); for (unsigned int batch = 0; batch < batch_size; ++batch) { const Tensor input_batch = inputs.getBatchSlice(batch, 1); @@ -243,7 +222,7 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { Tensor cs_batch = cs.getBatchSlice(batch, 1); Tensor ifgo_batch = ifgos.getBatchSlice(batch, 1); - for (unsigned int t = start_timestep; t < end_timestep; ++t) { + for (unsigned int t = 0; t < max_timestep; ++t) { Tensor input; if (input_batch.height() != 1) input = @@ -286,12 +265,12 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { } } - if (start_timestep == 0 && end_timestep == max_timestep && return_sequences) { + if (return_sequences) { std::copy(hs.getData(), hs.getData() + hs.size(), output.getData()); } else { for (unsigned int batch = 0; batch < batch_size; ++batch) { float *hidden_state_data = - hs.getAddress(batch * max_timestep * unit + (end_timestep - 1) * unit); + hs.getAddress(batch * max_timestep * unit + (max_timestep - 1) * unit); float *output_data = output.getAddress(batch * unit); std::copy(hidden_state_data, hidden_state_data + unit, output_data); } @@ -299,63 +278,11 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { } void LSTMLayer::calcDerivative(RunLayerContext &context) { - const unsigned int unit = std::get(lstm_props).get(); - const unsigned int max_timestep = - std::get(lstm_props).get(); - const props::Timestep timestep = std::get(lstm_props); - - unsigned int start_timestep = 0; - unsigned int end_timestep = max_timestep; - if (!timestep.empty()) { - const unsigned int cur_timestep = timestep.get(); - // Todo: replace end_timestep with input's time iteration - if (cur_timestep >= end_timestep) { - throw std::runtime_error("Timestep to run exceeds input dimensions"); - } - - start_timestep = cur_timestep; - end_timestep = cur_timestep + 1; - } - const unsigned int timestep_diff = end_timestep - start_timestep; - - const TensorDim input_dim = context.getInput(SINGLE_INOUT_IDX).getDim(); - const unsigned int batch_size = input_dim.batch(); - const unsigned int feature_size = input_dim.width(); - - const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]); - const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]); Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]); + const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]); - if (start_timestep == 0 && end_timestep == max_timestep) { - /** - * this if is only for optimization purpose. The else calculates for - * this scenario as well. - */ - lstmcell_calcDerivative(d_ifgos, weight_ih, outgoing_derivative); - } else { - for (unsigned int b = 0; b < batch_size; ++b) { - const Tensor d_ifgo_batch = d_ifgos.getBatchSlice(b, 1); - Tensor outgoing_derivative_batch = - outgoing_derivative.getBatchSlice(b, 1); - Tensor d_ifgo, outgoing_derivative_; - - if (d_ifgo_batch.height() != 1) { - d_ifgo = d_ifgo_batch.getSharedDataTensor( - {timestep_diff, NUM_GATE * unit}, start_timestep * NUM_GATE * unit); - } else { - d_ifgo = d_ifgo_batch; - } - - if (outgoing_derivative_batch.height() != 1) { - outgoing_derivative_ = outgoing_derivative_batch.getSharedDataTensor( - {timestep_diff, feature_size}, start_timestep * feature_size); - } else { - outgoing_derivative_ = outgoing_derivative_batch; - } - - lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative_); - } - } + lstmcell_calcDerivative(d_ifgos, weight_ih, outgoing_derivative); } void LSTMLayer::calcGradient(RunLayerContext &context) { @@ -369,18 +296,9 @@ void LSTMLayer::calcGradient(RunLayerContext &context) { const float dropout_rate = std::get(lstm_props).get(); const unsigned int max_timestep = std::get(lstm_props).get(); - const props::Timestep timestep = std::get(lstm_props); unsigned int start_timestep = max_timestep - 1; int end_timestep = -1; - if (!timestep.empty()) { - const unsigned int cur_timestep = timestep.get(); - NNTR_THROW_IF(cur_timestep > start_timestep, std::runtime_error) - << "Timestep to run exceeds input dimension current timestep" - << cur_timestep << "start_timestep" << start_timestep; - start_timestep = cur_timestep; - end_timestep = cur_timestep - 1; - } const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX); const Tensor &incoming_derivative = @@ -411,24 +329,21 @@ void LSTMLayer::calcGradient(RunLayerContext &context) { Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]); Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]); - if (start_timestep + 1 == max_timestep) { - d_weight_ih.setZero(); - d_weight_hh.setZero(); - if (!disable_bias) { - if (integrate_bias) { - d_bias_h.setZero(); - } else { - d_bias_ih.setZero(); - d_bias_hh.setZero(); - } + d_weight_ih.setZero(); + d_weight_hh.setZero(); + if (!disable_bias) { + if (integrate_bias) { + d_bias_h.setZero(); + } else { + d_bias_ih.setZero(); + d_bias_hh.setZero(); } - - d_cs.setZero(); - d_hs.setZero(); } - if (start_timestep == max_timestep - 1 && end_timestep == -1 && - return_sequences) { + d_cs.setZero(); + d_hs.setZero(); + + if (return_sequences) { std::copy(incoming_derivative.getData(), incoming_derivative.getData() + incoming_derivative.size(), d_hs.getData()); diff --git a/nntrainer/layers/lstm.h b/nntrainer/layers/lstm.h index 5667f7f..804ba0e 100644 --- a/nntrainer/layers/lstm.h +++ b/nntrainer/layers/lstm.h @@ -107,13 +107,12 @@ private: * RecurrentActivation: activation type for recurrent. default is sigmoid * ReturnSequence: option for return sequence * DropOutRate: dropout rate - * MaxTimestep: maximum timestep for lstmcell - * TimeStep: timestep for which lstm should operate + * MaxTimestep: maximum timestep for lstm * * */ std::tuple + props::DropOutRate, props::MaxTimestep> lstm_props; std::array wt_idx; /**< indices of the weights */ diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz index f752d8b..4480c3f 100644 Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ diff --git a/test/input_gen/genModelsRecurrent_v2.py b/test/input_gen/genModelsRecurrent_v2.py index 47e1fd8..6197a19 100644 --- a/test/input_gen/genModelsRecurrent_v2.py +++ b/test/input_gen/genModelsRecurrent_v2.py @@ -56,32 +56,31 @@ class RNNCellStacked(torch.nn.Module): return ret, loss class LSTMStacked(torch.nn.Module): - def __init__(self, unroll_for=2, num_lstm=1): + def __init__(self, num_lstm=1): super().__init__() self.input_size = self.hidden_size = 2 + self.num_lstm = num_lstm self.lstms = torch.nn.ModuleList( [ - torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True) + torch.nn.LSTM(self.input_size, self.hidden_size, batch_first=True) + # torch.nn.LSTM(self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True) for _ in range(num_lstm) ] ) - self.unroll_for = unroll_for self.loss = torch.nn.MSELoss() def forward(self, inputs, labels): - hs = [torch.zeros_like(inputs[0]) for _ in self.lstms] - cs = [torch.zeros_like(inputs[0]) for _ in self.lstms] out = inputs[0] - ret = [] - for _ in range(self.unroll_for): - for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)): - hs[i], cs[i] = lstm(out, (h, c)) - out = hs[i] - ret.append(out) + states = inputs[1:] + # hs = [states[2 * i] for i in range(self.num_lstm)] + hs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)] + # cs = [states[2 * i + 1] for i in range(self.num_lstm)] + cs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)] + for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)): + out, (hs[i], cs[i]) = lstm(out, (h, c)) - ret = torch.stack(ret, dim=1) - loss = self.loss(ret, labels[0]) - return ret, loss + loss = self.loss(out, labels[0]) + return out, loss class LSTMCellStacked(torch.nn.Module): def __init__(self, unroll_for=2, num_lstmcell=1): @@ -89,7 +88,7 @@ class LSTMCellStacked(torch.nn.Module): self.input_size = self.hidden_size = 2 self.lstmcells = torch.nn.ModuleList( [ - torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True) + torch.nn.LSTMCell(self.input_size, self.hidden_size) for _ in range(num_lstmcell) ] ) @@ -213,19 +212,23 @@ if __name__ == "__main__": name="rnncell_stacked", ) + unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 1, 3, 2, 2, 2] record_v2( - LSTMStacked(unroll_for=2, num_lstm=1), - iteration=2, - input_dims=[(3, 2)], - label_dims=[(3, 2, 2)], + LSTMStacked(num_lstm=num_lstm), + iteration=iteration, + input_dims=[(batch_size, unroll_for, feature_size)], + # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)], + label_dims=[(batch_size, unroll_for, unit)], name="lstm_single", ) + unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 2, 3, 2, 2, 2] record_v2( - LSTMStacked(unroll_for=2, num_lstm=2), - iteration=2, - input_dims=[(3, 2)], - label_dims=[(3, 2, 2)], + LSTMStacked(num_lstm=num_lstm), + iteration=iteration, + input_dims=[(batch_size, unroll_for, feature_size)], + # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)], + label_dims=[(batch_size, unroll_for, unit)], name="lstm_stacked", ) diff --git a/test/input_gen/transLayer_v2.py b/test/input_gen/transLayer_v2.py index 4f15325..9373d67 100644 --- a/test/input_gen/transLayer_v2.py +++ b/test/input_gen/transLayer_v2.py @@ -70,7 +70,7 @@ def zoneout_translate(model): new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3], hidden_state, cell_state] yield from new_params -@register_for_((torch.nn.RNNCell, torch.nn.LSTMCell)) +@register_for_((torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell)) def rnn_lstm_translate(model): params = [(name, tensor.detach()) for name, tensor in model.named_parameters()] # [hidden, input] -> [input, hidden] diff --git a/test/unittest/compiler/unittest_realizer.cpp b/test/unittest/compiler/unittest_realizer.cpp index 967883d..43d74cc 100644 --- a/test/unittest/compiler/unittest_realizer.cpp +++ b/test/unittest/compiler/unittest_realizer.cpp @@ -137,34 +137,31 @@ TEST(RecurrentRealizer, recurrent_input_is_sequence_p) { TEST(RecurrentRealizer, recurrent_return_sequence_single_p) { using C = Connection; RecurrentRealizer r({"unroll_for=3", "as_sequence=fc_out", - "recurrent_input=lstm", "recurrent_output=fc_out"}, + "recurrent_input=lstmcell", "recurrent_output=fc_out"}, {C("source")}, {C("fc_out")}); std::vector before = { - {"lstm", {"name=lstm", "input_layers=source"}}, - {"fully_connected", {"name=fc_out", "input_layers=lstm"}}}; + {"lstmcell", {"name=lstmcell", "input_layers=source"}}, + {"fully_connected", {"name=fc_out", "input_layers=lstmcell"}}}; std::vector expected = { /// t - 0 - {"lstm", - {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0", - "shared_from=lstm/0"}}, + {"lstmcell", + {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}}, {"fully_connected", - {"name=fc_out/0", "input_layers=lstm/0", "shared_from=fc_out/0"}}, + {"name=fc_out/0", "input_layers=lstmcell/0", "shared_from=fc_out/0"}}, /// t - 1 - {"lstm", - {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0", - "max_timestep=3", "timestep=1"}}, + {"lstmcell", + {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}}, {"fully_connected", - {"name=fc_out/1", "input_layers=lstm/1", "shared_from=fc_out/0"}}, + {"name=fc_out/1", "input_layers=lstmcell/1", "shared_from=fc_out/0"}}, /// t - 2 - {"lstm", - {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0", - "max_timestep=3", "timestep=2"}}, + {"lstmcell", + {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}}, {"fully_connected", - {"name=fc_out/2", "input_layers=lstm/2", "shared_from=fc_out/0"}}, + {"name=fc_out/2", "input_layers=lstmcell/2", "shared_from=fc_out/0"}}, /// mapping {"concat", @@ -181,53 +178,50 @@ TEST(RecurrentRealizer, recurrent_multi_inout_return_seq_p) { { "unroll_for=3", "as_sequence=fc_out", - "recurrent_input=lstm,add(2)", + "recurrent_input=lstmcell,add(2)", "recurrent_output=fc_out,split(1)", }, {C("source"), C("source2"), C("source3")}, {C("fc_out")}); /// @note for below graph, - /// 1. fc_out feds back to lstm + /// 1. fc_out feds back to lstmcell /// 2. ouput_dummy feds back to source2_dummy /// ======================================================== - /// lstm -------- addition - split ---- fc_out (to_lstm) + /// lstmcell -------- addition - split ---- fc_out (to_lstmcell) /// source2_dummy --/ \----- (to addition 3) std::vector before = { - {"lstm", {"name=lstm", "input_layers=source"}}, - {"addition", {"name=add", "input_layers=lstm,source2,source3"}}, + {"lstmcell", {"name=lstmcell", "input_layers=source"}}, + {"addition", {"name=add", "input_layers=lstmcell,source2,source3"}}, {"split", {"name=split", "input_layers=add"}}, {"fully_connected", {"name=fc_out", "input_layers=split(0)"}}, }; std::vector expected = { /// timestep 0 - {"lstm", - {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0", - "shared_from=lstm/0"}}, + {"lstmcell", + {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/0", "input_layers=lstm/0,source2,source3", + {"name=add/0", "input_layers=lstmcell/0,source2,source3", "shared_from=add/0"}}, {"split", {"name=split/0", "input_layers=add/0", "shared_from=split/0"}}, {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)", "shared_from=fc_out/0"}}, /// timestep 1 - {"lstm", - {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0", - "max_timestep=3", "timestep=1"}}, + {"lstmcell", + {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/1", "input_layers=lstm/1,source2,split/0(1)", + {"name=add/1", "input_layers=lstmcell/1,source2,split/0(1)", "shared_from=add/0"}}, {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}}, {"fully_connected", {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}}, /// timestep 2 - {"lstm", - {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0", - "max_timestep=3", "timestep=2"}}, + {"lstmcell", + {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/2", "input_layers=lstm/2,source2,split/1(1)", + {"name=add/2", "input_layers=lstmcell/2,source2,split/1(1)", "shared_from=add/0"}}, {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}}, {"fully_connected", @@ -247,53 +241,50 @@ TEST(RecurrentRealizer, recurrent_multi_inout_using_connection_p) { RecurrentRealizer r( { "unroll_for=3", - "recurrent_input=lstm,add(2)", + "recurrent_input=lstmcell,add(2)", "recurrent_output=fc_out,split(1)", }, {C("source"), C("source2"), C("source3")}, {C("fc_out")}); /// @note for below graph, - /// 1. fc_out feds back to lstm + /// 1. fc_out feds back to lstmcell /// 2. ouput_dummy feds back to source2_dummy /// ======================================================== - /// lstm -------- addition - split ---- fc_out (to_lstm) + /// lstmcell -------- addition - split ---- fc_out (to_lstmcell) /// source2_dummy --/ \----- (to addition 3) std::vector before = { - {"lstm", {"name=lstm", "input_layers=source"}}, - {"addition", {"name=add", "input_layers=lstm,source2,source3"}}, + {"lstmcell", {"name=lstmcell", "input_layers=source"}}, + {"addition", {"name=add", "input_layers=lstmcell,source2,source3"}}, {"split", {"name=split", "input_layers=add"}}, {"fully_connected", {"name=fc_out", "input_layers=split(0)"}}, }; std::vector expected = { /// timestep 0 - {"lstm", - {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0", - "shared_from=lstm/0"}}, + {"lstmcell", + {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/0", "input_layers=lstm/0,source2,source3", + {"name=add/0", "input_layers=lstmcell/0,source2,source3", "shared_from=add/0"}}, {"split", {"name=split/0", "input_layers=add/0", "shared_from=split/0"}}, {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)", "shared_from=fc_out/0"}}, /// timestep 1 - {"lstm", - {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0", - "max_timestep=3", "timestep=1"}}, + {"lstmcell", + {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/1", "input_layers=lstm/1,source2,split/0(1)", + {"name=add/1", "input_layers=lstmcell/1,source2,split/0(1)", "shared_from=add/0"}}, {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}}, {"fully_connected", {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}}, /// timestep 2 - {"lstm", - {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0", - "max_timestep=3", "timestep=2"}}, + {"lstmcell", + {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/2", "input_layers=lstm/2,source2,split/1(1)", + {"name=add/2", "input_layers=lstmcell/2,source2,split/1(1)", "shared_from=add/0"}}, {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}}, {"fully_connected", @@ -311,7 +302,7 @@ TEST(RecurrentRealizer, recurrent_multi_inout_multi_connection_end_p) { RecurrentRealizer r( { "unroll_for=3", - "recurrent_input=lstm,add(2)", + "recurrent_input=lstmcell,add(2)", "recurrent_output=fc_out,split(1)", "as_sequence=split(1)", }, @@ -326,47 +317,44 @@ TEST(RecurrentRealizer, recurrent_multi_inout_multi_connection_end_p) { }); /// @note for below graph, - /// 1. fc_out feds back to lstm + /// 1. fc_out feds back to lstmcell /// 2. ouput_dummy feds back to source2_dummy /// ======================================================== - /// lstm -------- addition - split ---- fc_out (to_lstm) + /// lstmcell -------- addition - split ---- fc_out (to_lstmcell) /// source2_dummy --/ \----- (to addition 3) std::vector before = { - {"lstm", {"name=lstm", "input_layers=source"}}, - {"addition", {"name=add", "input_layers=lstm,source2,source3"}}, + {"lstmcell", {"name=lstmcell", "input_layers=source"}}, + {"addition", {"name=add", "input_layers=lstmcell,source2,source3"}}, {"split", {"name=split", "input_layers=add"}}, {"fully_connected", {"name=fc_out", "input_layers=split(0)"}}, }; std::vector expected = { /// timestep 0 - {"lstm", - {"name=lstm/0", "input_layers=source", "max_timestep=3", "timestep=0", - "shared_from=lstm/0"}}, + {"lstmcell", + {"name=lstmcell/0", "input_layers=source", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/0", "input_layers=lstm/0,source2,source3", + {"name=add/0", "input_layers=lstmcell/0,source2,source3", "shared_from=add/0"}}, {"split", {"name=split/0", "input_layers=add/0", "shared_from=split/0"}}, {"fully_connected", {"name=fc_out/0", "input_layers=split/0(0)", "shared_from=fc_out/0"}}, /// timestep 1 - {"lstm", - {"name=lstm/1", "input_layers=fc_out/0", "shared_from=lstm/0", - "max_timestep=3", "timestep=1"}}, + {"lstmcell", + {"name=lstmcell/1", "input_layers=fc_out/0", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/1", "input_layers=lstm/1,source2,split/0(1)", + {"name=add/1", "input_layers=lstmcell/1,source2,split/0(1)", "shared_from=add/0"}}, {"split", {"name=split/1", "input_layers=add/1", "shared_from=split/0"}}, {"fully_connected", {"name=fc_out/1", "input_layers=split/1(0)", "shared_from=fc_out/0"}}, /// timestep 2 - {"lstm", - {"name=lstm/2", "input_layers=fc_out/1", "shared_from=lstm/0", - "max_timestep=3", "timestep=2"}}, + {"lstmcell", + {"name=lstmcell/2", "input_layers=fc_out/1", "shared_from=lstmcell/0"}}, {"addition", - {"name=add/2", "input_layers=lstm/2,source2,split/1(1)", + {"name=add/2", "input_layers=lstmcell/2,source2,split/1(1)", "shared_from=add/0"}}, {"split", {"name=split/2", "input_layers=add/2", "shared_from=split/0"}}, {"fully_connected", diff --git a/test/unittest/models/unittest_models_recurrent.cpp b/test/unittest/models/unittest_models_recurrent.cpp index 7859851..9aa3364 100644 --- a/test/unittest/models/unittest_models_recurrent.cpp +++ b/test/unittest/models/unittest_models_recurrent.cpp @@ -142,27 +142,15 @@ static std::unique_ptr makeSingleLSTM() { nn->setProperty({"batch_size=3"}); auto outer_graph = makeGraph({ - {"input", {"name=input", "input_shape=1:1:2"}}, - /// here lstm is being inserted - {"mse", {"name=loss", "input_layers=lstm_scope/a1"}}, + {"input", {"name=input", "input_shape=1:2:2"}}, + {"lstm", + {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true"}}, + {"mse", {"name=loss", "input_layers=a1"}}, }); for (auto &node : outer_graph) { nn->addLayer(node); } - auto lstm = makeGraph({ - {"lstm", {"name=a1", "unit=2", "integrate_bias=false"}}, - }); - - nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a1"}, - ml::train::ReferenceLayersType::RECURRENT, - { - "unroll_for=2", - "as_sequence=a1", - "recurrent_input=a1", - "recurrent_output=a1", - }); - nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); return nn; } @@ -172,28 +160,17 @@ static std::unique_ptr makeStackedLSTM() { nn->setProperty({"batch_size=3"}); auto outer_graph = makeGraph({ - {"input", {"name=input", "input_shape=1:1:2"}}, - /// here lstm is being inserted - {"mse", {"name=loss", "input_layers=lstm_scope/a2"}}, + {"input", {"name=input", "input_shape=1:2:2"}}, + {"lstm", + {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true"}}, + {"lstm", + {"name=a2", "unit=2", "integrate_bias=false", "return_sequences=true"}}, + {"mse", {"name=loss"}}, }); for (auto &node : outer_graph) { nn->addLayer(node); } - auto lstm = makeGraph({ - {"lstm", {"name=a1", "unit=2", "integrate_bias=false"}}, - {"lstm", {"name=a2", "unit=2", "integrate_bias=false", "input_layers=a1"}}, - }); - - nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a2"}, - ml::train::ReferenceLayersType::RECURRENT, - { - "unroll_for=2", - "as_sequence=a2", - "recurrent_input=a1", - "recurrent_output=a2", - }); - nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); return nn; }