1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_gemm_node.h"
12 #include <mkldnn_types.h>
13 #include <mkldnn_extension_utils.h>
15 using namespace mkldnn;
16 using namespace MKLDNNPlugin;
17 using namespace InferenceEngine;
19 MKLDNNGemmNode::MKLDNNGemmNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
21 void MKLDNNGemmNode::getSupportedDescriptors() {
22 auto* gemmLayer = dynamic_cast<GemmLayer*>(getCnnLayer().get());
24 if (gemmLayer == nullptr)
25 THROW_IE_EXCEPTION << "Cannot convert gemm layer.";
27 if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
28 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
29 if (getChildEdges().size() != 1)
30 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
32 auto inDims0 = getParentEdgeAt(0)->getDims();
33 auto inDims1 = getParentEdgeAt(1)->getDims();
34 auto outDims = getChildEdgeAt(0)->getDims();
36 alpha = gemmLayer->alpha;
37 beta = gemmLayer->beta;
38 transposeA = gemmLayer->transpose_a;
39 transposeB = gemmLayer->transpose_b;
41 if ((inDims0.ndims() < 2 || inDims0.ndims() > 4) ||
42 (inDims1.ndims() < 2 || inDims1.ndims() > 4))
43 THROW_IE_EXCEPTION << "Unsupported input dims count for layer " << getName();
45 if (outDims.ndims() < 2 || outDims.ndims() > 4)
46 THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
48 if (inDims0.ndims() != inDims1.ndims() || inDims0.ndims() != outDims.ndims())
49 THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
51 int nDims = inDims0.ndims();
55 // The check inDims0[xAxis] != inDims1[yAxis] is correct due to layer semantic
56 // coverity[copy_paste_error]
57 if (inDims0[xAxis] != inDims1[yAxis] || inDims0[yAxis] != outDims[yAxis] || inDims1[xAxis] != outDims[xAxis])
58 THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
60 isThreeInputs = getParentEdges().size() == 3;
63 auto inDims2 = getParentEdgeAt(2)->getDims();
65 if (inDims2.ndims() < 2 || inDims2.ndims() > 4)
66 THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
68 if (inDims2.ndims() != outDims.ndims())
69 THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
71 if (inDims2[yAxis] != outDims[yAxis] || inDims2[xAxis] != outDims[xAxis])
72 THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
75 for (int dim_idx = nDims - 3; dim_idx >= 0; dim_idx--) {
77 auto inDims2 = getParentEdgeAt(2)->getDims();
79 if (inDims2[dim_idx] != outDims[dim_idx] && inDims2[dim_idx] != 1)
80 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
83 for (int i = dim_idx + 1; i < nDims; i++)
84 cOffset *= inDims2[i];
85 cOffsets.push_back(inDims2[dim_idx] == outDims[dim_idx] ? cOffset : 0);
88 if ((inDims0[dim_idx] != outDims[dim_idx] && inDims0[dim_idx] != 1) ||
89 (inDims1[dim_idx] != outDims[dim_idx] && inDims1[dim_idx] != 1)) {
90 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
94 for (int i = dim_idx + 1; i < nDims; i++)
95 aOffset *= inDims0[i];
96 aOffsets.push_back(inDims0[dim_idx] == outDims[dim_idx] ? aOffset : 0);
99 for (int i = dim_idx + 1; i < nDims; i++)
100 bOffset *= inDims1[i];
101 bOffsets.push_back(inDims1[dim_idx] == outDims[dim_idx] ? bOffset : 0);
104 for (unsigned long dim_idx = aOffsets.size(); dim_idx < 2; dim_idx++)
105 aOffsets.push_back(0);
106 for (unsigned long dim_idx = bOffsets.size(); dim_idx < 2; dim_idx++)
107 bOffsets.push_back(0);
108 for (unsigned long dim_idx = cOffsets.size(); dim_idx < 2; dim_idx++)
109 cOffsets.push_back(0);
112 void MKLDNNGemmNode::initSupportedPrimitiveDescriptors() {
113 if (!supportedPrimitiveDescriptors.empty())
116 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
117 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
119 auto same = [&] (memory::format fmt) -> PrimitiveDescInfo {
120 InferenceEngine::LayerConfig config;
121 config.dynBatchSupport = true;
122 for (size_t i = 0; i < getParentEdges().size(); i++) {
123 InferenceEngine::DataConfig dataConfig;
124 dataConfig.inPlace = -1;
125 dataConfig.constant = false;
126 dataConfig.desc = MKLDNNMemoryDesc(getParentEdgeAt(i)->getDims(), inputDataType, fmt);
127 config.inConfs.push_back(dataConfig);
130 InferenceEngine::DataConfig dataConfig;
131 dataConfig.inPlace = -1;
132 dataConfig.constant = false;
133 dataConfig.desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, fmt);
134 config.outConfs.push_back(dataConfig);
135 return {config, impl_desc_type::gemm_any};
138 supportedPrimitiveDescriptors.push_back(same(memory::any));
141 void MKLDNNGemmNode::createPrimitive() {
142 auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
143 auto& src0MemPtr = getParentEdgeAt(0)->getMemoryPtr();
144 auto& src1MemPtr = getParentEdgeAt(1)->getMemoryPtr();
145 if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
146 THROW_IE_EXCEPTION << "Destination memory isn't allocated.";
147 if (!src0MemPtr || !src0MemPtr->GetPrimitivePtr() || !src1MemPtr || !src1MemPtr->GetPrimitivePtr())
148 THROW_IE_EXCEPTION << "Input memory isn't allocated.";
149 if (getSelectedPrimitiveDescriptor() == nullptr)
150 THROW_IE_EXCEPTION << "Preferable primitive descriptor isn't set.";
153 auto& src2MemPtr = getParentEdgeAt(2)->getMemoryPtr();
154 if (!src2MemPtr || !src2MemPtr->GetPrimitivePtr())
155 THROW_IE_EXCEPTION << "Input memory isn't allocated.";
159 void MKLDNNGemmNode::execute(mkldnn::stream strm) {
160 auto inDims0 = getParentEdgeAt(0)->getDims();
161 auto inDims1 = getParentEdgeAt(1)->getDims();
162 auto outDims = getChildEdgeAt(0)->getDims();
164 auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
165 auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
166 const float *src0_ptr = reinterpret_cast<const float*>(srcMemory0.GetData()) +
167 srcMemory0.GetDescriptor().data.layout_desc.blocking.offset_padding;
168 const float *src1_ptr = reinterpret_cast<const float*>(srcMemory1.GetData()) +
169 srcMemory1.GetDescriptor().data.layout_desc.blocking.offset_padding;
170 float *dst_ptr = reinterpret_cast<float*>(getChildEdgeAt(0)->getMemory().GetData()) +
171 getChildEdgeAt(0)->getMemory().GetDescriptor().data.layout_desc.blocking.offset_padding;
173 int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
174 int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
175 int M = inDims0[yAxis];
176 int N = inDims1[xAxis];
177 int K = inDims0[xAxis];
179 const char transa = transposeA ? 'T' : 'N';
180 const char transb = transposeB ? 'T' : 'N';
182 int lda = transposeA ? M : K;
183 int ldb = transposeB ? K : N;
186 const float *src2_ptr;
188 auto& srcMemory2 = getParentEdgeAt(2)->getMemory();
189 src2_ptr = reinterpret_cast<const float *>(srcMemory2.GetData()) +
190 srcMemory2.GetDescriptor().data.layout_desc.blocking.offset_padding;
195 if (!isThreeInputs) {
199 for (int b1 = 0; b1 < MB1; b1++) {
200 const float *a_ptr = src0_ptr;
201 const float *b_ptr = src1_ptr;
202 const float *c_ptr = src2_ptr;
203 float *d_ptr = dst_ptr;
205 for (int b2 = 0; b2 < MB2; b2++) {
207 memcpy(d_ptr, c_ptr, M * N * sizeof(float));
208 c_ptr += cOffsets[0];
211 mkldnn_sgemm(&transb, &transa, &N, &M, &K, &alpha, b_ptr, &ldb, a_ptr, &lda, &beta, d_ptr, &ldc);
213 a_ptr += aOffsets[0];
214 b_ptr += bOffsets[0];
218 src0_ptr += aOffsets[1];
219 src1_ptr += bOffsets[1];
220 dst_ptr += MB2 * M * N;
223 src2_ptr += cOffsets[1];
228 bool MKLDNNGemmNode::created() const {
229 return getType() == Gemm;
232 int MKLDNNGemmNode::getMaxBatch() {
233 if (!outDims.empty())
234 return outDims[0][0];