updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_gemm_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_gemm_node.h"
6 #include <ie_layers.h>
7 #include <string>
8 #include <vector>
9 #include <memory>
10 #include <algorithm>
11 #include <cmath>
12 #include <mkldnn_types.h>
13 #include <mkldnn_extension_utils.h>
14
15 using namespace mkldnn;
16 using namespace MKLDNNPlugin;
17 using namespace InferenceEngine;
18
19 MKLDNNGemmNode::MKLDNNGemmNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, int socket) :
20         MKLDNNNode(layer, eng, socket) {}
21
22 void MKLDNNGemmNode::getSupportedDescriptors() {
23     auto* gemmLayer = dynamic_cast<GemmLayer*>(getCnnLayer().get());
24
25     if (gemmLayer == nullptr)
26         THROW_IE_EXCEPTION << "Cannot convert gemm layer.";
27
28     if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
29         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
30     if (getChildEdges().size() != 1)
31         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
32
33     auto inDims0 = getParentEdgeAt(0)->getDims();
34     auto inDims1 = getParentEdgeAt(1)->getDims();
35     auto outDims = getChildEdgeAt(0)->getDims();
36
37     alpha = gemmLayer->alpha;
38     beta = gemmLayer->beta;
39     transposeA = gemmLayer->transpose_a;
40     transposeB = gemmLayer->transpose_b;
41
42     if ((inDims0.ndims() < 2 || inDims0.ndims() > 4) ||
43         (inDims1.ndims() < 2 || inDims1.ndims() > 4))
44         THROW_IE_EXCEPTION << "Unsupported input dims count for layer " << getName();
45
46     if (outDims.ndims() < 2 || outDims.ndims() > 4)
47         THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
48
49     if (inDims0.ndims() != inDims1.ndims() || inDims0.ndims() != outDims.ndims())
50         THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
51
52     int nDims = inDims0.ndims();
53     xAxis = nDims - 1;
54     yAxis = nDims - 2;
55     auto xAxis0 = transposeA ? yAxis : xAxis;
56     auto yAxis0 = transposeA ? xAxis : yAxis;
57     auto xAxis1 = transposeB ? yAxis : xAxis;
58     auto yAxis1 = transposeB ? xAxis : yAxis;
59
60     // The check inDims0[xAxis] != inDims1[yAxis] is correct due to layer semantic
61     // coverity[copy_paste_error]
62     if (inDims0[xAxis0] != inDims1[yAxis1] || inDims0[yAxis0] != outDims[yAxis] || inDims1[xAxis1] != outDims[xAxis])
63         THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
64
65     isThreeInputs = getParentEdges().size() == 3;
66
67     if (isThreeInputs) {
68         auto inDims2 = getParentEdgeAt(2)->getDims();
69
70         if (inDims2.ndims() < 2 || inDims2.ndims() > 4)
71             THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
72
73         if (inDims2.ndims() != outDims.ndims())
74             THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
75
76         if (inDims2[yAxis] != outDims[yAxis] || inDims2[xAxis] != outDims[xAxis])
77             THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
78     }
79
80     for (int dim_idx = nDims - 3; dim_idx >= 0; dim_idx--) {
81         if (isThreeInputs) {
82             auto inDims2 = getParentEdgeAt(2)->getDims();
83
84             if (inDims2[dim_idx] != outDims[dim_idx] && inDims2[dim_idx] != 1)
85                 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
86
87             int cOffset = 1;
88             for (int i = dim_idx + 1; i < nDims; i++)
89                 cOffset *= inDims2[i];
90             cOffsets.push_back(inDims2[dim_idx] == outDims[dim_idx] ? cOffset : 0);
91         }
92
93         if ((inDims0[dim_idx] != outDims[dim_idx] && inDims0[dim_idx] != 1) ||
94             (inDims1[dim_idx] != outDims[dim_idx] && inDims1[dim_idx] != 1)) {
95             THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
96         }
97
98         int aOffset = 1;
99         for (int i = dim_idx + 1; i < nDims; i++)
100             aOffset *= inDims0[i];
101         aOffsets.push_back(inDims0[dim_idx] == outDims[dim_idx] ? aOffset : 0);
102
103         int bOffset = 1;
104         for (int i = dim_idx + 1; i < nDims; i++)
105             bOffset *= inDims1[i];
106         bOffsets.push_back(inDims1[dim_idx] == outDims[dim_idx] ? bOffset : 0);
107     }
108
109     for (unsigned long dim_idx = aOffsets.size(); dim_idx < 2; dim_idx++)
110         aOffsets.push_back(0);
111     for (unsigned long dim_idx = bOffsets.size(); dim_idx < 2; dim_idx++)
112         bOffsets.push_back(0);
113     for (unsigned long dim_idx = cOffsets.size(); dim_idx < 2; dim_idx++)
114         cOffsets.push_back(0);
115 }
116
117 void MKLDNNGemmNode::initSupportedPrimitiveDescriptors() {
118     if (!supportedPrimitiveDescriptors.empty())
119         return;
120
121     auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
122     auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
123
124     auto same = [&] (memory::format fmt) -> PrimitiveDescInfo {
125         InferenceEngine::LayerConfig config;
126         config.dynBatchSupport = true;
127         for (size_t i = 0; i < getParentEdges().size(); i++) {
128             InferenceEngine::DataConfig dataConfig;
129             dataConfig.inPlace = -1;
130             dataConfig.constant = false;
131             dataConfig.desc = MKLDNNMemoryDesc(getParentEdgeAt(i)->getDims(), inputDataType, fmt);
132             config.inConfs.push_back(dataConfig);
133         }
134
135         InferenceEngine::DataConfig dataConfig;
136             dataConfig.inPlace = -1;
137             dataConfig.constant = false;
138             dataConfig.desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, fmt);
139             config.outConfs.push_back(dataConfig);
140         return {config, impl_desc_type::gemm_any, fmt};
141     };
142
143     supportedPrimitiveDescriptors.push_back(same(memory::any));
144 }
145
146 void MKLDNNGemmNode::createPrimitive() {
147     auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
148     auto& src0MemPtr = getParentEdgeAt(0)->getMemoryPtr();
149     auto& src1MemPtr = getParentEdgeAt(1)->getMemoryPtr();
150     if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
151         THROW_IE_EXCEPTION << "Destination memory isn't allocated.";
152     if (!src0MemPtr || !src0MemPtr->GetPrimitivePtr() || !src1MemPtr || !src1MemPtr->GetPrimitivePtr())
153         THROW_IE_EXCEPTION << "Input memory isn't allocated.";
154     if (getSelectedPrimitiveDescriptor() == nullptr)
155         THROW_IE_EXCEPTION << "Preferable primitive descriptor isn't set.";
156
157     if (isThreeInputs) {
158         auto& src2MemPtr = getParentEdgeAt(2)->getMemoryPtr();
159         if (!src2MemPtr || !src2MemPtr->GetPrimitivePtr())
160             THROW_IE_EXCEPTION << "Input memory isn't allocated.";
161     }
162 }
163
164 void MKLDNNGemmNode::execute(mkldnn::stream strm) {
165     auto inDims0 = getParentEdgeAt(0)->getDims();
166     auto inDims1 = getParentEdgeAt(1)->getDims();
167     auto outDims = getChildEdgeAt(0)->getDims();
168
169     auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
170     auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
171     const float *src0_ptr = reinterpret_cast<const float*>(srcMemory0.GetData()) +
172                             srcMemory0.GetDescriptor().data.layout_desc.blocking.offset_padding;
173     const float *src1_ptr = reinterpret_cast<const float*>(srcMemory1.GetData()) +
174                             srcMemory1.GetDescriptor().data.layout_desc.blocking.offset_padding;
175     float *dst_ptr = reinterpret_cast<float*>(getChildEdgeAt(0)->getMemory().GetData()) +
176                      getChildEdgeAt(0)->getMemory().GetDescriptor().data.layout_desc.blocking.offset_padding;
177
178     int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
179     int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
180     int M = outDims[yAxis];
181     int N = outDims[xAxis];
182     int K = transposeA ? inDims0[yAxis] : inDims0[xAxis];
183
184     const char transa = transposeA ? 'T' : 'N';
185     const char transb = transposeB ? 'T' : 'N';
186
187     int lda = transposeA ? M : K;
188     int ldb = transposeB ? K : N;
189     int ldc = N;
190
191     const float *src2_ptr;
192     if (isThreeInputs) {
193         auto& srcMemory2 = getParentEdgeAt(2)->getMemory();
194         src2_ptr = reinterpret_cast<const float *>(srcMemory2.GetData()) +
195                                 srcMemory2.GetDescriptor().data.layout_desc.blocking.offset_padding;
196     } else {
197         src2_ptr = dst_ptr;
198     }
199
200     if (!isThreeInputs) {
201         beta = 0.f;
202     }
203
204     for (int b1 = 0; b1 < MB1; b1++) {
205         const float *a_ptr = src0_ptr;
206         const float *b_ptr = src1_ptr;
207         const float *c_ptr = src2_ptr;
208         float *d_ptr = dst_ptr;
209
210         for (int b2 = 0; b2 < MB2; b2++) {
211             if (isThreeInputs) {
212                 memcpy(d_ptr, c_ptr, M * N * sizeof(float));
213                 c_ptr += cOffsets[0];
214             }
215
216             mkldnn_sgemm(&transb, &transa, &N, &M, &K, &alpha, b_ptr, &ldb, a_ptr, &lda, &beta, d_ptr, &ldc);
217
218             a_ptr += aOffsets[0];
219             b_ptr += bOffsets[0];
220             d_ptr += M * N;
221         }
222
223         src0_ptr += aOffsets[1];
224         src1_ptr += bOffsets[1];
225         dst_ptr += MB2 * M * N;
226
227         if (isThreeInputs) {
228             src2_ptr += cOffsets[1];
229         }
230     }
231 }
232
233 bool MKLDNNGemmNode::created() const {
234     return getType() == Gemm;
235 }
236
237 int MKLDNNGemmNode::getMaxBatch() {
238     if (!outDims.empty())
239         return outDims[0][0];
240     return 0;
241 }
242 REG_MKLDNN_PRIM_FOR(MKLDNNGemmNode, Gemm);