layout lstm_inst::calc_output_layout(lstm_node const& node)
{
+ assert((bool)node.get_primitive()->output_data_type == false
+ && "Output data type forcing is not supported for lstm_node!");
auto input_layout = node.input().get_output_layout();
auto hidden_layout = node.inital_hidden().get_output_layout();
- // input = [ 1, sequence, batch, input_size ]
- // weights = [ 1, direction, 4 * hidden_size, input_size ]
- // recurrent = [ 1, direction, 4 * hidden_size, hidden_size ]
- // biases = [ 1, 1, direction, 4 * hidden_size ]
- // hidden = [ 1, direction, batch, hidden_size ]
- // cell = [ 1, direction, batch, hidden_size ]
- // output = [ sequence, direction, batch, hidden_size ]
+ // input = [ batch, sequence, direction, input_size ]
+ // weights = [ 1, direction, 4 * hidden_size, input_size ]
+ // recurrent = [ 1, direction, 4 * hidden_size, hidden_size ]
+ // biases = [ 1, 1, direction, 4 * hidden_size ]
+ // hidden = [ batch, 1, direction, hidden_size ]
+ // cell = [ batch, 1, direction, hidden_size ]
+ // output = [ batch, sequence, direction, hidden_size ]
auto result = layout(input_layout.data_type, format::bfyx,
- tensor(hidden_layout.size.feature[0], input_layout.size.feature[0], hidden_layout.size.spatial[0], hidden_layout.size.spatial[1]));
+ tensor(hidden_layout.size.feature[0], input_layout.size.feature[0],
+ hidden_layout.size.spatial[0], hidden_layout.size.spatial[1]));
return result;
}
lstm_inst::typed_primitive_inst(network_impl& network, lstm_node const& node)
:parent(network, node)
{
- // [ARIEL] TODO: That do we need to check here??
- auto input_size = node.input().get_output_layout();
- // auto output_size = output_memory().get_layout();
- CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::bfyx);
- //CLDNN_ERROR_NOT_EQUAL(node.id(), "Input size", input_size.size.raw.size(), "output size", output_size.size.raw.size(), "");
+ auto input_layout = node.input().get_output_layout();
+ CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_layout.format.value, "expected format", format::bfyx);
}
+
}