From e7aebbc54f8d2df9d99fa6f6e8e826fc7da844ef Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Fri, 14 Jan 2022 13:29:57 +0900 Subject: [PATCH] [lstm] implement bidirectional lstm forward - Make batch_first_forward function - For now only support forward for bidirectional lstm Signed-off-by: hyeonseok lee --- nntrainer/layers/common_properties.cpp | 2 + nntrainer/layers/common_properties.h | 15 + nntrainer/layers/lstm.cpp | 343 +++++++++++++----- nntrainer/layers/lstm.h | 5 +- test/input_gen/genModelsRecurrent_v2.py | 44 ++- test/input_gen/transLayer_v2.py | 16 +- .../models/unittest_models_recurrent.cpp | 45 +++ 7 files changed, 374 insertions(+), 96 deletions(-) diff --git a/nntrainer/layers/common_properties.cpp b/nntrainer/layers/common_properties.cpp index a17021cb..c0d6c2c8 100644 --- a/nntrainer/layers/common_properties.cpp +++ b/nntrainer/layers/common_properties.cpp @@ -75,6 +75,8 @@ std::ifstream::pos_type FilePath::file_size() { return cached_pos_size; } ReturnSequences::ReturnSequences(bool value) { set(value); } +Bidirectional::Bidirectional(bool value) { set(value); } + bool NumClass::isValid(const unsigned int &v) const { return v > 0; } InputConnection::InputConnection() : nntrainer::Property() {} diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index a12eda7b..13cf6a50 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -573,6 +573,21 @@ public: using prop_tag = bool_prop_tag; }; +/** + * @brief bidirectional property, used to make bidirectional layers + * + */ +class Bidirectional : public nntrainer::Property { +public: + /** + * @brief Construct a new Bidirectional object + * + */ + Bidirectional(bool value = false); + static constexpr const char *key = "bidirectional"; + using prop_tag = bool_prop_tag; +}; + /** * @brief Identifiers to locate a connection which should be returned as whole * used in recurrent realizer diff --git a/nntrainer/layers/lstm.cpp b/nntrainer/layers/lstm.cpp index 6afca495..b0a35abd 100644 --- a/nntrainer/layers/lstm.cpp +++ b/nntrainer/layers/lstm.cpp @@ -31,16 +31,109 @@ enum LSTMParams { hidden_state, cell_state, ifgo, + reverse_weight_ih, + reverse_weight_hh, + reverse_bias_h, + reverse_bias_ih, + reverse_bias_hh, + reverse_hidden_state, + reverse_cell_state, + reverse_ifgo, dropout_mask }; +/** + * @brief run lstm fowarding for batch_first input + * + * @param NUM_GATE Number of gate which is 4 for lstm + * @param unit number of output neurons + * @param batch_size batch size + * @param max_timestep maximum timestep for lstm + * @param integrate_bias integrate bias_ih, bias_hh to bias_h + * @param acti_func activation function for memory cell, cell state + * @param recurrent_acti_func activation function for input/output/forget + * gate + * @param reverse indicate forward for reverse input in bidirectional lstm + * @param enable_dropout whether to apply dropout + * @param dropout_rate dropout rate + * @param input_ input + * @param weight_ih weight_ih. weight for input to hidden + * @param weight_hh weight_hh. weight for hidden to hidden + * @param bias_h bias_h. bias for input and hidden. + * @param bias_ih bias_ih. bias for input + * @param bias_hh bias_hh. bias for hidden + * @param hidden_state_ hidden state + * @param cell_state_ cell state + * @param ifgo_ input gate, forget gate, memory cell, output gate + * @param mask_ dropout mask + */ +static void batch_first_forwarding( + unsigned int NUM_GATE, const unsigned int unit, const unsigned int batch_size, + const unsigned int max_timestep, const unsigned int feature_size, + const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func, + ActiFunc &recurrent_acti_func, const bool reverse, const bool enable_dropout, + const float dropout_rate, const Tensor &input_, const Tensor &weight_ih, + const Tensor &weight_hh, const Tensor &bias_h, const Tensor &bias_ih, + const Tensor &bias_hh, Tensor &hidden_state_, Tensor &cell_state_, + Tensor &ifgo_, const Tensor &mask_) { + hidden_state_.setZero(); + cell_state_.setZero(); + + for (unsigned int batch = 0; batch < batch_size; ++batch) { + const Tensor input_sample = input_.getBatchSlice(batch, 1); + Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1); + Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1); + Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1); + + for (unsigned int t = 0; t < max_timestep; ++t) { + Tensor input = input_sample.getSharedDataTensor( + {feature_size}, (reverse ? max_timestep - 1 - t : t) * feature_size); + Tensor prev_hidden_state; + + if (!t) { + prev_hidden_state = Tensor(unit); + prev_hidden_state.setZero(); + } else { + prev_hidden_state = hidden_state_sample.getSharedDataTensor( + {unit}, (reverse ? (max_timestep - t) : (t - 1)) * unit); + } + Tensor hidden_state = hidden_state_sample.getSharedDataTensor( + {unit}, (reverse ? max_timestep - 1 - t : t) * unit); + Tensor prev_cell_state; + if (!t) { + prev_cell_state = Tensor(unit); + prev_cell_state.setZero(); + } else { + prev_cell_state = + cell_state_sample.getSharedDataTensor({unit}, (t - 1) * unit); + } + Tensor cell_state = + cell_state_sample.getSharedDataTensor({unit}, t * unit); + Tensor ifgo = + ifgo_sample.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE); + + lstmcell_forwarding(unit, 1, disable_bias, integrate_bias, acti_func, + recurrent_acti_func, input, prev_hidden_state, + prev_cell_state, hidden_state, cell_state, weight_ih, + weight_hh, bias_h, bias_ih, bias_hh, ifgo); + + if (enable_dropout) { + Tensor mask_sample = mask_.getBatchSlice(batch, 1); + Tensor mask = mask_sample.getSharedDataTensor({unit}, t * unit); + mask.dropout_mask(dropout_rate); + hidden_state.multiply_i(mask); + } + } + } +} + LSTMLayer::LSTMLayer() : LayerImpl(), lstm_props(props::Unit(), props::IntegrateBias(), props::HiddenStateActivation() = ActivationType::ACT_TANH, props::RecurrentActivation() = ActivationType::ACT_SIGMOID, - props::ReturnSequences(), props::DropOutRate(), - props::MaxTimestep()), + props::ReturnSequences(), props::Bidirectional(), + props::DropOutRate(), props::MaxTimestep()), acti_func(ActivationType::ACT_NONE, true), recurrent_acti_func(ActivationType::ACT_NONE, true), epsilon(1e-3) { @@ -70,6 +163,7 @@ void LSTMLayer::finalize(InitLayerContext &context) { std::get(lstm_props).get(); const bool return_sequences = std::get(lstm_props).get(); + const bool bidirectional = std::get(lstm_props).get(); const float dropout_rate = std::get(lstm_props).get(); if (context.getNumInputs() != 1) { @@ -91,34 +185,32 @@ void LSTMLayer::finalize(InitLayerContext &context) { std::get(lstm_props).set(max_timestep); const unsigned int feature_size = input_dim.width(); - // if return_sequences == false : - // output_dim = [ batch_size, 1, 1, unit ] - // else: - // output_dim = [ batch_size, 1, time_iteration, unit ] + // output_dim = [ batch_size, 1, return_sequences ? time_iteration : 1, + // bidirectional ? 2 * unit : unit ] const TensorDim output_dim(batch_size, 1, return_sequences ? max_timestep : 1, - unit); + bidirectional ? 2 * unit : unit); context.setOutputDimensions({output_dim}); // weight_initializer can be set seperately. weight_ih initializer, - // weight_hh initializer kernel initializer & recurrent_initializer in keras - // for now, it is set same way. + // weight_hh initializer kernel initializer & recurrent_initializer in + // keras for now, it is set same way. - // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ] -> - // i, f, g, o + // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ] + // -> i, f, g, o const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit}); wt_idx[LSTMParams::weight_ih] = context.requestWeight(weight_ih_dim, weight_initializer, weight_regularizer, weight_regularizer_constant, "weight_ih", true); - // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i, f, - // g, o + // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i, + // f, g, o const TensorDim weight_hh_dim({unit, NUM_GATE * unit}); wt_idx[LSTMParams::weight_hh] = context.requestWeight(weight_hh_dim, weight_initializer, weight_regularizer, weight_regularizer_constant, "weight_hh", true); if (!disable_bias) { if (integrate_bias) { - // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1, 1, 1, - // NUM_GATE * unit ] -> i, f, g, o + // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1, + // 1, 1, NUM_GATE * unit ] -> i, f, g, o const TensorDim bias_h_dim({NUM_GATE * unit}); wt_idx[LSTMParams::bias_h] = context.requestWeight(bias_h_dim, bias_initializer, @@ -129,7 +221,8 @@ void LSTMLayer::finalize(InitLayerContext &context) { wt_idx[LSTMParams::bias_ih] = context.requestWeight(bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, "bias_ih", true); - // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o + // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, + // o const TensorDim bias_hh_dim({NUM_GATE * unit}); wt_idx[LSTMParams::bias_hh] = context.requestWeight(bias_hh_dim, bias_initializer, @@ -154,6 +247,67 @@ void LSTMLayer::finalize(InitLayerContext &context) { context.requestTensor(ifgo_dim, "ifgo", Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); + if (bidirectional) { + // weight_initializer can be set seperately. weight_ih initializer, + // weight_hh initializer kernel initializer & recurrent_initializer in + // keras for now, it is set same way. + + // reverse_weight_ih ( input to hidden ) : [ 1, 1, feature_size, + // NUM_GATE * unit ] -> i, f, g, o + const TensorDim reverse_weight_ih_dim({feature_size, NUM_GATE * unit}); + wt_idx[LSTMParams::reverse_weight_ih] = context.requestWeight( + reverse_weight_ih_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, "reverse_weight_ih", true); + // reverse_weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * + // unit ] + // -> i, f, g, o + const TensorDim reverse_weight_hh_dim({unit, NUM_GATE * unit}); + wt_idx[LSTMParams::reverse_weight_hh] = context.requestWeight( + reverse_weight_hh_dim, weight_initializer, weight_regularizer, + weight_regularizer_constant, "reverse_weight_hh", true); + if (!disable_bias) { + if (integrate_bias) { + // reverse_bias_h ( input bias, hidden bias are integrate to 1 bias + // ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o + const TensorDim reverse_bias_h_dim({NUM_GATE * unit}); + wt_idx[LSTMParams::reverse_bias_h] = context.requestWeight( + reverse_bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, + "reverse_bias_h", true); + } else { + // reverse_bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> + // i, f, g, o + const TensorDim reverse_bias_ih_dim({NUM_GATE * unit}); + wt_idx[LSTMParams::reverse_bias_ih] = context.requestWeight( + reverse_bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, + "reverse_bias_ih", true); + // reverse_bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> + // i, f, g, o + const TensorDim reverse_bias_hh_dim({NUM_GATE * unit}); + wt_idx[LSTMParams::reverse_bias_hh] = context.requestWeight( + reverse_bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, + "reverse_bias_hh", true); + } + } + + // reverse_hidden_state_dim : [ batch_size, 1, max_timestep, unit ] + const TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit); + wt_idx[LSTMParams::reverse_hidden_state] = context.requestTensor( + reverse_hidden_state_dim, "reverse_hidden_state", + Tensor::Initializer::NONE, true, TensorLifespan::ITERATION_LIFESPAN); + // reverse_cell_state_dim : [ batch_size, 1, max_timestep, unit ] + const TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit); + wt_idx[LSTMParams::reverse_cell_state] = context.requestTensor( + reverse_cell_state_dim, "reverse_cell_state", Tensor::Initializer::NONE, + true, TensorLifespan::ITERATION_LIFESPAN); + + // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ] + const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep, + NUM_GATE * unit); + wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor( + reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true, + TensorLifespan::ITERATION_LIFESPAN); + } + if (dropout_rate > epsilon) { // dropout_mask_dim = [ batch, 1, time_iteration, unit ] const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit); @@ -186,12 +340,16 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { const bool integrate_bias = std::get(lstm_props).get(); const bool return_sequences = std::get(lstm_props).get(); + const bool bidirectional = std::get(lstm_props).get(); const float dropout_rate = std::get(lstm_props).get(); const unsigned int max_timestep = std::get(lstm_props).get(); - const Tensor &inputs = context.getInput(SINGLE_INOUT_IDX); - const TensorDim input_dim = inputs.getDim(); + unsigned int bidirectional_constant = bidirectional ? 2 : 1; + bool enable_dropout = dropout_rate > epsilon && training; + + const Tensor &input = context.getInput(SINGLE_INOUT_IDX); + const TensorDim input_dim = input.getDim(); const unsigned int batch_size = input_dim.batch(); const unsigned int feature_size = input_dim.width(); Tensor &output = context.getOutput(SINGLE_INOUT_IDX); @@ -209,70 +367,81 @@ void LSTMLayer::forwarding(RunLayerContext &context, bool training) { ? context.getWeight(wt_idx[LSTMParams::bias_hh]) : empty; - Tensor &hs = context.getTensor(wt_idx[LSTMParams::hidden_state]); - Tensor &cs = context.getTensor(wt_idx[LSTMParams::cell_state]); - Tensor &ifgos = context.getTensor(wt_idx[LSTMParams::ifgo]); - - hs.setZero(); - cs.setZero(); - - for (unsigned int batch = 0; batch < batch_size; ++batch) { - const Tensor input_batch = inputs.getBatchSlice(batch, 1); - Tensor hs_batch = hs.getBatchSlice(batch, 1); - Tensor cs_batch = cs.getBatchSlice(batch, 1); - Tensor ifgo_batch = ifgos.getBatchSlice(batch, 1); - - for (unsigned int t = 0; t < max_timestep; ++t) { - Tensor input; - if (input_batch.height() != 1) - input = - input_batch.getSharedDataTensor({feature_size}, t * feature_size); - else - input = input_batch; - - Tensor prev_hidden_state; - if (!t) { - prev_hidden_state = Tensor(unit); - prev_hidden_state.setZero(); - } else { - prev_hidden_state = - hs_batch.getSharedDataTensor({unit}, (t - 1) * unit); - } - Tensor hidden_state = hs_batch.getSharedDataTensor({unit}, t * unit); - Tensor prev_cell_state; - if (!t) { - prev_cell_state = Tensor(unit); - prev_cell_state.setZero(); - } else { - prev_cell_state = cs_batch.getSharedDataTensor({unit}, (t - 1) * unit); - } - Tensor cell_state = cs_batch.getSharedDataTensor({unit}, t * unit); - Tensor ifgo = - ifgo_batch.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE); - - lstmcell_forwarding(unit, 1, disable_bias, integrate_bias, acti_func, - recurrent_acti_func, input, prev_hidden_state, - prev_cell_state, hidden_state, cell_state, weight_ih, - weight_hh, bias_h, bias_ih, bias_hh, ifgo); - - if (dropout_rate > epsilon && training) { - Tensor masks = context.getTensor(wt_idx[LSTMParams::dropout_mask]) - .getBatchSlice(batch, 1); - Tensor mask = masks.getSharedDataTensor({unit}, t * unit); - mask.dropout_mask(dropout_rate); - hidden_state.multiply_i(mask); - } - } + Tensor &hidden_state = context.getTensor(wt_idx[LSTMParams::hidden_state]); + Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]); + Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]); + + Tensor &mask = enable_dropout + ? context.getTensor(wt_idx[LSTMParams::dropout_mask]) + : empty; + + batch_first_forwarding( + NUM_GATE, unit, batch_size, max_timestep, feature_size, disable_bias, + integrate_bias, acti_func, recurrent_acti_func, false, enable_dropout, + dropout_rate, input, weight_ih, weight_hh, bias_h, bias_ih, bias_hh, + hidden_state, cell_state, ifgo, mask); + + if (bidirectional) { + const Tensor &reverse_weight_ih = + context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]); + const Tensor &reverse_weight_hh = + context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]); + const Tensor &reverse_bias_h = + !disable_bias && integrate_bias + ? context.getWeight(wt_idx[LSTMParams::reverse_bias_h]) + : empty; + const Tensor &reverse_bias_ih = + !disable_bias && !integrate_bias + ? context.getWeight(wt_idx[LSTMParams::reverse_bias_ih]) + : empty; + const Tensor &reverse_bias_hh = + !disable_bias && !integrate_bias + ? context.getWeight(wt_idx[LSTMParams::reverse_bias_hh]) + : empty; + + Tensor &reverse_hidden_state = + context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]); + Tensor &reverse_cell_state = + context.getTensor(wt_idx[LSTMParams::reverse_cell_state]); + Tensor &reverse_ifgo = context.getTensor(wt_idx[LSTMParams::reverse_ifgo]); + + batch_first_forwarding( + NUM_GATE, unit, batch_size, max_timestep, feature_size, disable_bias, + integrate_bias, acti_func, recurrent_acti_func, true, enable_dropout, + dropout_rate, input, reverse_weight_ih, reverse_weight_hh, reverse_bias_h, + reverse_bias_ih, reverse_bias_hh, reverse_hidden_state, + reverse_cell_state, reverse_ifgo, mask); } - if (return_sequences) { - std::copy(hs.getData(), hs.getData() + hs.size(), output.getData()); + if (return_sequences && !bidirectional) { + std::copy(hidden_state.getData(), + hidden_state.getData() + hidden_state.size(), output.getData()); } else { + unsigned int start_timestep = 0; + unsigned int end_timestep = return_sequences ? max_timestep : 1; for (unsigned int batch = 0; batch < batch_size; ++batch) { - float *hidden_state_data = - hs.getAddress(batch * max_timestep * unit + (max_timestep - 1) * unit); - float *output_data = output.getAddress(batch * unit); - std::copy(hidden_state_data, hidden_state_data + unit, output_data); + for (unsigned int timestep = start_timestep; timestep < end_timestep; + ++timestep) { + float *hidden_state_data = hidden_state.getAddress( + batch * max_timestep * unit + + (return_sequences ? 0 : (max_timestep - 1) * unit) + timestep * unit); + float *output_data = + output.getAddress(batch * (return_sequences ? max_timestep : 1) * + bidirectional_constant * unit + + timestep * bidirectional_constant * unit); + std::copy(hidden_state_data, hidden_state_data + unit, output_data); + + if (bidirectional) { + Tensor &reverse_hidden_state = + context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]); + float *reverse_hidden_state_data = reverse_hidden_state.getAddress( + batch * max_timestep * unit + + (return_sequences ? 0 : (max_timestep - 1) * unit) + + timestep * unit); + std::copy(reverse_hidden_state_data, reverse_hidden_state_data + unit, + output_data + unit); + } + } } } } @@ -354,8 +523,10 @@ void LSTMLayer::calcGradient(RunLayerContext &context) { Tensor rdata = incoming_derivative.getSharedDataTensor({unit}, batch * unit); - /// @note this is not copying from start ~ end but only start time step - // This is copying for self rolling as well as last recurrent unrolled. + /// @note this is not copying from start ~ end but only start time + /// step + // This is copying for self rolling as well as last recurrent + // unrolled. if ((unsigned)start_timestep + 1 == max_timestep) { data.fill(rdata); } else { @@ -427,8 +598,8 @@ void LSTMLayer::calcGradient(RunLayerContext &context) { Tensor d_ifgo = d_ifgo_batch.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE); - // Temporary variable for d_prev_hidden_state. d_prev_hidden_state already - // have precalculated values from incomming derivatives + // Temporary variable for d_prev_hidden_state. d_prev_hidden_state + // already have precalculated values from incomming derivatives Tensor d_prev_hidden_state_temp; lstmcell_calcGradient(unit, 1, disable_bias, integrate_bias, acti_func, @@ -443,10 +614,18 @@ void LSTMLayer::calcGradient(RunLayerContext &context) { } void LSTMLayer::setBatch(RunLayerContext &context, unsigned int batch) { + const bool bidirectional = std::get(lstm_props).get(); + context.updateTensor(wt_idx[LSTMParams::hidden_state], batch); context.updateTensor(wt_idx[LSTMParams::cell_state], batch); context.updateTensor(wt_idx[LSTMParams::ifgo], batch); + if (bidirectional) { + context.updateTensor(wt_idx[LSTMParams::reverse_hidden_state], batch); + context.updateTensor(wt_idx[LSTMParams::reverse_cell_state], batch); + context.updateTensor(wt_idx[LSTMParams::reverse_ifgo], batch); + } + if (std::get(lstm_props).get() > epsilon) { context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch); } diff --git a/nntrainer/layers/lstm.h b/nntrainer/layers/lstm.h index 804ba0e0..e68369d5 100644 --- a/nntrainer/layers/lstm.h +++ b/nntrainer/layers/lstm.h @@ -106,15 +106,16 @@ private: * HiddenStateActivation: activation type for hidden state. default is tanh * RecurrentActivation: activation type for recurrent. default is sigmoid * ReturnSequence: option for return sequence + * Bidirectional: option for bidirectional * DropOutRate: dropout rate * MaxTimestep: maximum timestep for lstm * * */ std::tuple + props::Bidirectional, props::DropOutRate, props::MaxTimestep> lstm_props; - std::array wt_idx; /**< indices of the weights */ + std::array wt_idx; /**< indices of the weights */ /** * @brief activation function for h_t : default is tanh diff --git a/test/input_gen/genModelsRecurrent_v2.py b/test/input_gen/genModelsRecurrent_v2.py index 6197a19d..f704d8c0 100644 --- a/test/input_gen/genModelsRecurrent_v2.py +++ b/test/input_gen/genModelsRecurrent_v2.py @@ -56,15 +56,17 @@ class RNNCellStacked(torch.nn.Module): return ret, loss class LSTMStacked(torch.nn.Module): - def __init__(self, num_lstm=1): + def __init__(self, num_lstm=1, bidirectional=False): super().__init__() self.input_size = self.hidden_size = 2 self.num_lstm = num_lstm + self.bidirectional=bidirectional self.lstms = torch.nn.ModuleList( [ - torch.nn.LSTM(self.input_size, self.hidden_size, batch_first=True) - # torch.nn.LSTM(self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True) - for _ in range(num_lstm) + torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, batch_first=True, bidirectional=bidirectional) + # Intended comment + # torch.nn.LSTM(self.input_size if self.bidirectional == False or i == 0 else 2 * self.input_size, self.hidden_size, num_layers=num_lstm, batch_first=True, bidirectional=bidirectional) + for i in range(num_lstm) ] ) self.loss = torch.nn.MSELoss() @@ -73,12 +75,12 @@ class LSTMStacked(torch.nn.Module): out = inputs[0] states = inputs[1:] # hs = [states[2 * i] for i in range(self.num_lstm)] - hs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)] + hs = [torch.zeros((2, 3, 2)) if self.bidirectional else torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)] # cs = [states[2 * i + 1] for i in range(self.num_lstm)] - cs = [torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)] + cs = [torch.zeros((2, 3, 2)) if self.bidirectional else torch.zeros((1, 3, 2)) for _ in range(self.num_lstm)] for i, (lstm, h, c) in enumerate(zip(self.lstms, hs, cs)): out, (hs[i], cs[i]) = lstm(out, (h, c)) - + loss = self.loss(out, labels[0]) return out, loss @@ -212,9 +214,9 @@ if __name__ == "__main__": name="rnncell_stacked", ) - unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 1, 3, 2, 2, 2] + unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 1, 3, 2, 2, 2, False] record_v2( - LSTMStacked(num_lstm=num_lstm), + LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional), iteration=iteration, input_dims=[(batch_size, unroll_for, feature_size)], # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)], @@ -222,9 +224,9 @@ if __name__ == "__main__": name="lstm_single", ) - unroll_for, num_lstm, batch_size, unit, feature_size, iteration = [2, 2, 3, 2, 2, 2] + unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 2, 3, 2, 2, 2, False] record_v2( - LSTMStacked(num_lstm=num_lstm), + LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional), iteration=iteration, input_dims=[(batch_size, unroll_for, feature_size)], # input_dims=[(batch_size, unroll_for, feature_size)] + [(1, batch_size, unit) for _ in range(2 * num_lstm)], @@ -232,6 +234,26 @@ if __name__ == "__main__": name="lstm_stacked", ) + unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 1, 3, 2, 2, 2, True] + record_v2( + LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional), + iteration=iteration, + input_dims=[(batch_size, unroll_for, feature_size)], + # input_dims=[(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)], + label_dims=[(batch_size, unroll_for, 2 * unit)], + name="bidirectional_lstm_single", + ) + + unroll_for, num_lstm, batch_size, unit, feature_size, iteration, bidirectional = [2, 2, 3, 2, 2, 2, True] + record_v2( + LSTMStacked(num_lstm=num_lstm, bidirectional=bidirectional), + iteration=iteration, + input_dims=[(batch_size, unroll_for, feature_size)], + # input_dims=[(batch_size, unroll_for, feature_size)] + [(2, batch_size, unit) for _ in range(2 * num_lstm)], + label_dims=[(batch_size, unroll_for, 2 * unit)], + name="bidirectional_lstm_stacked", + ) + unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 1, 2, 3, 2, 2, 2] record_v2( LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell), diff --git a/test/input_gen/transLayer_v2.py b/test/input_gen/transLayer_v2.py index 9373d673..ca0b6215 100644 --- a/test/input_gen/transLayer_v2.py +++ b/test/input_gen/transLayer_v2.py @@ -70,7 +70,21 @@ def zoneout_translate(model): new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3], hidden_state, cell_state] yield from new_params -@register_for_((torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell)) +@register_for_((torch.nn.LSTM)) +def lstm_translate(model): + params = [(name, tensor.detach()) for name, tensor in model.named_parameters()] + # [hidden, input] -> [input, hidden] + def transpose_(weight): + return (weight[0], weight[1].transpose(1, 0)) + + new_params = [transpose_(params[0]), transpose_(params[1]), params[2], params[3]] + if model.bidirectional: + reverse_params = [transpose_(params[4]), transpose_(params[5]), params[6], params[7]] + new_params += reverse_params + + yield from new_params + +@register_for_((torch.nn.RNNCell, torch.nn.LSTMCell)) def rnn_lstm_translate(model): params = [(name, tensor.detach()) for name, tensor in model.named_parameters()] # [hidden, input] -> [input, hidden] diff --git a/test/unittest/models/unittest_models_recurrent.cpp b/test/unittest/models/unittest_models_recurrent.cpp index 9aa33644..ae812f95 100644 --- a/test/unittest/models/unittest_models_recurrent.cpp +++ b/test/unittest/models/unittest_models_recurrent.cpp @@ -175,6 +175,47 @@ static std::unique_ptr makeStackedLSTM() { return nn; } +// static std::unique_ptr makeSingleBidirectionalLSTM() { +// std::unique_ptr nn(new NeuralNetwork()); +// nn->setProperty({"batch_size=3"}); + +// auto outer_graph = makeGraph({ +// {"input", {"name=input", "input_shape=1:2:2"}}, +// {"lstm", +// {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true", +// "bidirectional=true"}}, +// {"mse", {"name=loss", "input_layers=a1"}}, +// }); +// for (auto &node : outer_graph) { +// nn->addLayer(node); +// } + +// nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = +// 0.1"})); return nn; +// } + +// static std::unique_ptr makeStackedBidirectionalLSTM() { +// std::unique_ptr nn(new NeuralNetwork()); +// nn->setProperty({"batch_size=3"}); + +// auto outer_graph = makeGraph({ +// {"input", {"name=input", "input_shape=1:2:2"}}, +// {"lstm", +// {"name=a1", "unit=2", "integrate_bias=false", "return_sequences=true", +// "bidirectional=true"}}, +// {"lstm", +// {"name=a2", "unit=2", "integrate_bias=false", "return_sequences=true", +// "bidirectional=true"}}, +// {"mse", {"name=loss"}}, +// }); +// for (auto &node : outer_graph) { +// nn->addLayer(node); +// } + +// nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = +// 0.1"})); return nn; +// } + static std::unique_ptr makeSingleLSTMCell() { std::unique_ptr nn(new NeuralNetwork()); nn->setProperty({"batch_size=3"}); @@ -526,6 +567,10 @@ INSTANTIATE_TEST_CASE_P( ModelTestOption::COMPARE_V2), mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::ALL_V2), mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::ALL_V2), + // mkModelTc_V2(makeSingleBidirectionalLSTM, "bidirectional_lstm_single", + // ModelTestOption::ALL_V2), + // mkModelTc_V2(makeStackedBidirectionalLSTM, "bidirectional_lstm_stacked", + // ModelTestOption::ALL_V2), mkModelTc_V2(makeSingleLSTMCell, "lstmcell_single", ModelTestOption::ALL_V2), mkModelTc_V2(makeStackedLSTMCell, "lstmcell_stacked", -- 2.34.1