061ab9b2da08893b6ab5ebebc0893301f04403e5
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / lstm.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "lstm_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
22 #include <string>
23
24 namespace cldnn {
25 primitive_type_id lstm_type_id() {
26     static primitive_type_base<lstm> instance;
27     return &instance;
28 }
29
30 layout lstm_inst::calc_output_layout(lstm_node const& node) {
31     assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
32            "Output data type forcing is not supported for lstm_node!");
33     auto input_layout = node.input().get_output_layout();
34     auto hidden_layout = node.inital_hidden().get_output_layout();
35
36     // input     = [ batch,  sequence,       direction,      input_size ]
37     // weights   = [     1, direction, 4 * hidden_size,      input_size ]
38     // recurrent = [     1, direction, 4 * hidden_size,     hidden_size ]
39     // biases    = [     1,         1,       direction, 4 * hidden_size ]
40     // hidden    = [ batch,         1,       direction,     hidden_size ]
41     // cell      = [ batch,         1,       direction,     hidden_size ]
42     // output    = [ batch,  sequence,       direction,     hidden_size ]
43     auto result = layout(input_layout.data_type,
44                          format::bfyx,
45                          tensor(hidden_layout.size.feature[0],
46                                 input_layout.size.feature[0],
47                                 hidden_layout.size.spatial[0],
48                                 hidden_layout.size.spatial[1]));
49     return result;
50 }
51
52 std::string lstm_inst::to_string(lstm_node const& node) {
53     auto desc = node.get_primitive();
54     auto node_info = node.desc_to_json();
55     auto weights_id = desc->weights;
56     auto recurrent_id = desc->recurrent;
57     auto bias_id = desc->bias != "" ? desc->bias : "no bias";
58     auto peepholes_id = desc->peepholes != "" ? desc->peepholes : "no peepholes";
59     auto initial_hidden_id = desc->initial_hidden != "" ? desc->initial_hidden : "no inital hidden";
60     auto initial_cell_id = desc->initial_cell != "" ? desc->initial_cell : "no initial cell";
61
62     std::stringstream primitive_description;
63
64     json_composite lstm_info;
65     lstm_info.add("weights id", weights_id);
66     lstm_info.add("recurrent id", recurrent_id);
67     lstm_info.add("bias id", bias_id);
68     lstm_info.add("peepholes id", peepholes_id);
69     lstm_info.add("initial_hidden id", initial_hidden_id);
70     lstm_info.add("initial_cell id", initial_cell_id);
71     node_info->add("lstm info", lstm_info);
72     node_info->dump(primitive_description);
73
74     return primitive_description.str();
75 }
76
77 lstm_inst::typed_primitive_inst(network_impl& network, lstm_node const& node) : parent(network, node) {
78     auto input_layout = node.input().get_output_layout();
79     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(),
80                                   "input format",
81                                   input_layout.format.value,
82                                   "expected format",
83                                   format::bfyx);
84 }
85
86 }  // namespace cldnn