2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "lstm_dynamic_input_inst.h"
19 #include "lstm_dynamic_inst.h"
20 #include "primitive_type_base.h"
21 #include "error_handler.h"
22 #include "json_object.h"
26 primitive_type_id lstm_dynamic_input_type_id() {
27 static primitive_type_base<lstm_dynamic_input> instance;
30 // input_tensor: [b: batch, f: max_sequence_length, x: input_size, y: direction]
31 // weights_tensor: [b: 1, f: direction, x: input_size, y: 4 * hidden_size]
32 // output_tensor: [b: batch, f: max_sequence_length, x: 4 * hidden_size, y: direction]
33 layout lstm_dynamic_input_inst::calc_output_layout(lstm_dynamic_input_node const& node) {
34 assert(static_cast<bool>(node.get_primitive()->output_data_type) == false &&
35 "Output data type forcing is not supported for lstm_dynamic_node!");
36 auto input_layout = node.input().get_output_layout();
37 auto weight_layout = node.weights().get_output_layout();
38 auto batch = input_layout.size.batch[0];
39 auto direction = node.direction();
40 auto output_sequence = input_layout.size.feature[0];
41 return layout(input_layout.data_type,
43 tensor(batch, output_sequence, weight_layout.size.spatial[1], direction));
46 std::string lstm_dynamic_input_inst::to_string(lstm_dynamic_input_node const& node) {
47 auto desc = node.get_primitive();
48 auto node_info = node.desc_to_json();
49 auto bias_id = desc->bias != "" ? desc->bias : "no bias";
51 std::stringstream primitive_description;
52 json_composite lstm_dynamic_input_info;
53 lstm_dynamic_input_info.add("dyn_length id", desc->dyn_length);
54 lstm_dynamic_input_info.add("weights id", desc->weights);
55 lstm_dynamic_input_info.add("bias id", bias_id);
56 lstm_dynamic_input_info.add("max seq len", node.input().get_output_layout().size.feature[0]);
57 lstm_dynamic_input_info.add("hidden size", node.weights().get_output_layout().size.spatial[1] / 4);
58 lstm_dynamic_input_info.add("direction", node.weights().get_output_layout().size.feature[0]);
59 node_info->add("lstm_dynamic_input info", lstm_dynamic_input_info);
60 node_info->dump(primitive_description);
62 return primitive_description.str();
65 lstm_dynamic_input_inst::typed_primitive_inst(network_impl& network, lstm_dynamic_input_node const& node)
66 : parent(network, node) {
68 auto input_layout = node.input().get_output_layout();
69 auto direction = node.direction();
70 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(),
72 input_layout.format.value,
75 lstm_dynamic_inst::check_direction(node.input(), direction, "input");
77 // check dynamic length
78 CLDNN_ERROR_BOOL(node.id(),
79 "Dynamic length memory",
80 !node.dyn_length_term(),
81 "Id of dynamic length memory is not set.");
82 auto dyn_length_size = node.dyn_length().get_output_layout().count();
83 CLDNN_ERROR_NOT_EQUAL(node.id(),
85 node.get_output_layout().size.batch[0],
86 "Dynamic tensor elements count.",
91 CLDNN_ERROR_BOOL(node.id(), "Weights memory", !node.weights_term(), "Id of weights memory is not set.");
92 auto weights_id = node.weights().id();
93 auto weights_tensor = node.weights().get_output_layout().size;
94 auto hidden_size = weights_tensor.spatial[1] / 4;
95 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(),
97 node.weights().get_output_layout().format.value,
98 "expected bfyx format",
100 CLDNN_ERROR_NOT_EQUAL(node.id(),
101 "Weights batch size",
102 weights_tensor.batch[0],
105 "Sizes mismatch, weights_id: " + weights_id);
106 CLDNN_ERROR_NOT_EQUAL(node.id(),
108 weights_tensor.spatial[0],
110 input_layout.size.spatial[0],
111 "Sizes mismatch, weights_id: " + weights_id);
114 if (node.bias_term()) {
115 auto bias_id = node.id();
116 auto bias_tensor = node.bias().get_output_layout().size;
117 CLDNN_ERROR_NOT_EQUAL(node.id(),
120 "direction * 4 * hidden_size",
121 direction * 4 * hidden_size,
122 "Bias count mismtach, bias_id: " + bias_id);
123 lstm_dynamic_inst::check_direction(node.bias(), direction, "bias");