2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "lstm_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id lstm_type_id() {
26 static primitive_type_base<lstm> instance;
30 layout lstm_inst::calc_output_layout(lstm_node const& node) {
31 assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
32 "Output data type forcing is not supported for lstm_node!");
33 auto input_layout = node.input().get_output_layout();
34 auto hidden_layout = node.inital_hidden().get_output_layout();
36 // input = [ batch, sequence, direction, input_size ]
37 // weights = [ 1, direction, 4 * hidden_size, input_size ]
38 // recurrent = [ 1, direction, 4 * hidden_size, hidden_size ]
39 // biases = [ 1, 1, direction, 4 * hidden_size ]
40 // hidden = [ batch, 1, direction, hidden_size ]
41 // cell = [ batch, 1, direction, hidden_size ]
42 // output = [ batch, sequence, direction, hidden_size ]
43 auto result = layout(input_layout.data_type,
45 tensor(hidden_layout.size.feature[0],
46 input_layout.size.feature[0],
47 hidden_layout.size.spatial[0],
48 hidden_layout.size.spatial[1]));
52 std::string lstm_inst::to_string(lstm_node const& node) {
53 auto desc = node.get_primitive();
54 auto node_info = node.desc_to_json();
55 auto weights_id = desc->weights;
56 auto recurrent_id = desc->recurrent;
57 auto bias_id = desc->bias != "" ? desc->bias : "no bias";
58 auto peepholes_id = desc->peepholes != "" ? desc->peepholes : "no peepholes";
59 auto initial_hidden_id = desc->initial_hidden != "" ? desc->initial_hidden : "no inital hidden";
60 auto initial_cell_id = desc->initial_cell != "" ? desc->initial_cell : "no initial cell";
62 std::stringstream primitive_description;
64 json_composite lstm_info;
65 lstm_info.add("weights id", weights_id);
66 lstm_info.add("recurrent id", recurrent_id);
67 lstm_info.add("bias id", bias_id);
68 lstm_info.add("peepholes id", peepholes_id);
69 lstm_info.add("initial_hidden id", initial_hidden_id);
70 lstm_info.add("initial_cell id", initial_cell_id);
71 node_info->add("lstm info", lstm_info);
72 node_info->dump(primitive_description);
74 return primitive_description.str();
77 lstm_inst::typed_primitive_inst(network_impl& network, lstm_node const& node) : parent(network, node) {
78 auto input_layout = node.input().get_output_layout();
79 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(),
81 input_layout.format.value,