2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "lstm_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id lstm_type_id()
27 static primitive_type_base<lstm> instance;
32 layout lstm_inst::calc_output_layout(lstm_node const& node)
34 assert((bool)node.get_primitive()->output_data_type == false
35 && "Output data type forcing is not supported for lstm_node!");
36 auto input_layout = node.input().get_output_layout();
37 auto hidden_layout = node.inital_hidden().get_output_layout();
39 // input = [ batch, sequence, direction, input_size ]
40 // weights = [ 1, direction, 4 * hidden_size, input_size ]
41 // recurrent = [ 1, direction, 4 * hidden_size, hidden_size ]
42 // biases = [ 1, 1, direction, 4 * hidden_size ]
43 // hidden = [ batch, 1, direction, hidden_size ]
44 // cell = [ batch, 1, direction, hidden_size ]
45 // output = [ batch, sequence, direction, hidden_size ]
46 auto result = layout(input_layout.data_type, format::bfyx,
47 tensor(hidden_layout.size.feature[0], input_layout.size.feature[0],
48 hidden_layout.size.spatial[0], hidden_layout.size.spatial[1]));
52 std::string lstm_inst::to_string(lstm_node const& node)
54 auto desc = node.get_primitive();
55 auto node_info = node.desc_to_json();
56 auto weights_id = desc->weights;
57 auto recurrent_id = desc->recurrent;
58 auto bias_id = desc->bias != "" ? desc->bias : "no bias";
59 auto peepholes_id = desc->peepholes != "" ? desc->peepholes : "no peepholes";
60 auto initial_hidden_id = desc->initial_hidden != "" ? desc->initial_hidden : "no inital hidden";
61 auto initial_cell_id = desc->initial_cell != "" ? desc->initial_cell : "no initial cell";
63 std::stringstream primitive_description;
65 json_composite lstm_info;
66 lstm_info.add("weights id", weights_id);
67 lstm_info.add("recurrent id", recurrent_id);
68 lstm_info.add("bias id", bias_id);
69 lstm_info.add("peepholes id", peepholes_id);
70 lstm_info.add("initial_hidden id", initial_hidden_id);
71 lstm_info.add("initial_cell id", initial_cell_id);
72 node_info->add("lstm info", lstm_info);
73 node_info->dump(primitive_description);
75 return primitive_description.str();
78 lstm_inst::typed_primitive_inst(network_impl& network, lstm_node const& node)
79 :parent(network, node)
81 auto input_layout = node.input().get_output_layout();
82 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_layout.format.value, "expected format", format::bfyx);