fc23642d4fc3a642c2f9a3fefd3942f915d1d452
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / lstm_dynamic_input.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "lstm_dynamic_input_inst.h"
19 #include "lstm_dynamic_inst.h"
20 #include "primitive_type_base.h"
21 #include "error_handler.h"
22 #include "json_object.h"
23 #include <string>
24
25 namespace cldnn {
26 primitive_type_id lstm_dynamic_input_type_id() {
27     static primitive_type_base<lstm_dynamic_input> instance;
28     return &instance;
29 }
30 // input_tensor:   [b: batch, f: max_sequence_length, x: input_size, y: direction]
31 // weights_tensor: [b: 1, f: direction, x: input_size, y: 4 * hidden_size]
32 // output_tensor:  [b: batch, f: max_sequence_length, x: 4 * hidden_size, y: direction]
33 layout lstm_dynamic_input_inst::calc_output_layout(lstm_dynamic_input_node const& node) {
34     assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
35            "Output data type forcing is not supported for lstm_dynamic_node!");
36     auto input_layout = node.input().get_output_layout();
37     auto weight_layout = node.weights().get_output_layout();
38     auto batch = input_layout.size.batch[0];
39     auto direction = node.direction();
40     auto output_sequence = input_layout.size.feature[0];
41     return layout(input_layout.data_type,
42                   input_layout.format,
43                   tensor(batch, output_sequence, weight_layout.size.spatial[1], direction));
44 }
45
46 std::string lstm_dynamic_input_inst::to_string(lstm_dynamic_input_node const& node) {
47     auto desc = node.get_primitive();
48     auto node_info = node.desc_to_json();
49     auto bias_id = desc->bias != "" ? desc->bias : "no bias";
50
51     std::stringstream primitive_description;
52     json_composite lstm_dynamic_input_info;
53     lstm_dynamic_input_info.add("dyn_length id", desc->dyn_length);
54     lstm_dynamic_input_info.add("weights id", desc->weights);
55     lstm_dynamic_input_info.add("bias id", bias_id);
56     lstm_dynamic_input_info.add("max seq len", node.input().get_output_layout().size.feature[0]);
57     lstm_dynamic_input_info.add("hidden size", node.weights().get_output_layout().size.spatial[1] / 4);
58     lstm_dynamic_input_info.add("direction", node.weights().get_output_layout().size.feature[0]);
59     node_info->add("lstm_dynamic_input info", lstm_dynamic_input_info);
60     node_info->dump(primitive_description);
61
62     return primitive_description.str();
63 }
64
65 lstm_dynamic_input_inst::typed_primitive_inst(network_impl& network, lstm_dynamic_input_node const& node)
66     : parent(network, node) {
67     // Check input
68     auto input_layout = node.input().get_output_layout();
69     auto direction = node.direction();
70     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(),
71                                   "input format",
72                                   input_layout.format.value,
73                                   "expected format",
74                                   format::bfyx);
75     lstm_dynamic_inst::check_direction(node.input(), direction, "input");
76
77     // check dynamic length
78     CLDNN_ERROR_BOOL(node.id(),
79                      "Dynamic length memory",
80                      !node.dyn_length_term(),
81                      "Id of dynamic length memory is not set.");
82     auto dyn_length_size = node.dyn_length().get_output_layout().count();
83     CLDNN_ERROR_NOT_EQUAL(node.id(),
84                           "Batch",
85                           node.get_output_layout().size.batch[0],
86                           "Dynamic tensor elements count.",
87                           dyn_length_size,
88                           "Should be equal.");
89
90     // check weights
91     CLDNN_ERROR_BOOL(node.id(), "Weights memory", !node.weights_term(), "Id of weights memory is not set.");
92     auto weights_id = node.weights().id();
93     auto weights_tensor = node.weights().get_output_layout().size;
94     auto hidden_size = weights_tensor.spatial[1] / 4;
95     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(),
96                                   "weights format",
97                                   node.weights().get_output_layout().format.value,
98                                   "expected bfyx format",
99                                   format::bfyx);
100     CLDNN_ERROR_NOT_EQUAL(node.id(),
101                           "Weights batch size",
102                           weights_tensor.batch[0],
103                           "1",
104                           1,
105                           "Sizes mismatch, weights_id: " + weights_id);
106     CLDNN_ERROR_NOT_EQUAL(node.id(),
107                           "Weights x size",
108                           weights_tensor.spatial[0],
109                           "input_size",
110                           input_layout.size.spatial[0],
111                           "Sizes mismatch, weights_id: " + weights_id);
112
113     // check bias
114     if (node.bias_term()) {
115         auto bias_id = node.id();
116         auto bias_tensor = node.bias().get_output_layout().size;
117         CLDNN_ERROR_NOT_EQUAL(node.id(),
118                               "Bias count",
119                               bias_tensor.count(),
120                               "direction * 4 * hidden_size",
121                               direction * 4 * hidden_size,
122                               "Bias count mismtach, bias_id: " + bias_id);
123         lstm_dynamic_inst::check_direction(node.bias(), direction, "bias");
124     }
125 }
126 }  // namespace cldnn