2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "lstm_gemm_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id lstm_gemm_type_id()
27 static primitive_type_base<lstm_gemm> instance;
32 layout lstm_gemm_inst::calc_output_layout(lstm_gemm_node const& node)
34 assert((bool)node.get_primitive()->output_data_type == false
35 && "Output data type forcing is not supported for lstm_gemm_node!");
36 auto desc = node.get_primitive();
37 auto input_layout = node.input().get_output_layout();
38 auto weights_layout = node.weights().get_output_layout();
40 // input{bfyx} = [b: batch, f: sequence, x: input_size, y: 1]
41 // weights{bfyx} = [b: 1, f: direction, x: 4 * hidden_size, y: input_size ]
42 // recurrent{bfyx} = [b: 1, f: direction, x: 4 * hidden_size, y: hidden_size ]
43 // biases{bfyx} = [b: 1, f:1 , x: direction, y: 4 * hidden_size ]
44 // hidden{bfyx} = [b: batch, f: direction, x: 1 , y: hidden_size ] optional
45 // tempGEMM{bfyx} = [b: batch, f: direction, x: 4*hidden_size, y: 1] output
46 auto result = layout(input_layout.data_type, input_layout.format, tensor(input_layout.size.batch[0], weights_layout.size.feature[0], weights_layout.size.spatial[1], 1));
50 std::string lstm_gemm_inst::to_string(lstm_gemm_node const& node)
52 auto desc = node.get_primitive();
53 auto node_info = node.desc_to_json();
54 auto weights_id = desc->weights;
55 auto recurrent_id = desc->recurrent;
56 auto bias_id = desc->bias != "" ? desc->bias : "no bias";
57 auto hidden_id = desc->hidden != "" ? desc->hidden : "no inital hidden";
59 std::stringstream primitive_description;
61 json_composite lstm_gemm_info;
62 lstm_gemm_info.add("weights id", weights_id);
63 lstm_gemm_info.add("recurrent id", recurrent_id);
64 lstm_gemm_info.add("bias id", bias_id);
65 lstm_gemm_info.add("hidden id", hidden_id);
66 node_info->add("lstm gemm info", lstm_gemm_info);
67 node_info->dump(primitive_description);
69 return primitive_description.str();
72 lstm_gemm_inst::typed_primitive_inst(network_impl& network, lstm_gemm_node const& node)
73 :parent(network, node)
75 auto input_layout = node.input().get_output_layout();
76 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_layout.format.value, "expected format", format::bfyx, format::fyxb);