Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_gemm_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_gemm_node.h"
6 #include <ie_layers.h>
7 #include <string>
8 #include <vector>
9 #include <memory>
10 #include <algorithm>
11 #include <cmath>
12 #include <mkldnn_types.h>
13 #include <mkldnn_extension_utils.h>
14
15 using namespace mkldnn;
16 using namespace MKLDNNPlugin;
17 using namespace InferenceEngine;
18
19 MKLDNNGemmNode::MKLDNNGemmNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
20
21 void MKLDNNGemmNode::getSupportedDescriptors() {
22     auto* gemmLayer = dynamic_cast<GemmLayer*>(getCnnLayer().get());
23
24     if (gemmLayer == nullptr)
25         THROW_IE_EXCEPTION << "Cannot convert gemm layer.";
26
27     if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
28         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
29     if (getChildEdges().size() != 1)
30         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
31
32     auto inDims0 = getParentEdgeAt(0)->getDims();
33     auto inDims1 = getParentEdgeAt(1)->getDims();
34     auto outDims = getChildEdgeAt(0)->getDims();
35
36     alpha = gemmLayer->alpha;
37     beta = gemmLayer->beta;
38     transposeA = gemmLayer->transpose_a;
39     transposeB = gemmLayer->transpose_b;
40
41     if ((inDims0.ndims() < 2 || inDims0.ndims() > 4) ||
42         (inDims1.ndims() < 2 || inDims1.ndims() > 4))
43         THROW_IE_EXCEPTION << "Unsupported input dims count for layer " << getName();
44
45     if (outDims.ndims() < 2 || outDims.ndims() > 4)
46         THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
47
48     if (inDims0.ndims() != inDims1.ndims() || inDims0.ndims() != outDims.ndims())
49         THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
50
51     int nDims = inDims0.ndims();
52     xAxis = nDims - 1;
53     yAxis = nDims - 2;
54
55     if (inDims0[xAxis] != inDims1[yAxis] || inDims0[yAxis] != outDims[yAxis] || inDims1[xAxis] != outDims[xAxis])
56         THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
57
58     isThreeInputs = getParentEdges().size() == 3;
59
60     if (isThreeInputs) {
61         auto inDims2 = getParentEdgeAt(2)->getDims();
62
63         if (inDims2.ndims() < 2 || inDims2.ndims() > 4)
64             THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
65
66         if (inDims2.ndims() != outDims.ndims())
67             THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
68
69         if (inDims2[yAxis] != outDims[yAxis] || inDims2[xAxis] != outDims[xAxis])
70             THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
71     }
72
73     for (int dim_idx = nDims - 3; dim_idx >= 0; dim_idx--) {
74         if (isThreeInputs) {
75             auto inDims2 = getParentEdgeAt(2)->getDims();
76
77             if (inDims2[dim_idx] != outDims[dim_idx] && inDims2[dim_idx] != 1)
78                 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
79
80             int cOffset = 1;
81             for (int i = dim_idx + 1; i < nDims; i++)
82                 cOffset *= inDims2[i];
83             cOffsets.push_back(inDims2[dim_idx] == outDims[dim_idx] ? cOffset : 0);
84         }
85
86         if ((inDims0[dim_idx] != outDims[dim_idx] && inDims0[dim_idx] != 1) ||
87             (inDims1[dim_idx] != outDims[dim_idx] && inDims1[dim_idx] != 1)) {
88             THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
89         }
90
91         int aOffset = 1;
92         for (int i = dim_idx + 1; i < nDims; i++)
93             aOffset *= inDims0[i];
94         aOffsets.push_back(inDims0[dim_idx] == outDims[dim_idx] ? aOffset : 0);
95
96         int bOffset = 1;
97         for (int i = dim_idx + 1; i < nDims; i++)
98             bOffset *= inDims1[i];
99         bOffsets.push_back(inDims1[dim_idx] == outDims[dim_idx] ? bOffset : 0);
100     }
101
102     for (unsigned long dim_idx = aOffsets.size(); dim_idx < 2; dim_idx++)
103         aOffsets.push_back(0);
104     for (unsigned long dim_idx = bOffsets.size(); dim_idx < 2; dim_idx++)
105         bOffsets.push_back(0);
106     for (unsigned long dim_idx = cOffsets.size(); dim_idx < 2; dim_idx++)
107         cOffsets.push_back(0);
108 }
109
110 void MKLDNNGemmNode::initSupportedPrimitiveDescriptors() {
111     if (!supportedPrimitiveDescriptors.empty())
112         return;
113
114     auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
115     auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
116
117     auto same = [&] (memory::format fmt) -> PrimitiveDescInfo {
118         InferenceEngine::LayerConfig config;
119         config.dynBatchSupport = true;
120         for (size_t i = 0; i < getParentEdges().size(); i++) {
121             InferenceEngine::DataConfig dataConfig;
122             dataConfig.inPlace = -1;
123             dataConfig.constant = false;
124             dataConfig.desc = MKLDNNMemoryDesc(getParentEdgeAt(i)->getDims(), inputDataType, fmt);
125             config.inConfs.push_back(dataConfig);
126         }
127
128         InferenceEngine::DataConfig dataConfig;
129             dataConfig.inPlace = -1;
130             dataConfig.constant = false;
131             dataConfig.desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, fmt);
132             config.outConfs.push_back(dataConfig);
133         return {config, impl_desc_type::gemm_any};
134     };
135
136     supportedPrimitiveDescriptors.push_back(same(memory::any));
137 }
138
139 void MKLDNNGemmNode::createPrimitive() {
140     auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
141     auto& src0MemPtr = getParentEdgeAt(0)->getMemoryPtr();
142     auto& src1MemPtr = getParentEdgeAt(1)->getMemoryPtr();
143     if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
144         THROW_IE_EXCEPTION << "Destination memory isn't allocated.";
145     if (!src0MemPtr || !src0MemPtr->GetPrimitivePtr() || !src1MemPtr || !src1MemPtr->GetPrimitivePtr())
146         THROW_IE_EXCEPTION << "Input memory isn't allocated.";
147     if (getSelectedPrimitiveDescriptor() == nullptr)
148         THROW_IE_EXCEPTION << "Preferable primitive descriptor isn't set.";
149
150     if (isThreeInputs) {
151         auto& src2MemPtr = getParentEdgeAt(2)->getMemoryPtr();
152         if (!src2MemPtr || !src2MemPtr->GetPrimitivePtr())
153             THROW_IE_EXCEPTION << "Input memory isn't allocated.";
154     }
155 }
156
157 void MKLDNNGemmNode::execute(mkldnn::stream strm) {
158     auto inDims0 = getParentEdgeAt(0)->getDims();
159     auto inDims1 = getParentEdgeAt(1)->getDims();
160     auto outDims = getChildEdgeAt(0)->getDims();
161
162     auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
163     auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
164     const float *src0_ptr = reinterpret_cast<const float*>(srcMemory0.GetData()) +
165                             srcMemory0.GetDescriptor().data.layout_desc.blocking.offset_padding;
166     const float *src1_ptr = reinterpret_cast<const float*>(srcMemory1.GetData()) +
167                             srcMemory1.GetDescriptor().data.layout_desc.blocking.offset_padding;
168     float *dst_ptr = reinterpret_cast<float*>(getChildEdgeAt(0)->getMemory().GetData()) +
169                      getChildEdgeAt(0)->getMemory().GetDescriptor().data.layout_desc.blocking.offset_padding;
170
171     int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
172     int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
173     int M = inDims0[yAxis];
174     int N = inDims1[xAxis];
175     int K = inDims0[xAxis];
176
177     const char transa = transposeA ? 'T' : 'N';
178     const char transb = transposeB ? 'T' : 'N';
179
180     int lda = transposeA ? M : K;
181     int ldb = transposeB ? K : N;
182     int ldc = N;
183
184     const float *src2_ptr;
185     if (isThreeInputs) {
186         auto& srcMemory2 = getParentEdgeAt(2)->getMemory();
187         src2_ptr = reinterpret_cast<const float *>(srcMemory2.GetData()) +
188                                 srcMemory2.GetDescriptor().data.layout_desc.blocking.offset_padding;
189     } else {
190         src2_ptr = dst_ptr;
191     }
192
193     if (!isThreeInputs) {
194         beta = 0.f;
195     }
196
197     for (int b1 = 0; b1 < MB1; b1++) {
198         const float *a_ptr = src0_ptr;
199         const float *b_ptr = src1_ptr;
200         const float *c_ptr = src2_ptr;
201         float *d_ptr = dst_ptr;
202
203         for (int b2 = 0; b2 < MB2; b2++) {
204             if (isThreeInputs) {
205                 memcpy(d_ptr, c_ptr, M * N * sizeof(float));
206                 c_ptr += cOffsets[0];
207             }
208
209             mkldnn_sgemm(&transb, &transa, &N, &M, &K, &alpha, b_ptr, &ldb, a_ptr, &lda, &beta, d_ptr, &ldc);
210
211             a_ptr += aOffsets[0];
212             b_ptr += bOffsets[0];
213             d_ptr += M * N;
214         }
215
216         src0_ptr += aOffsets[1];
217         src1_ptr += bOffsets[1];
218         dst_ptr += MB2 * M * N;
219
220         if (isThreeInputs) {
221             src2_ptr += cOffsets[1];
222         }
223     }
224 }
225
226 bool MKLDNNGemmNode::created() const {
227     return getType() == Gemm;
228 }
229
230 int MKLDNNGemmNode::getMaxBatch() {
231     if (!outDims.empty())
232         return outDims[0][0];
233     return 0;
234 }