Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gemm.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "gemm_inst.h"
19 #include "primitive_type_base.h"
20 #include "error_handler.h"
21 #include "json_object.h"
22
23 namespace cldnn
24 {
25 primitive_type_id gemm_type_id()
26 {
27     static primitive_type_base<gemm> instance;
28     return &instance;
29 }
30
31
32 layout gemm_inst::calc_output_layout(gemm_node const& node)
33 {
34     assert((bool)node.get_primitive()->output_data_type == false
35            && "Output data type forcing is not supported for gemm_node!");
36     auto input1_layout = node.input(0).get_output_layout();
37     auto input2_layout = node.input(1).get_output_layout();
38     bool transpose_input1 = node.get_primitive()->transpose_input1;
39     bool transpose_input2 = node.get_primitive()->transpose_input2;
40
41     if (!transpose_input1 && !transpose_input2)
42         return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0],  1, 
43                       input2_layout.size.spatial[0], input1_layout.size.spatial[1]));
44     else if (!transpose_input1 && transpose_input2)
45         return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
46             input2_layout.size.spatial[1], input1_layout.size.spatial[1]));
47     else if (transpose_input1 && !transpose_input2)
48         return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
49             input2_layout.size.spatial[0], input1_layout.size.spatial[0]));
50     else
51         return layout(input1_layout.data_type, format::bfyx, tensor(input1_layout.size.batch[0], 1,
52             input2_layout.size.spatial[1], input1_layout.size.spatial[0]));
53     
54 }
55
56 std::string gemm_inst::to_string(gemm_node const& node)
57 {
58     auto desc = node.get_primitive();
59     auto node_info = node.desc_to_json();
60     auto alpha = desc->alpha;
61     auto beta = desc->beta;
62     auto transpose_input1 = desc->transpose_input1 ? " true" : "false";
63     auto transpose_input2 = desc->transpose_input2 ? " true" : "false";
64     std::stringstream primitive_description;
65
66     json_composite gemm_info;
67     for (size_t i = 0; i < node.inputs_count(); i++)
68     {
69         gemm_info.add("input_" + std::to_string(i), node.input(i).id());
70     }
71     gemm_info.add("alpha", alpha);
72     gemm_info.add("beta", beta);
73     gemm_info.add("trasnpose_input1", transpose_input1);
74     gemm_info.add("transpose_input2", transpose_input2);
75     node_info->dump(primitive_description);
76
77     return primitive_description.str();
78 }
79
80 gemm_inst::typed_primitive_inst(network_impl& network, gemm_node const& node)
81     :parent(network, node)
82 {
83     auto input_layout = node.input(0).get_output_layout();
84     auto input2_layout = node.input(1).get_output_layout();
85     bool transpose_input1 = node.get_primitive()->transpose_input1;
86     bool transpose_input2 = node.get_primitive()->transpose_input2;
87
88     if (!transpose_input1 && !transpose_input2)
89     {
90         CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[1], "");
91         if (node.inputs_count() > 2)
92         {
93             auto input3_layout = node.input(2).get_output_layout();
94             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Columns count", input3_layout.size.spatial[0], "Input2 Columns count", input2_layout.size.spatial[0], "");
95             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Rows count", input_layout.size.spatial[1], "");
96         }
97     }
98
99     else if (!transpose_input1 && transpose_input2)
100     {
101         CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[0], "");
102         if (node.inputs_count() > 2)
103         {
104             auto input3_layout = node.input(2).get_output_layout();
105             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input13 Columns count", input3_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[1], "");
106             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Rows count", input_layout.size.spatial[1], "");
107         }
108     }
109     else if (transpose_input1 && !transpose_input2)
110     {
111         CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[1], "Input2 Rows count", input2_layout.size.spatial[1], "");
112         if (node.inputs_count() > 2)
113         {
114             auto input3_layout = node.input(2).get_output_layout();
115             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Columns count", input3_layout.size.spatial[0], "Input2 Columns count", input2_layout.size.spatial[0], "");
116             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Columns count", input_layout.size.spatial[0], "");
117         }
118     }
119     else
120     {
121         CLDNN_ERROR_NOT_EQUAL(node.id(), "Input1 Columns count", input_layout.size.spatial[1], "Input2 Rows count", input2_layout.size.spatial[0], "");
122         if (node.inputs_count() > 2)
123         {
124             auto input3_layout = node.input(2).get_output_layout();
125             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Columns count", input3_layout.size.spatial[0], "Input2 Rows count", input2_layout.size.spatial[1], "");
126             CLDNN_ERROR_NOT_EQUAL(node.id(), "Input3 Rows count", input3_layout.size.spatial[1], "Input1 Columns count", input_layout.size.spatial[0], "");
127         }
128     }
129
130 }
131 }