1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
16 using namespace ::testing;
18 using namespace mkldnn;
20 struct gemm_test_params {
44 MKLDNNPlugin::impl_desc_type selectedType;
46 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
49 template<typename data_t>
50 void ref_gemm(const std::vector<InferenceEngine::TBlob<data_t>> &src, InferenceEngine::TBlob<data_t> &dst,
51 gemm_test_params prm) {
52 const data_t *src0_data = src[0].readOnly();
53 const data_t *src1_data = src[1].readOnly();
54 const data_t *src2_data = src.size() == 3 ? src[2].readOnly() : dst.readOnly();
55 data_t *dst_data = dst.data();
57 size_t MB1 = prm.batches.MB1_D;
58 size_t MB2 = prm.batches.MB2_D;
63 for (int mb1 = 0; mb1 < MB1; mb1++) {
64 const data_t *a_data = src0_data;
65 const data_t *b_data = src1_data;
66 const data_t *c_data = src2_data;
67 data_t *d_data = dst_data;
69 for (int mb2 = 0; mb2 < MB2; mb2++) {
70 for (int i = 0; i < M; i++) {
71 for (int j = 0; j < N; j++) {
72 d_data[i * N + j] = src.size() == 3 ? prm.beta * c_data[i * N + j] : 0;
74 for (int k = 0; k < K; k++) {
75 size_t src0_off = prm.transposeA ? k * M + i : i * K + k;
76 size_t src1_off = prm.transposeB ? j * K + k : k * N + j;
77 d_data[i * N + j] += prm.alpha * a_data[src0_off] * b_data[src1_off];
81 a_data += prm.batches.MB2_A == MB2 ? M*K : 0;
82 b_data += prm.batches.MB2_B == MB2 ? K*N : 0;
83 c_data += prm.batches.MB2_C == MB2 ? M*N : 0;
87 src0_data += prm.batches.MB1_A == MB1 ? prm.batches.MB2_A*M*K : 0;
88 src1_data += prm.batches.MB1_B == MB1 ? prm.batches.MB2_B*K*N : 0;
89 src2_data += prm.batches.MB1_C == MB1 ? prm.batches.MB2_C*M*N : 0;
90 dst_data += prm.batches.MB2_D*M*N;
94 class MKLDNNGraphGemmTests: public TestsCommon,
95 public WithParamInterface<gemm_test_params> {
96 std::string model_t = R"V0G0N(
97 <net name="gemmOnly" version="2" precision="FP32" batch="1">
99 <layer name="in1" type="Input" precision="FP32" id="1">
109 <layer name="in2" type="Input" precision="FP32" id="2">
119 <layer name="in3" type="Input" precision="FP32" id="3">
129 <layer name="gemm" id="4" type="GEMM" precision="FP32">
130 <data alpha="_A_" beta="_B_" transpose_a="_TA_" transpose_b="_TB_"/>
162 <edge from-layer="1" from-port="1" to-layer="4" to-port="1"/>
163 <edge from-layer="2" from-port="1" to-layer="4" to-port="2"/>
164 <edge from-layer="3" from-port="1" to-layer="4" to-port="3"/>
170 std::string getModel(gemm_test_params p) {
171 std::string model = model_t;
174 REPLACE_WITH_NUM(model, "_MB1_A_", p.batches.MB1_A);
175 REPLACE_WITH_NUM(model, "_MB2_A_", p.batches.MB2_A);
176 REPLACE_WITH_NUM(model, "_MB1_B_", p.batches.MB1_B);
177 REPLACE_WITH_NUM(model, "_MB2_B_", p.batches.MB2_B);
178 REPLACE_WITH_NUM(model, "_MB1_C_", p.batches.MB1_C);
179 REPLACE_WITH_NUM(model, "_MB2_C_", p.batches.MB2_C);
180 REPLACE_WITH_NUM(model, "_MB1_D_", p.batches.MB1_D);
181 REPLACE_WITH_NUM(model, "_MB2_D_", p.batches.MB2_D);
183 REPLACE_WITH_NUM(model, "_M_", p.M);
184 REPLACE_WITH_NUM(model, "_N_", p.N);
185 REPLACE_WITH_NUM(model, "_K_", p.K);
187 REPLACE_WITH_NUM(model, "_A_", p.alpha);
188 REPLACE_WITH_NUM(model, "_B_", p.beta);
189 REPLACE_WITH_NUM(model, "_TA_", p.transposeA);
190 REPLACE_WITH_NUM(model, "_TB_", p.transposeB);
195 virtual void TearDown() {
198 virtual void SetUp() {
200 TestsCommon::SetUp();
201 gemm_test_params p = ::testing::WithParamInterface<gemm_test_params>::GetParam();
202 std::string model = getModel(p);
204 InferenceEngine::CNNNetReader net_reader;
205 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
207 MKLDNNGraphTestClass graph;
208 graph.CreateGraph(net_reader.getNetwork());
210 auto& nodes = graph.getNodes();
211 for (int i = 0; i < nodes.size(); i++) {
212 if (nodes[i]->getType() == MKLDNNPlugin::Gemm) {
213 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
214 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
215 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
217 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
218 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
222 InferenceEngine::SizeVector dims_src1 = {p.batches.MB1_A, p.batches.MB2_A, p.M, p.K};
223 InferenceEngine::SizeVector dims_src2 = {p.batches.MB1_B, p.batches.MB2_B, p.K, p.N};
224 InferenceEngine::SizeVector dims_src3 = {p.batches.MB1_C, p.batches.MB2_C, p.M, p.N};
225 InferenceEngine::SizeVector dims_dst = {p.batches.MB1_D, p.batches.MB2_D, p.M, p.N};
227 InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
229 InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
230 if (srcPtr1 == nullptr)
231 FAIL() << "Cannot cast blob to TBlob<float>.";
232 fill_data(src1->buffer(), src1->size());
234 InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
236 InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
237 if (srcPtr2 == nullptr)
238 FAIL() << "Cannot cast blob to TBlob<float>.";
239 fill_data(src2->buffer(), src2->size());
241 InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src3);
243 InferenceEngine::TBlob<float>* srcPtr3 = dynamic_cast<InferenceEngine::TBlob<float>*>(src3.get());
244 if (srcPtr3 == nullptr)
245 FAIL() << "Cannot cast blob to TBlob<float>.";
246 fill_data(src3->buffer(), src3->size());
248 InferenceEngine::BlobMap srcs;
249 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
250 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
251 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
253 InferenceEngine::OutputsDataMap out;
254 out = net_reader.getNetwork().getOutputsInfo();
255 InferenceEngine::BlobMap outputBlobs;
257 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
259 InferenceEngine::TBlob<float>::Ptr output;
260 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
262 outputBlobs[item.first] = output;
264 graph.Infer(srcs, outputBlobs);
266 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
269 std::vector<InferenceEngine::TBlob<float>> src_vec = {*srcPtr1, *srcPtr2, *srcPtr3};
271 ref_gemm(src_vec, dst_ref, p);
273 compare(*output, dst_ref);
274 } catch (const InferenceEngine::details::InferenceEngineException &e) {
280 TEST_P(MKLDNNGraphGemmTests, TestsGemm) {}
282 INSTANTIATE_TEST_CASE_P(
283 TestsGemm, MKLDNNGraphGemmTests,
285 gemm_test_params{{2, 1, 2, 1, 2, 1, 2, 1}, 3, 3, 2, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
286 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
287 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
288 ASSERT_EQ(3, impl.getConfig().inConfs.size());
289 ASSERT_EQ(1, impl.getConfig().outConfs.size());
290 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
291 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
292 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
293 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
296 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 8, 5, 4, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
297 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
298 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
299 ASSERT_EQ(3, impl.getConfig().inConfs.size());
300 ASSERT_EQ(1, impl.getConfig().outConfs.size());
301 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
302 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
303 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
304 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
307 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 16, 10, 12, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
308 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
309 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
310 ASSERT_EQ(3, impl.getConfig().inConfs.size());
311 ASSERT_EQ(1, impl.getConfig().outConfs.size());
312 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
313 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
314 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
315 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
318 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 11, 10, 20, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
319 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
320 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
321 ASSERT_EQ(3, impl.getConfig().inConfs.size());
322 ASSERT_EQ(1, impl.getConfig().outConfs.size());
323 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
324 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
325 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
326 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
329 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 13, 2, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
330 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
331 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
332 ASSERT_EQ(3, impl.getConfig().inConfs.size());
333 ASSERT_EQ(1, impl.getConfig().outConfs.size());
334 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
335 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
336 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
337 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
340 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 15, 10, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any, {
341 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
342 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::gemm_any, impl.getImplementationType());
343 ASSERT_EQ(3, impl.getConfig().inConfs.size());
344 ASSERT_EQ(1, impl.getConfig().outConfs.size());
345 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
346 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
347 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
348 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
351 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 6, 7, 2, 0, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
352 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 5, 6, 7, 0, 2, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
353 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 3, 7, 4, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
354 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 3, 4, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
355 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
356 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 3, 7, 4, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
357 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 3, 4, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
358 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
359 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 3, 7, 4, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
360 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 3, 4, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
361 gemm_test_params{{3, 2, 3, 2, 3, 2, 3, 2}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
362 gemm_test_params{{1, 3, 2, 3, 2, 3, 2, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
363 gemm_test_params{{1, 3, 2, 3, 1, 3, 2, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
364 gemm_test_params{{2, 3, 1, 3, 1, 3, 2, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
365 gemm_test_params{{5, 3, 5, 1, 5, 3, 5, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
366 gemm_test_params{{5, 3, 5, 1, 5, 1, 5, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
367 gemm_test_params{{5, 1, 5, 1, 5, 3, 5, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
368 gemm_test_params{{1, 1, 5, 3, 5, 3, 5, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
369 gemm_test_params{{1, 1, 1, 1, 5, 3, 5, 3}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
370 gemm_test_params{{5, 4, 1, 1, 1, 1, 5, 4}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any}
373 class MKLDNNGraphDynBatchGemmTests: public MKLDNNGraphGemmTests {
375 virtual void SetUp() {
377 TestsCommon::SetUp();
378 gemm_test_params p = ::testing::WithParamInterface<gemm_test_params>::GetParam();
379 std::string model = getModel(p);
380 size_t MB = p.batches.MB1_D;
384 InferenceEngine::CNNNetReader net_reader;
385 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
386 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
387 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
388 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
389 InferenceEngine::ResponseDesc resp;
390 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
391 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
393 MKLDNNGraphTestClass graph;
394 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
395 graph.CreateGraph(net_reader.getNetwork());
397 InferenceEngine::SizeVector dims_src1 = {MB, p.batches.MB2_A, p.M, p.K};
398 InferenceEngine::SizeVector dims_src2 = {MB, p.batches.MB2_B, p.K, p.N};
399 InferenceEngine::SizeVector dims_src3 = {MB, p.batches.MB2_C, p.M, p.N};
400 InferenceEngine::SizeVector dims_dst = {MB, p.batches.MB2_D, p.M, p.N};
402 InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
404 InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
405 if (srcPtr1 == nullptr)
406 FAIL() << "Cannot cast blob to TBlob<float>.";
407 fill_data(src1->buffer(), src1->size());
409 InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
411 InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
412 if (srcPtr2 == nullptr)
413 FAIL() << "Cannot cast blob to TBlob<float>.";
414 fill_data(src2->buffer(), src2->size());
416 InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src3);
418 InferenceEngine::TBlob<float>* srcPtr3 = dynamic_cast<InferenceEngine::TBlob<float>*>(src3.get());
419 if (srcPtr3 == nullptr)
420 FAIL() << "Cannot cast blob to TBlob<float>.";
421 fill_data(src3->buffer(), src3->size());
423 InferenceEngine::BlobMap srcs;
424 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
425 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
426 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
428 InferenceEngine::OutputsDataMap out;
429 out = net_reader.getNetwork().getOutputsInfo();
430 InferenceEngine::BlobMap outputBlobs;
432 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
434 InferenceEngine::TBlob<float>::Ptr output;
435 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
437 outputBlobs[item.first] = output;
439 auto check = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
440 return node->getType() == MKLDNNPlugin::Gemm;
443 graph.checkDynBatch(srcs, outputBlobs, MB, MB, check);
444 graph.checkDynBatch(srcs, outputBlobs, 1, MB, check);
445 } catch (const InferenceEngine::details::InferenceEngineException &e) {
451 TEST_P(MKLDNNGraphDynBatchGemmTests, TestsDynBatchGemm) {}
453 INSTANTIATE_TEST_CASE_P(
454 TestsDynBatchGemm, MKLDNNGraphDynBatchGemmTests,
456 gemm_test_params{{1, 3, 1, 3, 1, 3, 1, 3}, 3, 3, 3, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
457 gemm_test_params{{1, 3, 1, 1, 1, 3, 1, 3}, 16, 15, 12, 1, 1, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any}
460 class MKLDNNGraphSingleBatchDimGemmTests: public TestsCommon,
461 public WithParamInterface<gemm_test_params> {
462 std::string model_t = R"V0G0N(
463 <net name="gemmOnly" version="2" precision="FP32" batch="1">
465 <layer name="in1" type="Input" precision="FP32" id="1">
474 <layer name="in2" type="Input" precision="FP32" id="2">
483 <layer name="gemm" id="3" type="GEMM" precision="FP32">
484 <data alpha="_A_" beta="_B_" transpose_a="_TA_" transpose_b="_TB_"/>
507 <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
508 <edge from-layer="2" from-port="1" to-layer="3" to-port="2"/>
514 std::string getModel(gemm_test_params p) {
515 std::string model = model_t;
518 REPLACE_WITH_NUM(model, "_MB_A_", p.batches.MB2_A);
519 REPLACE_WITH_NUM(model, "_MB_B_", p.batches.MB2_B);
520 REPLACE_WITH_NUM(model, "_MB_D_", p.batches.MB2_D);
522 REPLACE_WITH_NUM(model, "_M_", p.M);
523 REPLACE_WITH_NUM(model, "_N_", p.N);
524 REPLACE_WITH_NUM(model, "_K_", p.K);
526 REPLACE_WITH_NUM(model, "_A_", p.alpha);
527 REPLACE_WITH_NUM(model, "_B_", p.beta);
528 REPLACE_WITH_NUM(model, "_TA_", p.transposeA);
529 REPLACE_WITH_NUM(model, "_TB_", p.transposeB);
534 virtual void TearDown() {
537 virtual void SetUp() {
539 TestsCommon::SetUp();
540 gemm_test_params p = ::testing::WithParamInterface<gemm_test_params>::GetParam();
541 std::string model = getModel(p);
543 InferenceEngine::CNNNetReader net_reader;
544 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
546 MKLDNNGraphTestClass graph;
547 graph.CreateGraph(net_reader.getNetwork());
549 auto& nodes = graph.getNodes();
550 for (int i = 0; i < nodes.size(); i++) {
551 if (nodes[i]->getType() == MKLDNNPlugin::Gemm) {
552 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
553 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
554 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
556 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
557 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
561 InferenceEngine::SizeVector dims_src1 = {p.batches.MB2_A, p.M, p.K};
562 InferenceEngine::SizeVector dims_src2 = {p.batches.MB2_B, p.K, p.N};
563 InferenceEngine::SizeVector dims_dst = {p.batches.MB2_D, p.M, p.N};
565 InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::CHW, dims_src1);
567 InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
568 if (srcPtr1 == nullptr)
569 FAIL() << "Cannot cast blob to TBlob<float>.";
570 fill_data(src1->buffer(), src1->size());
572 InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::CHW, dims_src2);
574 InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
575 if (srcPtr2 == nullptr)
576 FAIL() << "Cannot cast blob to TBlob<float>.";
577 fill_data(src2->buffer(), src2->size());
579 InferenceEngine::BlobMap srcs;
580 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
581 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
583 InferenceEngine::OutputsDataMap out;
584 out = net_reader.getNetwork().getOutputsInfo();
585 InferenceEngine::BlobMap outputBlobs;
587 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
589 InferenceEngine::TBlob<float>::Ptr output;
590 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
592 outputBlobs[item.first] = output;
594 graph.Infer(srcs, outputBlobs);
596 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
599 std::vector<InferenceEngine::TBlob<float>> src_vec = {*srcPtr1, *srcPtr2};
601 ref_gemm(src_vec, dst_ref, p);
603 compare(*output, dst_ref);
604 } catch (const InferenceEngine::details::InferenceEngineException &e) {
610 TEST_P(MKLDNNGraphSingleBatchDimGemmTests, TestsGemm) {}
612 INSTANTIATE_TEST_CASE_P(
613 TestsGemm, MKLDNNGraphSingleBatchDimGemmTests,
615 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
616 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
617 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
618 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
619 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
620 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, false, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
621 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
622 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
623 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, false, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
624 gemm_test_params{{1, 1, 1, 1, 1, 1, 1, 1}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
625 gemm_test_params{{1, 3, 1, 3, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any},
626 gemm_test_params{{1, 3, 1, 1, 1, 1, 1, 3}, 7, 4, 3, 2, 3, true, true, 1, MKLDNNPlugin::impl_desc_type::gemm_any}