1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_gemm_node.h"
12 #include <mkldnn_types.h>
13 #include <mkldnn_extension_utils.h>
15 using namespace mkldnn;
16 using namespace MKLDNNPlugin;
17 using namespace InferenceEngine;
19 MKLDNNGemmNode::MKLDNNGemmNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, int socket) :
20 MKLDNNNode(layer, eng, socket) {}
22 void MKLDNNGemmNode::getSupportedDescriptors() {
23 auto* gemmLayer = dynamic_cast<GemmLayer*>(getCnnLayer().get());
25 if (gemmLayer == nullptr)
26 THROW_IE_EXCEPTION << "Cannot convert gemm layer.";
28 if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
29 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
30 if (getChildEdges().size() != 1)
31 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
33 auto inDims0 = getParentEdgeAt(0)->getDims();
34 auto inDims1 = getParentEdgeAt(1)->getDims();
35 auto outDims = getChildEdgeAt(0)->getDims();
37 alpha = gemmLayer->alpha;
38 beta = gemmLayer->beta;
39 transposeA = gemmLayer->transpose_a;
40 transposeB = gemmLayer->transpose_b;
42 if ((inDims0.ndims() < 2 || inDims0.ndims() > 4) ||
43 (inDims1.ndims() < 2 || inDims1.ndims() > 4))
44 THROW_IE_EXCEPTION << "Unsupported input dims count for layer " << getName();
46 if (outDims.ndims() < 2 || outDims.ndims() > 4)
47 THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
49 if (inDims0.ndims() != inDims1.ndims() || inDims0.ndims() != outDims.ndims())
50 THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
52 int nDims = inDims0.ndims();
55 auto xAxis0 = transposeA ? yAxis : xAxis;
56 auto yAxis0 = transposeA ? xAxis : yAxis;
57 auto xAxis1 = transposeB ? yAxis : xAxis;
58 auto yAxis1 = transposeB ? xAxis : yAxis;
60 // The check inDims0[xAxis] != inDims1[yAxis] is correct due to layer semantic
61 // coverity[copy_paste_error]
62 if (inDims0[xAxis0] != inDims1[yAxis1] || inDims0[yAxis0] != outDims[yAxis] || inDims1[xAxis1] != outDims[xAxis])
63 THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
65 isThreeInputs = getParentEdges().size() == 3;
68 auto inDims2 = getParentEdgeAt(2)->getDims();
70 if (inDims2.ndims() < 2 || inDims2.ndims() > 4)
71 THROW_IE_EXCEPTION << "Unsupported output dims count for layer " << getName();
73 if (inDims2.ndims() != outDims.ndims())
74 THROW_IE_EXCEPTION << "Invalid dims count for layer " << getName();
76 if (inDims2[yAxis] != outDims[yAxis] || inDims2[xAxis] != outDims[xAxis])
77 THROW_IE_EXCEPTION << "Spatial input and output dimensions are incorrect for layer " << getName();
80 for (int dim_idx = nDims - 3; dim_idx >= 0; dim_idx--) {
82 auto inDims2 = getParentEdgeAt(2)->getDims();
84 if (inDims2[dim_idx] != outDims[dim_idx] && inDims2[dim_idx] != 1)
85 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
88 for (int i = dim_idx + 1; i < nDims; i++)
89 cOffset *= inDims2[i];
90 cOffsets.push_back(inDims2[dim_idx] == outDims[dim_idx] ? cOffset : 0);
93 if ((inDims0[dim_idx] != outDims[dim_idx] && inDims0[dim_idx] != 1) ||
94 (inDims1[dim_idx] != outDims[dim_idx] && inDims1[dim_idx] != 1)) {
95 THROW_IE_EXCEPTION << "Input batch dimensions are incorrect for layer " << getName();
99 for (int i = dim_idx + 1; i < nDims; i++)
100 aOffset *= inDims0[i];
101 aOffsets.push_back(inDims0[dim_idx] == outDims[dim_idx] ? aOffset : 0);
104 for (int i = dim_idx + 1; i < nDims; i++)
105 bOffset *= inDims1[i];
106 bOffsets.push_back(inDims1[dim_idx] == outDims[dim_idx] ? bOffset : 0);
109 for (unsigned long dim_idx = aOffsets.size(); dim_idx < 2; dim_idx++)
110 aOffsets.push_back(0);
111 for (unsigned long dim_idx = bOffsets.size(); dim_idx < 2; dim_idx++)
112 bOffsets.push_back(0);
113 for (unsigned long dim_idx = cOffsets.size(); dim_idx < 2; dim_idx++)
114 cOffsets.push_back(0);
117 void MKLDNNGemmNode::initSupportedPrimitiveDescriptors() {
118 if (!supportedPrimitiveDescriptors.empty())
121 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
122 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(InferenceEngine::Precision::FP32);
124 auto same = [&] (memory::format fmt) -> PrimitiveDescInfo {
125 InferenceEngine::LayerConfig config;
126 config.dynBatchSupport = true;
127 for (size_t i = 0; i < getParentEdges().size(); i++) {
128 InferenceEngine::DataConfig dataConfig;
129 dataConfig.inPlace = -1;
130 dataConfig.constant = false;
131 dataConfig.desc = MKLDNNMemoryDesc(getParentEdgeAt(i)->getDims(), inputDataType, fmt);
132 config.inConfs.push_back(dataConfig);
135 InferenceEngine::DataConfig dataConfig;
136 dataConfig.inPlace = -1;
137 dataConfig.constant = false;
138 dataConfig.desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, fmt);
139 config.outConfs.push_back(dataConfig);
140 return {config, impl_desc_type::gemm_any, fmt};
143 supportedPrimitiveDescriptors.push_back(same(memory::any));
146 void MKLDNNGemmNode::createPrimitive() {
147 auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
148 auto& src0MemPtr = getParentEdgeAt(0)->getMemoryPtr();
149 auto& src1MemPtr = getParentEdgeAt(1)->getMemoryPtr();
150 if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
151 THROW_IE_EXCEPTION << "Destination memory isn't allocated.";
152 if (!src0MemPtr || !src0MemPtr->GetPrimitivePtr() || !src1MemPtr || !src1MemPtr->GetPrimitivePtr())
153 THROW_IE_EXCEPTION << "Input memory isn't allocated.";
154 if (getSelectedPrimitiveDescriptor() == nullptr)
155 THROW_IE_EXCEPTION << "Preferable primitive descriptor isn't set.";
158 auto& src2MemPtr = getParentEdgeAt(2)->getMemoryPtr();
159 if (!src2MemPtr || !src2MemPtr->GetPrimitivePtr())
160 THROW_IE_EXCEPTION << "Input memory isn't allocated.";
164 void MKLDNNGemmNode::execute(mkldnn::stream strm) {
165 auto inDims0 = getParentEdgeAt(0)->getDims();
166 auto inDims1 = getParentEdgeAt(1)->getDims();
167 auto outDims = getChildEdgeAt(0)->getDims();
169 auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
170 auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
171 const float *src0_ptr = reinterpret_cast<const float*>(srcMemory0.GetData()) +
172 srcMemory0.GetDescriptor().data.layout_desc.blocking.offset_padding;
173 const float *src1_ptr = reinterpret_cast<const float*>(srcMemory1.GetData()) +
174 srcMemory1.GetDescriptor().data.layout_desc.blocking.offset_padding;
175 float *dst_ptr = reinterpret_cast<float*>(getChildEdgeAt(0)->getMemory().GetData()) +
176 getChildEdgeAt(0)->getMemory().GetDescriptor().data.layout_desc.blocking.offset_padding;
178 int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
179 int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
180 int M = outDims[yAxis];
181 int N = outDims[xAxis];
182 int K = transposeA ? inDims0[yAxis] : inDims0[xAxis];
184 const char transa = transposeA ? 'T' : 'N';
185 const char transb = transposeB ? 'T' : 'N';
187 int lda = transposeA ? M : K;
188 int ldb = transposeB ? K : N;
191 const float *src2_ptr;
193 auto& srcMemory2 = getParentEdgeAt(2)->getMemory();
194 src2_ptr = reinterpret_cast<const float *>(srcMemory2.GetData()) +
195 srcMemory2.GetDescriptor().data.layout_desc.blocking.offset_padding;
200 if (!isThreeInputs) {
204 for (int b1 = 0; b1 < MB1; b1++) {
205 const float *a_ptr = src0_ptr;
206 const float *b_ptr = src1_ptr;
207 const float *c_ptr = src2_ptr;
208 float *d_ptr = dst_ptr;
210 for (int b2 = 0; b2 < MB2; b2++) {
212 memcpy(d_ptr, c_ptr, M * N * sizeof(float));
213 c_ptr += cOffsets[0];
216 mkldnn_sgemm(&transb, &transa, &N, &M, &K, &alpha, b_ptr, &ldb, a_ptr, &lda, &beta, d_ptr, &ldc);
218 a_ptr += aOffsets[0];
219 b_ptr += bOffsets[0];
223 src0_ptr += aOffsets[1];
224 src1_ptr += bOffsets[1];
225 dst_ptr += MB2 * M * N;
228 src2_ptr += cOffsets[1];
233 bool MKLDNNGemmNode::created() const {
234 return getType() == Gemm;
237 int MKLDNNGemmNode::getMaxBatch() {
238 if (!outDims.empty())
239 return outDims[0][0];
242 REG_MKLDNN_PRIM_FOR(MKLDNNGemmNode, Gemm);