Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / gemm.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "../C/gemm.h"
20 #include "primitive.hpp"
21
22 namespace cldnn
23 {
24     /// @addtogroup cpp_api C++ API
25     /// @{
26     /// @addtogroup cpp_topology Network Topology
27     /// @{
28     /// @addtogroup cpp_primitives Primitives
29     /// @{
30     /// @brief Type of gemm that will be added to the input by border layer / primitive.
31
32     /// @brief Adds gemm  input.
33     ///
34     /// @details General Matrix Multiplication witch batch support, 
35     ///          A(B,Z,X)xA2(B,Y,Z)=C(B,X,Y)
36     /// @n
37     /// @n@b Requirements:
38     /// @n - @c input - first matrix
39     /// @n - @c input2 - second matrix
40     /// @n - @c optional: input3 matrix, alpha, beta, transpose
41     /// @n - @c computations with optional params: output = alpha x (input3 x beta + input x input2) 
42     /// @n - @c transpose params tranposing second matrix <-TODO
43
44
45 struct gemm : public primitive_base<gemm, CLDNN_PRIMITIVE_DESC(gemm)>
46 {
47     CLDNN_DECLARE_PRIMITIVE(gemm)
48
49         /// @brief Constructs gemm layer.
50         /// @brief Primitive id containing first matrix
51         /// @brief Primitive id containing second matrix
52         /// @brief Flag for transposing first input matrix
53         /// @brief Flag for transposing second input matrix
54         /// @brief Variable containing ALPHA parameter
55         /// @brief Variable containing BETA parameter
56
57         gemm(
58             const primitive_id& id,
59             const primitive_id& input,
60             const primitive_id& input2,
61             const bool transpose_input1 = false,
62             const bool transpose_input2 = false,
63             const float alpha = 1.0f,
64             const float beta = 0.0f,
65             const padding& output_padding = padding()
66         )
67         : primitive_base(id, { input, input2 }, output_padding)
68         , transpose_input1(transpose_input1)
69         , transpose_input2(transpose_input2)
70         , alpha(alpha)
71         , beta(beta)
72     {
73     }
74         /// @brief Constructs gemm layer.
75         /// @brief Primitive id containing first matrix
76         /// @brief Primitive id containing second matrix
77         /// @brief Primitive id containing third matrix 
78         /// @brief Flag for transposing first input matrix
79         /// @brief Flag for transposing second input matrix
80         /// @brief Variable containing ALPHA parameter
81         /// @brief Variable containing BETA parameter
82
83         gemm(
84             const primitive_id& id,
85             const primitive_id& input,
86             const primitive_id& input2,
87             const primitive_id& input3,
88             const bool transpose_input1 = false,
89             const bool transpose_input2 = false,
90             const float alpha = 1.f,
91             const float beta = 0.f,
92             const padding& output_padding = padding()
93         )
94         : primitive_base(id, { input, input2, input3 }, output_padding)
95         , transpose_input1(transpose_input1)
96         , transpose_input2(transpose_input2)
97         , alpha(alpha)
98         , beta(beta)
99
100     {
101     }
102
103     /// @brief Flag for transposing first input matrix
104     bool transpose_input1;
105     /// @brief Flag for transposing second input matrix
106     bool transpose_input2;
107     /// @brief Variable containing ALPHA parameter
108     float alpha;
109     /// @brief Variable containing BETA parameter
110     float beta;
111
112     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{gemm}
113     gemm(const dto* dto)
114         : primitive_base(dto)
115         , transpose_input1 (dto->transpose_input1)
116         , transpose_input2(dto->transpose_input2)
117         , alpha (dto->alpha)
118         , beta (dto->beta)
119     {
120     }
121
122 protected:
123     void update_dto(dto& dto) const override
124     {
125         dto.transpose_input1 = transpose_input1;
126         dto.transpose_input2 = transpose_input2;
127         dto.alpha = alpha;
128         dto.beta = beta;
129     }
130 };
131
132 }
133
134 /// @}
135 /// @}
136 /// @}