#include <input_layer.h>
#include <layer_node.h>
#include <lstm.h>
-#include <lstmcell.h>
#include <nntrainer_error.h>
#include <node_exporter.h>
#include <recurrent_realizer.h>
auto is_recurrent_type = [](LayerNode *node) {
return node->getType() == RNNCellLayer::type ||
node->getType() == LSTMLayer::type ||
- node->getType() == LSTMCellLayer::type ||
node->getType() == ZoneoutLSTMCellLayer::type ||
node->getType() == GRUCellLayer::type;
};
namespace nntrainer {
-static constexpr size_t SINGLE_INOUT_IDX = 0;
-
enum LSTMCellParams {
weight_ih,
weight_hh,
bias_h,
bias_ih,
bias_hh,
- hidden_state,
- cell_state,
ifgo,
dropout_mask
};
lstmcell_props(props::Unit(), props::IntegrateBias(),
props::HiddenStateActivation() = ActivationType::ACT_TANH,
props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
- props::DropOutRate(), props::MaxTimestep(), props::Timestep()),
+ props::DropOutRate()),
acti_func(ActivationType::ACT_NONE, true),
recurrent_acti_func(ActivationType::ACT_NONE, true),
epsilon(1e-3) {
const ActivationType recurrent_activation_type =
std::get<props::RecurrentActivation>(lstmcell_props).get();
const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
- const unsigned int max_timestep =
- std::get<props::MaxTimestep>(lstmcell_props).get();
- if (context.getNumInputs() != 1)
- throw std::invalid_argument("LSTMCell layer takes only one input");
- if (std::get<props::MaxTimestep>(lstmcell_props).empty())
- throw std::invalid_argument(
- "Number of unroll steps(max timestep) must be provided to LSTM cell");
- if (std::get<props::Timestep>(lstmcell_props).empty())
+ if (context.getNumInputs() != 3) {
throw std::invalid_argument(
- "Current Timestep must be provided to LSTM cell");
+ "Number of input is not 3. LSTMCell layer should takes 3 inputs");
+ }
// input_dim = [ batch_size, 1, 1, feature_size ]
- const TensorDim &input_dim = context.getInputDimensions()[0];
- if (input_dim.channel() != 1 || input_dim.height() != 1)
+ const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
+ if (input_dim.channel() != 1 || input_dim.height() != 1) {
throw std::invalid_argument(
"Input must be single time dimension for LSTMCell (shape should be "
"[batch_size, 1, 1, feature_size])");
+ }
+ // input_hidden_state_dim = [ batch, 1, 1, unit ]
+ const TensorDim &input_hidden_state_dim =
+ context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
+ if (input_hidden_state_dim.channel() != 1 ||
+ input_hidden_state_dim.height() != 1) {
+ throw std::invalid_argument("Input hidden state's dimension should be "
+ "[batch, 1, 1, unit] for LSTMCell");
+ }
+ // input_cell_state_dim = [ batch, 1, 1, unit ]
+ const TensorDim &input_cell_state_dim =
+ context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
+ if (input_cell_state_dim.channel() != 1 ||
+ input_cell_state_dim.height() != 1) {
+ throw std::invalid_argument("Input cell state's dimension should be "
+ "[batch, 1, 1, unit] for LSTMCell");
+ }
const unsigned int batch_size = input_dim.batch();
const unsigned int feature_size = input_dim.width();
- // output_dim = [ batch_size, 1, 1, unit ]
- const TensorDim output_dim(batch_size, 1, 1, unit);
- context.setOutputDimensions({output_dim});
+ // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
+ const TensorDim output_hidden_state_dim = input_hidden_state_dim;
+ // output_cell_state_dim = [ batch_size, 1, 1, unit ]
+ const TensorDim output_cell_state_dim = input_cell_state_dim;
+
+ std::vector<VarGradSpecV2> out_specs;
+ out_specs.push_back(
+ InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
+ TensorLifespan::FORWARD_DERIV_LIFESPAN));
+ out_specs.push_back(
+ InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
+ TensorLifespan::FORWARD_DERIV_LIFESPAN));
+ context.requestOutputs(std::move(out_specs));
// weight_initializer can be set seperately. weight_ih initializer,
// weight_hh initializer kernel initializer & recurrent_initializer in keras
}
}
- /**
- * TODO: hidden_state is only used from the previous timestep. Once it is
- * supported as input, no need to cache the hidden_state itself
- */
- /** hidden_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
- const TensorDim hidden_state_dim(max_timestep * batch_size, 1, 1, unit);
- wt_idx[LSTMCellParams::hidden_state] = context.requestTensor(
- hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
- /** cell_state_dim = [ max_timestep * batch_size, 1, 1, unit ] */
- const TensorDim cell_state_dim(max_timestep * batch_size, 1, 1, unit);
- wt_idx[LSTMCellParams::cell_state] = context.requestTensor(
- cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
- TensorLifespan::ITERATION_LIFESPAN, false);
-
/** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
wt_idx[LSTMCellParams::ifgo] =
const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
const bool integrate_bias =
std::get<props::IntegrateBias>(lstmcell_props).get();
- const unsigned int max_timestep =
- std::get<props::MaxTimestep>(lstmcell_props).get();
- const unsigned int timestep = std::get<props::Timestep>(lstmcell_props).get();
- const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
- Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+ const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+ const Tensor &prev_hidden_state =
+ context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+ const Tensor &prev_cell_state =
+ context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
+ Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
+ Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
+
const unsigned int batch_size = input.getDim().batch();
const Tensor &weight_ih =
const Tensor &weight_hh =
context.getWeight(wt_idx[LSTMCellParams::weight_hh]);
Tensor empty;
- Tensor &bias_h = !disable_bias && integrate_bias
- ? context.getWeight(wt_idx[LSTMCellParams::bias_h])
- : empty;
- Tensor &bias_ih = !disable_bias && !integrate_bias
- ? context.getWeight(wt_idx[LSTMCellParams::bias_ih])
- : empty;
- Tensor &bias_hh = !disable_bias && !integrate_bias
- ? context.getWeight(wt_idx[LSTMCellParams::bias_hh])
- : empty;
-
- Tensor &hs = context.getTensor(wt_idx[LSTMCellParams::hidden_state]);
- hs.reshape({max_timestep, 1, batch_size, unit});
- Tensor prev_hidden_state;
- if (!timestep) {
- prev_hidden_state = Tensor(batch_size, unit);
- prev_hidden_state.setZero();
- } else {
- prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
- }
- prev_hidden_state.reshape({batch_size, 1, 1, unit});
- Tensor hidden_state = hs.getBatchSlice(timestep, 1);
- hidden_state.reshape({batch_size, 1, 1, unit});
-
- Tensor &cs = context.getTensor(wt_idx[LSTMCellParams::cell_state]);
- cs.reshape({max_timestep, 1, batch_size, unit});
- Tensor prev_cell_state;
- if (!timestep) {
- prev_cell_state = Tensor(batch_size, unit);
- prev_cell_state.setZero();
- } else {
- prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
- }
- prev_cell_state.reshape({batch_size, 1, 1, unit});
- Tensor cell_state = cs.getBatchSlice(timestep, 1);
- cell_state.reshape({batch_size, 1, 1, unit});
+ const Tensor &bias_h = !disable_bias && integrate_bias
+ ? context.getWeight(wt_idx[LSTMCellParams::bias_h])
+ : empty;
+ const Tensor &bias_ih = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[LSTMCellParams::bias_ih])
+ : empty;
+ const Tensor &bias_hh = !disable_bias && !integrate_bias
+ ? context.getWeight(wt_idx[LSTMCellParams::bias_hh])
+ : empty;
Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
dropout_mask.dropout_mask(dropout_rate);
hidden_state.multiply_i(dropout_mask);
}
-
- output.copyData(hidden_state);
}
void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
const Tensor &weight_ih =
context.getWeight(wt_idx[LSTMCellParams::weight_ih]);
- Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+ Tensor &outgoing_derivative =
+ context.getOutgoingDerivative(INOUT_INDEX::INPUT);
lstmcell_calcDerivative(d_ifgo, weight_ih, outgoing_derivative);
}
const bool integrate_bias =
std::get<props::IntegrateBias>(lstmcell_props).get();
const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
- const unsigned int max_timestep =
- std::get<props::MaxTimestep>(lstmcell_props);
- const unsigned int timestep = std::get<props::Timestep>(lstmcell_props);
- const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
- const Tensor &incoming_derivative =
- context.getIncomingDerivative(SINGLE_INOUT_IDX);
+ const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
+ const Tensor &prev_hidden_state =
+ context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
+ Tensor &d_prev_hidden_state =
+ context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
+ const Tensor &prev_cell_state =
+ context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
+ Tensor &d_prev_cell_state =
+ context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
+ const Tensor &d_hidden_state =
+ context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
+ const Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
+ const Tensor &d_cell_state =
+ context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
+
unsigned int batch_size = input.getDim().batch();
Tensor &d_weight_ih =
? context.getWeightGrad(wt_idx[LSTMCellParams::bias_hh])
: empty;
- Tensor &hs = context.getTensor(wt_idx[LSTMCellParams::hidden_state]);
- hs.reshape({max_timestep, 1, batch_size, unit});
- Tensor prev_hidden_state;
- if (!timestep) {
- prev_hidden_state = Tensor(batch_size, unit);
- prev_hidden_state.setZero();
- } else {
- prev_hidden_state = hs.getBatchSlice(timestep - 1, 1);
- }
- prev_hidden_state.reshape({batch_size, 1, 1, unit});
-
- Tensor &d_hs = context.getTensorGrad(wt_idx[LSTMCellParams::hidden_state]);
- d_hs.reshape({max_timestep, 1, batch_size, unit});
- Tensor d_prev_hidden_state;
- if (!timestep) {
- d_prev_hidden_state = Tensor(batch_size, unit);
- d_prev_hidden_state.setZero();
- } else {
- d_prev_hidden_state = d_hs.getBatchSlice(timestep - 1, 1);
- }
- d_prev_hidden_state.reshape({batch_size, 1, 1, unit});
- Tensor d_hidden_state = d_hs.getBatchSlice(timestep, 1);
- d_hidden_state.reshape({batch_size, 1, 1, unit});
-
- Tensor &cs = context.getTensor(wt_idx[LSTMCellParams::cell_state]);
- cs.reshape({max_timestep, 1, batch_size, unit});
- Tensor prev_cell_state;
- if (!timestep) {
- prev_cell_state = Tensor(batch_size, unit);
- prev_cell_state.setZero();
- } else {
- prev_cell_state = cs.getBatchSlice(timestep - 1, 1);
- }
- prev_cell_state.reshape({batch_size, 1, 1, unit});
- Tensor cell_state = cs.getBatchSlice(timestep, 1);
- cell_state.reshape({batch_size, 1, 1, unit});
-
- Tensor &d_cs = context.getTensorGrad(wt_idx[LSTMCellParams::cell_state]);
- d_cs.reshape({max_timestep, 1, batch_size, unit});
- Tensor d_prev_cell_state;
- if (!timestep) {
- d_prev_cell_state = Tensor(batch_size, unit);
- d_prev_cell_state.setZero();
- } else {
- d_prev_cell_state = d_cs.getBatchSlice(timestep - 1, 1);
- }
- d_prev_cell_state.reshape({batch_size, 1, 1, unit});
- Tensor d_cell_state = d_cs.getBatchSlice(timestep, 1);
- d_cell_state.reshape({batch_size, 1, 1, unit});
-
const Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
- if (timestep + 1 == max_timestep) {
+ if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_ih])) {
d_weight_ih.setZero();
+ }
+ if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_hh])) {
d_weight_hh.setZero();
- if (!disable_bias) {
- if (integrate_bias) {
+ }
+ if (!disable_bias) {
+ if (integrate_bias) {
+ if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_h])) {
d_bias_h.setZero();
- } else {
+ }
+ } else {
+ if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_ih])) {
d_bias_ih.setZero();
+ }
+ if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_hh])) {
d_bias_hh.setZero();
}
}
-
- d_hidden_state.setZero();
- d_cell_state.setZero();
}
+ Tensor d_hidden_state_masked;
if (dropout_rate > epsilon) {
Tensor &dropout_mask =
context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
- d_hidden_state.multiply_i(dropout_mask);
+ d_hidden_state.multiply(dropout_mask, d_hidden_state_masked);
}
- d_hidden_state.add_i(incoming_derivative);
-
- lstmcell_calcGradient(unit, batch_size, disable_bias, integrate_bias,
- acti_func, recurrent_acti_func, input,
- prev_hidden_state, d_prev_hidden_state, prev_cell_state,
- d_prev_cell_state, d_hidden_state, cell_state,
- d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
- d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
+ lstmcell_calcGradient(
+ unit, batch_size, disable_bias, integrate_bias, acti_func,
+ recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
+ prev_cell_state, d_prev_cell_state,
+ dropout_rate > epsilon ? d_hidden_state_masked : d_hidden_state, cell_state,
+ d_cell_state, d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
+ d_bias_hh, ifgo, d_ifgo);
}
void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
- const unsigned int max_timestep =
- std::get<props::MaxTimestep>(lstmcell_props);
- context.updateTensor(wt_idx[LSTMCellParams::hidden_state],
- max_timestep * batch);
- context.updateTensor(wt_idx[LSTMCellParams::cell_state],
- max_timestep * batch);
context.updateTensor(wt_idx[LSTMCellParams::ifgo], batch);
if (dropout_rate > epsilon) {
context.updateTensor(wt_idx[LSTMCellParams::dropout_mask], batch);
* @copydoc Layer::calcGradient(RunLayerContext &context)
*/
void calcGradient(RunLayerContext &context) override;
+
/**
* @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
*/
private:
static constexpr unsigned int NUM_GATE = 4;
+ enum INOUT_INDEX {
+ INPUT = 0,
+ INPUT_HIDDEN_STATE = 1,
+ INPUT_CELL_STATE = 2,
+ OUTPUT_HIDDEN_STATE = 0,
+ OUTPUT_CELL_STATE = 1
+ };
/**
* Unit: number of output neurons
* HiddenStateActivation: activation type for hidden state. default is tanh
* RecurrentActivation: activation type for recurrent. default is sigmoid
* DropOutRate: dropout rate
- * MaxTimestep: maximum timestep for lstmcell
- * TimeStep: timestep for which lstm should operate
*
* */
std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
- props::RecurrentActivation, props::DropOutRate, props::MaxTimestep,
- props::Timestep>
+ props::RecurrentActivation, props::DropOutRate>
lstmcell_props;
- std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
+ std::array<unsigned int, 7> wt_idx; /**< indices of the weights */
/**
* @brief activation function for h_t : default is tanh
return_state=False)
record_single(lstm, (3, 4, 7), "lstm_multi_step_seq_act")
+ unit, batch_size, unroll_for, feature_size, state_num = [5, 3, 1, 7, 2]
+ lstmcell = K.layers.LSTMCell(units=unit,
+ activation="tanh",
+ recurrent_activation="sigmoid",
+ bias_initializer='glorot_uniform')
+ record_single(lstmcell, [(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num)], "lstmcell_single_step", input_type='float')
+
gru = K.layers.GRU(units=5, activation="tanh",
recurrent_activation="sigmoid",
bias_initializer='GlorotUniform',
loss = self.loss(ret, labels[0])
return ret, loss
+class LSTMCellStacked(torch.nn.Module):
+ def __init__(self, unroll_for=2, num_lstmcell=1):
+ super().__init__()
+ self.input_size = self.hidden_size = 2
+ self.lstmcells = torch.nn.ModuleList(
+ [
+ torch.nn.LSTMCell(self.input_size, self.hidden_size, bias=True)
+ for _ in range(num_lstmcell)
+ ]
+ )
+ self.unroll_for = unroll_for
+ self.num_lstmcell = num_lstmcell
+ self.loss = torch.nn.MSELoss()
+
+ def forward(self, inputs, labels):
+ out = inputs[0]
+ states = inputs[1:]
+ hs = [states[2 * i] for i in range(self.num_lstmcell)]
+ cs = [states[2 * i + 1] for i in range(self.num_lstmcell)]
+ ret = []
+ for _ in range(self.unroll_for):
+ for i, (lstm, h, c) in enumerate(zip(self.lstmcells, hs, cs)):
+ hs[i], cs[i] = lstm(out, (h, c))
+ out = hs[i]
+ ret.append(out)
+
+ ret = torch.stack(ret, dim=1)
+ loss = self.loss(ret, labels[0])
+ return ret, loss
+
class ZoneoutLSTMStacked(torch.nn.Module):
def __init__(self, batch_size=3, unroll_for=2, num_lstm=1, hidden_state_zoneout_rate=1, cell_state_zoneout_rate=1):
super().__init__()
name="lstm_stacked",
)
+ unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 1, 2, 3, 2, 2, 2]
+ record_v2(
+ LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
+ iteration=iteration,
+ input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
+ label_dims=[(batch_size, unroll_for, unit)],
+ name="lstmcell_single",
+ )
+
+ unroll_for, num_lstmcell, state_num, batch_size, unit, feature_size, iteration = [2, 2, 2, 3, 2, 2, 2]
+ record_v2(
+ LSTMCellStacked(unroll_for=unroll_for, num_lstmcell=num_lstmcell),
+ iteration=iteration,
+ input_dims=[(batch_size, feature_size)] + [(batch_size, unit) for _ in range(state_num * num_lstmcell)],
+ label_dims=[(batch_size, unroll_for, unit)],
+ name="lstmcell_stacked",
+ )
+
unroll_for, num_lstm, state_num, batch_size, unit, feature_size, iteration, hidden_state_zoneout_rate, cell_state_zoneout_rate = [2, 1, 2, 1, 2, 2, 2, 0.0, 0.0]
record_v2(
ZoneoutLSTMStacked(batch_size=batch_size, unroll_for=unroll_for, num_lstm=num_lstm, hidden_state_zoneout_rate=hidden_state_zoneout_rate, cell_state_zoneout_rate=cell_state_zoneout_rate),
return [layer(tf_output) for layer in self.stub_layers]
##
+# @brief Translayer for lstmcell layer
+class LSTMCellTransLayer(IdentityTransLayer):
+ def build(self, input_shape):
+ if not self.built:
+ self.tf_layer.build(input_shape[0])
+ super().build(input_shape[0])
+
+ ##
+ # @brief call function
+ # @param inputs input with nntrainer layout
+ def call(self, inputs):
+ input = inputs[0]
+ states = inputs[1:]
+ _, states = self.tf_layer.call(input, states)
+ return states
+
+##
# @brief Translayer for gru layer
class GRUTransLayer(IdentityTransLayer):
def to_nntr_weights(self, tensorOrList):
if isinstance(layer, CHANNEL_LAST_LAYERS):
return ChannelLastTransLayer(layer)
+ if isinstance(layer, K.layers.LSTMCell):
+ return LSTMCellTransLayer(layer)
+
if isinstance(layer, K.layers.GRU):
return GRUTransLayer(layer)
auto semantic_lstmcell = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::LSTMCellLayer>,
- nntrainer::LSTMCellLayer::type, {"unit=1", "timestep=0", "max_timestep=1"}, 0,
- false, 1);
+ nntrainer::LSTMCellLayer::type, {"unit=1"}, 0, false, 3);
INSTANTIATE_TEST_CASE_P(LSTMCell, LayerSemantics,
::testing::Values(semantic_lstmcell));
auto lstmcell_single_step = LayerGoldenTestParamType(
nntrainer::createLayer<nntrainer::LSTMCellLayer>,
- {"unit=5", "timestep=0", "max_timestep=1", "integrate_bias=true"}, "3:1:1:7",
- "lstm_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+ {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5,3:1:1:5",
+ "lstmcell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
INSTANTIATE_TEST_CASE_P(LSTMCell, LayerGoldenTest,
::testing::Values(lstmcell_single_step));
auto outer_graph = makeGraph({
{"input", {"name=input", "input_shape=1:1:2"}},
+ {"input", {"name=input_hidden_state", "input_shape=1:1:2"}},
+ {"input", {"name=input_cell_state", "input_shape=1:1:2"}},
/// here lstm_cells is being inserted
- {"mse", {"name=loss", "input_layers=lstm_scope/a1"}},
+ {"mse", {"name=loss", "input_layers=lstmcell_scope/a1(0)"}},
});
for (auto &node : outer_graph) {
nn->addLayer(node);
}
- auto lstm = makeGraph({
- {"lstmcell", {"name=a1", "unit=2", "integrate_bias=false"}},
+ auto lstmcell = makeGraph({
+ {"input", {"name=dummy_0", "input_shape=1"}},
+ {"input", {"name=dummy_1", "input_shape=1"}},
+ {"input", {"name=dummy_2", "input_shape=1"}},
+ {"lstmcell",
+ {"name=a1", "unit=2", "input_layers=dummy_0, dummy_1, dummy_2"}},
});
- nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a1"},
- ml::train::ReferenceLayersType::RECURRENT,
- {
- "unroll_for=2",
- "as_sequence=a1",
- "recurrent_input=a1",
- "recurrent_output=a1",
- });
+ nn->addWithReferenceLayers(
+ lstmcell, "lstmcell_scope",
+ {"input", "input_hidden_state", "input_cell_state"},
+ {"a1(0)", "a1(1)", "a1(2)"}, {"a1"},
+ ml::train::ReferenceLayersType::RECURRENT,
+ {
+ "unroll_for=2",
+ "as_sequence=a1",
+ "recurrent_input=a1(0), a1(1), a1(2)",
+ "recurrent_output=a1(0), a1(0), a1(1)",
+ });
+ nn->setProperty({"input_layers=input, input_hidden_state, input_cell_state"});
nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
return nn;
}
auto outer_graph = makeGraph({
{"input", {"name=input", "input_shape=1:1:2"}},
+ {"input", {"name=a1_input_hidden_state", "input_shape=1:1:2"}},
+ {"input", {"name=a1_input_cell_state", "input_shape=1:1:2"}},
+ {"input", {"name=a2_input_hidden_state", "input_shape=1:1:2"}},
+ {"input", {"name=a2_input_cell_state", "input_shape=1:1:2"}},
/// here lstm_cells is being inserted
- {"mse", {"name=loss", "input_layers=lstm_scope/a2"}},
+ {"mse", {"name=loss", "input_layers=lstmcell_scope/a2(0)"}},
});
for (auto &node : outer_graph) {
nn->addLayer(node);
}
- auto lstm = makeGraph({
- {"lstmcell", {"name=a1", "unit=2", "integrate_bias=false"}},
+ auto lstmcell = makeGraph({
+ {"input", {"name=dummy_0", "input_shape=1"}},
+ {"input", {"name=dummy_1", "input_shape=1"}},
+ {"input", {"name=dummy_2", "input_shape=1"}},
+ {"input", {"name=dummy_3", "input_shape=1"}},
+ {"input", {"name=dummy_4", "input_shape=1"}},
{"lstmcell",
- {"name=a2", "unit=2", "integrate_bias=false", "input_layers=a1"}},
+ {"name=a1", "unit=2", "input_layers=dummy_0, dummy_1, dummy_2"}},
+ {"lstmcell", {"name=a2", "unit=2", "input_layers=a1(0), dummy_3, dummy_4"}},
});
- nn->addWithReferenceLayers(lstm, "lstm_scope", {"input"}, {"a1"}, {"a2"},
- ml::train::ReferenceLayersType::RECURRENT,
- {
- "unroll_for=2",
- "as_sequence=a2",
- "recurrent_input=a1",
- "recurrent_output=a2",
- });
+ nn->addWithReferenceLayers(
+ lstmcell, "lstmcell_scope",
+ {
+ "input",
+ "a1_input_hidden_state",
+ "a1_input_cell_state",
+ "a2_input_hidden_state",
+ "a2_input_cell_state",
+ },
+ {"a1(0)", "a1(1)", "a1(2)", "a2(1)", "a2(2)"}, {"a2"},
+ ml::train::ReferenceLayersType::RECURRENT,
+ {
+ "unroll_for=2",
+ "as_sequence=a2",
+ "recurrent_input=a1(0), a1(1), a1(2), a2(1), a2(2)",
+ "recurrent_output=a2(0), a1(0), a1(1), a2(0), a2(1)",
+ });
+ nn->setProperty(
+ {"input_layers=input, a1_input_hidden_state, a1_input_cell_state, "
+ "a2_input_hidden_state, a2_input_cell_state"});
nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"}));
return nn;
}
nn->addWithReferenceLayers(
zoneout_lstm, "zoneout_lstm_scope",
- {"input", "input_hidden_state", "input_cell_state"}, {"a1"}, {"a1"},
+ {"input", "input_hidden_state", "input_cell_state"},
+ {"a1(0)", "a1(1)", "a1(2)"}, {"a1"},
ml::train::ReferenceLayersType::RECURRENT,
{
"unroll_for=2",
"a2_input_hidden_state",
"a2_input_cell_state",
},
- {"a1", "a2"}, {"a2"}, ml::train::ReferenceLayersType::RECURRENT,
+ {"a1(0)", "a1(1)", "a1(2)", "a2(1)", "a2(2)"}, {"a2"},
+ ml::train::ReferenceLayersType::RECURRENT,
{
"unroll_for=2",
"as_sequence=a2",
mkModelTc_V2(makeFC, "fc_unroll_stacked", ModelTestOption::COMPARE_V2),
mkModelTc_V2(makeFCClipped, "fc_unroll_stacked_clipped",
ModelTestOption::COMPARE_V2),
- mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::COMPARE_V2),
- mkModelTc_V2(makeSingleLSTMCell, "lstm_single__1",
- ModelTestOption::COMPARE_V2),
- mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::COMPARE_V2),
- mkModelTc_V2(makeStackedLSTMCell, "lstm_stacked__1",
- ModelTestOption::COMPARE_V2),
+ mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::ALL_V2),
+ mkModelTc_V2(makeStackedLSTM, "lstm_stacked", ModelTestOption::ALL_V2),
+ mkModelTc_V2(makeSingleLSTMCell, "lstmcell_single",
+ ModelTestOption::ALL_V2),
+ mkModelTc_V2(makeStackedLSTMCell, "lstmcell_stacked",
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_000",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_000",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_000",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_000",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_000",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_000",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_050",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_050",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_050",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_050",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_050",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_050",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_000_100",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_000_100",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_050_100",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_050_100",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleZoneoutLSTMCell, "zoneout_lstm_single_100_100",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedZoneoutLSTMCell, "zoneout_lstm_stacked_100_100",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleRNNCell, "rnncell_single__1",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedRNNCell, "rnncell_stacked__1",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeSingleGRUCell, "grucell_single__1",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
mkModelTc_V2(makeStackedGRUCell, "grucell_stacked__1",
- ModelTestOption::COMPARE_V2),
+ ModelTestOption::ALL_V2),
}),
[](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info) {
return std::get<1>(info.param);