2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "gemm_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
25 primitive_type_id gemm_type_id()
27 static primitive_type_base<gemm> instance;
32 layout gemm_inst::calc_output_layout(gemm_node const& node)
34 assert((bool)node.get_primitive()->output_data_type == false
35 && "Output data type forcing is not supported for gemm_node!");
36 auto input1_layout = node.input(0).get_output_layout();
37 auto input2_layout = node.input(1).get_output_layout();
38 bool transpose_input1 = node.get_primitive()->transpose_input1;
39 bool transpose_input2 = node.get_primitive()->transpose_input2;
41 if (!transpose_input1 && !transpose_input2)
42 return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
43 input2_layout.size.spatial[0], input1_layout.size.spatial[1]));
44 else if (!transpose_input1 && transpose_input2)
45 return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
46 input2_layout.size.spatial[1], input1_layout.size.spatial[1]));
47 else if (transpose_input1 && !transpose_input2)
48 return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
49 input2_layout.size.spatial[0], input1_layout.size.spatial[0]));
51 return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
52 input2_layout.size.spatial[1], input1_layout.size.spatial[0]));
56 std::string gemm_inst::to_string(gemm_node const& node)
58 auto desc = node.get_primitive();
59 auto node_info = node.desc_to_json();
60 auto alpha = desc->alpha;
61 auto beta = desc->beta;
62 auto transpose_input1 = desc->transpose_input1 ? " true" : "false";
63 auto transpose_input2 = desc->transpose_input2 ? " true" : "false";
64 std::stringstream primitive_description;
66 json_composite gemm_info;
67 for (size_t i = 0; i < node.inputs_count(); i++)
69 gemm_info.add("input_" + std::to_string(i), node.input(i).id());
71 gemm_info.add("alpha", alpha);
72 gemm_info.add("beta", beta);
73 gemm_info.add("trasnpose_input1", transpose_input1);
74 gemm_info.add("transpose_input2", transpose_input2);
75 node_info->dump(primitive_description);
77 return primitive_description.str();
80 gemm_inst::typed_primitive_inst(network_impl& network, gemm_node const& node)
81 :parent(network, node)
83 auto input_layout = node.input(0).get_output_layout();
84 auto input2_layout = node.input(1).get_output_layout();
85 bool transpose_input1 = node.get_primitive()->transpose_input1;
86 bool transpose_input2 = node.get_primitive()->transpose_input2;
88 if (!transpose_input1 && !transpose_input2)
90 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[1], "");
91 if (node.inputs_count() > 2)
93 auto input3_layout = node.input(2).get_output_layout();
94 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Columns count", input3_layout.size.spatial[0], "Input2 Columns count", input2_layout.size.spatial[0], "");
95 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Rows count", input_layout.size.spatial[1], "");
99 else if (!transpose_input1 && transpose_input2)
101 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[0], "");
102 if (node.inputs_count() > 2)
104 auto input3_layout = node.input(2).get_output_layout();
105 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input13 Columns count", input3_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[1], "");
106 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Rows count", input_layout.size.spatial[1], "");
109 else if (transpose_input1 && !transpose_input2)
111 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[1], "Input2 Rows count", input2_layout.size.spatial[1], "");
112 if (node.inputs_count() > 2)
114 auto input3_layout = node.input(2).get_output_layout();
115 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Columns count", input3_layout.size.spatial[0], "Input2 Columns count", input2_layout.size.spatial[0], "");
116 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Columns count", input_layout.size.spatial[0], "");
121 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[1], "Input2 Rows count", input2_layout.size.spatial[0], "");
122 if (node.inputs_count() > 2)
124 auto input3_layout = node.input(2).get_output_layout();
125 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Columns count", input3_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[1], "");
126 CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Columns count", input_layout.size.spatial[0], "");