2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "../C/gemm.h"
20 #include "primitive.hpp"
24 /// @addtogroup cpp_api C++ API
26 /// @addtogroup cpp_topology Network Topology
28 /// @addtogroup cpp_primitives Primitives
30 /// @brief Type of gemm that will be added to the input by border layer / primitive.
32 /// @brief Adds gemm input.
34 /// @details General Matrix Multiplication witch batch support,
35 /// A(B,Z,X)xA2(B,Y,Z)=C(B,X,Y)
37 /// @n@b Requirements:
38 /// @n - @c input - first matrix
39 /// @n - @c input2 - second matrix
40 /// @n - @c optional: input3 matrix, alpha, beta, transpose
41 /// @n - @c computations with optional params: output = alpha x (input3 x beta + input x input2)
42 /// @n - @c transpose params tranposing second matrix <-TODO
45 struct gemm : public primitive_base<gemm, CLDNN_PRIMITIVE_DESC(gemm)>
47 CLDNN_DECLARE_PRIMITIVE(gemm)
49 /// @brief Constructs gemm layer.
50 /// @brief Primitive id containing first matrix
51 /// @brief Primitive id containing second matrix
52 /// @brief Flag for transposing first input matrix
53 /// @brief Flag for transposing second input matrix
54 /// @brief Variable containing ALPHA parameter
55 /// @brief Variable containing BETA parameter
58 const primitive_id& id,
59 const primitive_id& input,
60 const primitive_id& input2,
61 const bool transpose_input1 = false,
62 const bool transpose_input2 = false,
63 const float alpha = 1.0f,
64 const float beta = 0.0f,
65 const padding& output_padding = padding()
67 : primitive_base(id, { input, input2 }, output_padding)
68 , transpose_input1(transpose_input1)
69 , transpose_input2(transpose_input2)
74 /// @brief Constructs gemm layer.
75 /// @brief Primitive id containing first matrix
76 /// @brief Primitive id containing second matrix
77 /// @brief Primitive id containing third matrix
78 /// @brief Flag for transposing first input matrix
79 /// @brief Flag for transposing second input matrix
80 /// @brief Variable containing ALPHA parameter
81 /// @brief Variable containing BETA parameter
84 const primitive_id& id,
85 const primitive_id& input,
86 const primitive_id& input2,
87 const primitive_id& input3,
88 const bool transpose_input1 = false,
89 const bool transpose_input2 = false,
90 const float alpha = 1.f,
91 const float beta = 0.f,
92 const padding& output_padding = padding()
94 : primitive_base(id, { input, input2, input3 }, output_padding)
95 , transpose_input1(transpose_input1)
96 , transpose_input2(transpose_input2)
103 /// @brief Flag for transposing first input matrix
104 bool transpose_input1;
105 /// @brief Flag for transposing second input matrix
106 bool transpose_input2;
107 /// @brief Variable containing ALPHA parameter
109 /// @brief Variable containing BETA parameter
112 /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{gemm}
114 : primitive_base(dto)
115 , transpose_input1 (dto->transpose_input1)
116 , transpose_input2(dto->transpose_input2)
123 void update_dto(dto& dto) const override
125 dto.transpose_input1 = transpose_input1;
126 dto.transpose_input2 = transpose_input2;