1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_gemm_node.h"
12 #include <mkldnn_types.h>
13 #include <mkldnn_extension_utils.h>
15 using namespace mkldnn;
16 using namespace MKLDNNPlugin;
17 using namespace InferenceEngine;
19 MKLDNNGemmNode::MKLDNNGemmNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
21 void MKLDNNGemmNode::getSupportedDescriptors() {
22 auto* gemmLayer = dynamic_cast<GemmLayer*>(getCnnLayer().get());
24 if (gemmLayer == nullptr)
25 THROW_IE_EXCEPTION << "Cannot convert gemm layer.";
27 if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
28 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
29 if (getChildEdges().size() != 1)
30 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
32 auto inDims0 = getParentEdgeAt(0)->getDims();
33 auto inDims1 = getParentEdgeAt(1)->getDims();
34 auto outDims = getChildEdgeAt(0)->getDims();
36 alpha = gemmLayer->alpha;
37 beta = gemmLayer->beta;
38 transposeA = gemmLayer->transpose_a;
39 transposeB = gemmLayer->transpose_b;
41 if ((inDims0.ndims() < 2 || inDims0.ndims() > 4) ||
42 (inDims1.ndims() < 2 || inDims1.ndims() > 4))
43 THROW_IE_EXCEPTION << "Unsupported input dims count for layer " << getName();
45 if (outDims.ndims() < 2 || outDims.ndims() > 4)
46 THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
48 if (inDims0.ndims() != inDims1.ndims() || inDims0.ndims() != outDims.ndims())
49 THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
51 int nDims = inDims0.ndims();
55 if (inDims0[xAxis] != inDims1[yAxis] || inDims0[yAxis] != outDims[yAxis] || inDims1[xAxis] != outDims[xAxis])
56 THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
58 isThreeInputs = getParentEdges().size() == 3;
61 auto inDims2 = getParentEdgeAt(2)->getDims();
63 if (inDims2.ndims() < 2 || inDims2.ndims() > 4)
64 THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
66 if (inDims2.ndims() != outDims.ndims())
67 THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
69 if (inDims2[yAxis] != outDims[yAxis] || inDims2[xAxis] != outDims[xAxis])
70 THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
73 for (int dim_idx = nDims - 3; dim_idx >= 0; dim_idx--) {
75 auto inDims2 = getParentEdgeAt(2)->getDims();
77 if (inDims2[dim_idx] != outDims[dim_idx] && inDims2[dim_idx] != 1)
78 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
81 for (int i = dim_idx + 1; i < nDims; i++)
82 cOffset *= inDims2[i];
83 cOffsets.push_back(inDims2[dim_idx] == outDims[dim_idx] ? cOffset : 0);
86 if ((inDims0[dim_idx] != outDims[dim_idx] && inDims0[dim_idx] != 1) ||
87 (inDims1[dim_idx] != outDims[dim_idx] && inDims1[dim_idx] != 1)) {
88 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
92 for (int i = dim_idx + 1; i < nDims; i++)
93 aOffset *= inDims0[i];
94 aOffsets.push_back(inDims0[dim_idx] == outDims[dim_idx] ? aOffset : 0);
97 for (int i = dim_idx + 1; i < nDims; i++)
98 bOffset *= inDims1[i];
99 bOffsets.push_back(inDims1[dim_idx] == outDims[dim_idx] ? bOffset : 0);
102 for (unsigned long dim_idx = aOffsets.size(); dim_idx < 2; dim_idx++)
103 aOffsets.push_back(0);
104 for (unsigned long dim_idx = bOffsets.size(); dim_idx < 2; dim_idx++)
105 bOffsets.push_back(0);
106 for (unsigned long dim_idx = cOffsets.size(); dim_idx < 2; dim_idx++)
107 cOffsets.push_back(0);
110 void MKLDNNGemmNode::initSupportedPrimitiveDescriptors() {
111 if (!supportedPrimitiveDescriptors.empty())
114 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
115 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
117 auto same = [&] (memory::format fmt) -> PrimitiveDescInfo {
118 InferenceEngine::LayerConfig config;
119 config.dynBatchSupport = true;
120 for (size_t i = 0; i < getParentEdges().size(); i++) {
121 InferenceEngine::DataConfig dataConfig;
122 dataConfig.inPlace = -1;
123 dataConfig.constant = false;
124 dataConfig.desc = MKLDNNMemoryDesc(getParentEdgeAt(i)->getDims(), inputDataType, fmt);
125 config.inConfs.push_back(dataConfig);
128 InferenceEngine::DataConfig dataConfig;
129 dataConfig.inPlace = -1;
130 dataConfig.constant = false;
131 dataConfig.desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, fmt);
132 config.outConfs.push_back(dataConfig);
133 return {config, impl_desc_type::gemm_any};
136 supportedPrimitiveDescriptors.push_back(same(memory::any));
139 void MKLDNNGemmNode::createPrimitive() {
140 auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
141 auto& src0MemPtr = getParentEdgeAt(0)->getMemoryPtr();
142 auto& src1MemPtr = getParentEdgeAt(1)->getMemoryPtr();
143 if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
144 THROW_IE_EXCEPTION << "Destination memory isn't allocated.";
145 if (!src0MemPtr || !src0MemPtr->GetPrimitivePtr() || !src1MemPtr || !src1MemPtr->GetPrimitivePtr())
146 THROW_IE_EXCEPTION << "Input memory isn't allocated.";
147 if (getSelectedPrimitiveDescriptor() == nullptr)
148 THROW_IE_EXCEPTION << "Preferable primitive descriptor isn't set.";
151 auto& src2MemPtr = getParentEdgeAt(2)->getMemoryPtr();
152 if (!src2MemPtr || !src2MemPtr->GetPrimitivePtr())
153 THROW_IE_EXCEPTION << "Input memory isn't allocated.";
157 void MKLDNNGemmNode::execute(mkldnn::stream strm) {
158 auto inDims0 = getParentEdgeAt(0)->getDims();
159 auto inDims1 = getParentEdgeAt(1)->getDims();
160 auto outDims = getChildEdgeAt(0)->getDims();
162 auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
163 auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
164 const float *src0_ptr = reinterpret_cast<const float*>(srcMemory0.GetData()) +
165 srcMemory0.GetDescriptor().data.layout_desc.blocking.offset_padding;
166 const float *src1_ptr = reinterpret_cast<const float*>(srcMemory1.GetData()) +
167 srcMemory1.GetDescriptor().data.layout_desc.blocking.offset_padding;
168 float *dst_ptr = reinterpret_cast<float*>(getChildEdgeAt(0)->getMemory().GetData()) +
169 getChildEdgeAt(0)->getMemory().GetDescriptor().data.layout_desc.blocking.offset_padding;
171 int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
172 int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
173 int M = inDims0[yAxis];
174 int N = inDims1[xAxis];
175 int K = inDims0[xAxis];
177 const char transa = transposeA ? 'T' : 'N';
178 const char transb = transposeB ? 'T' : 'N';
180 int lda = transposeA ? M : K;
181 int ldb = transposeB ? K : N;
184 const float *src2_ptr;
186 auto& srcMemory2 = getParentEdgeAt(2)->getMemory();
187 src2_ptr = reinterpret_cast<const float *>(srcMemory2.GetData()) +
188 srcMemory2.GetDescriptor().data.layout_desc.blocking.offset_padding;
193 if (!isThreeInputs) {
197 for (int b1 = 0; b1 < MB1; b1++) {
198 const float *a_ptr = src0_ptr;
199 const float *b_ptr = src1_ptr;
200 const float *c_ptr = src2_ptr;
201 float *d_ptr = dst_ptr;
203 for (int b2 = 0; b2 < MB2; b2++) {
205 memcpy(d_ptr, c_ptr, M * N * sizeof(float));
206 c_ptr += cOffsets[0];
209 mkldnn_sgemm(&transb, &transa, &N, &M, &K, &alpha, b_ptr, &ldb, a_ptr, &lda, &beta, d_ptr, &ldc);
211 a_ptr += aOffsets[0];
212 b_ptr += bOffsets[0];
216 src0_ptr += aOffsets[1];
217 src1_ptr += bOffsets[1];
218 dst_ptr += MB2 * M * N;
221 src2_ptr += cOffsets[1];
226 bool MKLDNNGemmNode::created() const {
227 return getType() == Gemm;
230 int MKLDNNGemmNode::getMaxBatch() {
231 if (!outDims.empty())
232 return outDims[0][0];