[GNA] Aligning GNA2 and GNA Unit testing (#952)
authorKamil Magierski <kamil.magierski@intel.com>
Wed, 26 Aug 2020 10:43:43 +0000 (12:43 +0200)
committerGitHub <noreply@github.com>
Wed, 26 Aug 2020 10:43:43 +0000 (13:43 +0300)
Co-authored-by: kmagiers <kmagiers@intel.com>
inference-engine/tests_deprecated/unit/engines/gna/gna_api_stub.cpp
inference-engine/tests_deprecated/unit/engines/gna/gna_graph_aot_test.cpp
inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp
inference-engine/tests_deprecated/unit/engines/gna/gna_mock_api.hpp

index 3dd1cfb..be48e98 100644 (file)
@@ -13,7 +13,6 @@
 #include <gna2-inference-api.h>
 #include <gna2-model-export-api.h>
 #endif
-
 #include "gna_mock_api.hpp"
 
 static GNACppApi * current = nullptr;
@@ -47,27 +46,42 @@ GNA2_API enum Gna2Status Gna2MemoryAlloc(
 
 GNA2_API enum Gna2Status Gna2DeviceOpen(
     uint32_t deviceIndex) {
+    if (current != nullptr) {
+        return current->Gna2DeviceOpen(deviceIndex);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2DeviceSetNumberOfThreads(
     uint32_t deviceIndex,
     uint32_t numberOfThreads) {
+    if (current != nullptr) {
+        return current->Gna2DeviceSetNumberOfThreads(deviceIndex, numberOfThreads);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API Gna2Status Gna2DeviceClose(
     uint32_t deviceIndex) {
+    if (current != nullptr) {
+        return current->Gna2DeviceClose(deviceIndex);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2MemoryFree(
     void * memory) {
+    if (current != nullptr) {
+        return current->Gna2MemoryFree(memory);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2StatusGetMessage(enum Gna2Status status,
     char * messageBuffer, uint32_t messageBufferSize) {
+    if (current != nullptr) {
+        return current->Gna2StatusGetMessage(status, messageBuffer, messageBufferSize);
+    }
     return Gna2StatusSuccess;
 }
 
@@ -75,21 +89,34 @@ GNA2_API enum Gna2Status Gna2ModelCreate(
     uint32_t deviceIndex,
     struct Gna2Model const * model,
     uint32_t * modelId) {
+    if (current != nullptr) {
+        return current->Gna2ModelCreate(deviceIndex, model, modelId);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2ModelRelease(
     uint32_t modelId) {
+    if (current != nullptr) {
+        return current->Gna2ModelRelease(modelId);
+    }
     return Gna2StatusSuccess;
 }
 
-GNA2_API enum Gna2Status Gna2ModelGetLastError(struct Gna2ModelError* error) {
+GNA2_API enum Gna2Status Gna2ModelGetLastError(
+    struct Gna2ModelError* error) {
+    if (current != nullptr) {
+        return current->Gna2ModelGetLastError(error);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2RequestConfigCreate(
     uint32_t modelId,
     uint32_t * requestConfigId) {
+    if (current != nullptr) {
+        return current->Gna2RequestConfigCreate(modelId, requestConfigId);
+    }
     return Gna2StatusSuccess;
 }
 
@@ -98,41 +125,62 @@ GNA2_API enum Gna2Status Gna2RequestConfigEnableActiveList(
     uint32_t operationIndex,
     uint32_t numberOfIndices,
     uint32_t const * indices) {
+    if (current != nullptr) {
+        return current->Gna2RequestConfigEnableActiveList(requestConfigId, operationIndex, numberOfIndices, indices);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2RequestConfigEnableHardwareConsistency(
     uint32_t requestConfigId,
     enum Gna2DeviceVersion deviceVersion) {
+    if (current != nullptr) {
+        return current->Gna2RequestConfigEnableHardwareConsistency(requestConfigId, deviceVersion);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2RequestConfigSetAccelerationMode(
     uint32_t requestConfigId,
     enum Gna2AccelerationMode accelerationMode) {
+    if (current != nullptr) {
+        return current->Gna2RequestConfigSetAccelerationMode(requestConfigId, accelerationMode);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2RequestEnqueue(
     uint32_t requestConfigId,
     uint32_t * requestId) {
+    if (current != nullptr) {
+        return current->Gna2RequestEnqueue(requestConfigId, requestId);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2RequestWait(
     uint32_t requestId,
     uint32_t timeoutMilliseconds) {
+    if (current != nullptr) {
+        return current->Gna2RequestWait(requestId, timeoutMilliseconds);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2ModelExportConfigCreate(
     Gna2UserAllocator userAllocator,
     uint32_t * exportConfigId) {
+    if (current != nullptr) {
+        return current->Gna2ModelExportConfigCreate(userAllocator, exportConfigId);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2ModelExportConfigRelease(
     uint32_t exportConfigId) {
+    if (current != nullptr) {
+        return current->Gna2ModelExportConfigRelease(exportConfigId);
+    }
     return Gna2StatusSuccess;
 }
 
@@ -140,12 +188,18 @@ GNA2_API enum Gna2Status Gna2ModelExportConfigSetSource(
     uint32_t exportConfigId,
     uint32_t sourceDeviceIndex,
     uint32_t sourceModelId) {
+    if (current != nullptr) {
+        return current->Gna2ModelExportConfigSetSource(exportConfigId, sourceDeviceIndex, sourceModelId);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2ModelExportConfigSetTarget(
     uint32_t exportConfigId,
     enum Gna2DeviceVersion targetDeviceVersion) {
+    if (current != nullptr) {
+        return current->Gna2ModelExportConfigSetTarget(exportConfigId, targetDeviceVersion);
+    }
     return Gna2StatusSuccess;
 }
 
@@ -154,12 +208,18 @@ GNA2_API enum Gna2Status Gna2ModelExport(
     enum Gna2ModelExportComponent componentType,
     void ** exportBuffer,
     uint32_t * exportBufferSize) {
+    if (current != nullptr) {
+        return current->Gna2ModelExport(exportConfigId, componentType, exportBuffer, exportBufferSize);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2DeviceGetVersion(
     uint32_t deviceIndex,
     enum Gna2DeviceVersion * deviceVersion) {
+    if (current != nullptr) {
+        return current->Gna2DeviceGetVersion(deviceIndex,deviceVersion);
+    }
     *deviceVersion = Gna2DeviceVersionSoftwareEmulation;
     return Gna2StatusSuccess;
 }
@@ -169,12 +229,18 @@ GNA2_API enum Gna2Status Gna2InstrumentationConfigCreate(
     enum Gna2InstrumentationPoint* selectedInstrumentationPoints,
     uint64_t * results,
     uint32_t * instrumentationConfigId) {
+    if (current != nullptr) {
+        return current->Gna2InstrumentationConfigCreate(numberOfInstrumentationPoints, selectedInstrumentationPoints, results, instrumentationConfigId);
+    }
     return Gna2StatusSuccess;
 }
 
 GNA2_API enum Gna2Status Gna2InstrumentationConfigAssignToRequestConfig(
     uint32_t instrumentationConfigId,
     uint32_t requestConfigId) {
+    if (current != nullptr) {
+        return current->Gna2InstrumentationConfigAssignToRequestConfig(instrumentationConfigId, requestConfigId);
+    }
     return Gna2StatusSuccess;
 }
 
index 2097652..1b53e33 100644 (file)
@@ -41,14 +41,11 @@ TEST_F(GNAAOTTests, DISABLED_AffineWith2AffineOutputs_canbe_export_imported) {
         .inNotCompactMode().gna().propagate_forward().called().once();
 }
 
-
+TEST_F(GNAAOTTests, AffineWith2AffineOutputs_canbe_imported_verify_structure) {
 // Disabled because of random fails: Issue-23611
-TEST_F(GNAAOTTests, DISABLED_AffineWith2AffineOutputs_canbe_imported_verify_structure) {
-
-#if GNA_LIB_VER == 2
+#if GNA_LIB_VER == 1
     GTEST_SKIP();
 #endif
-
     auto & nnet_type = storage<intel_nnet_type_t>();
 
     // saving pointer to nnet - todo probably deep copy required
@@ -123,10 +120,6 @@ TEST_F(GNAAOTTests, PoolingModel_canbe_export_imported) {
 
 TEST_F(GNAAOTTests, CanConvertFromAOTtoSueModel) {
 
-#if GNA_LIB_VER == 2
-    GTEST_SKIP();
-#endif
-
     auto & nnet_type = storage<intel_nnet_type_t>();
 
     // saving pointer to nnet - todo probably deep copy required
index f9af555..573d178 100644 (file)
@@ -67,6 +67,43 @@ public:
         delete this;
     }
 };
+#if GNA_LIB_VER == 2
+void expect_enqueue_calls(GNACppApi &mockApi, bool enableHardwareConsistency = true){
+    EXPECT_CALL(mockApi, Gna2ModelCreate(_,_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([](
+        uint32_t deviceIndex,
+        struct Gna2Model const * model,
+        uint32_t * modelId) {
+            *modelId = 0;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2RequestConfigCreate(_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([](
+        uint32_t modelId,
+        uint32_t * requestConfigId) {
+            *requestConfigId = 0;
+            return Gna2StatusSuccess;
+        }));
+
+    if (enableHardwareConsistency) {
+        EXPECT_CALL(mockApi, Gna2RequestConfigEnableHardwareConsistency(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
+    }
+
+    EXPECT_CALL(mockApi, Gna2RequestConfigSetAccelerationMode(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
+
+    {
+        ::testing::InSequence enqueue_wait_sequence;
+        EXPECT_CALL(mockApi, Gna2RequestEnqueue(_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([](
+            uint32_t requestConfigId,
+            uint32_t * requestId) {
+                *requestId = 0;
+                return Gna2StatusSuccess;
+            }));
+        EXPECT_CALL(mockApi, Gna2RequestWait(_, _)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
+    }
+}
+#endif
 
 void GNAPropagateMatcher :: match() {
     try {
@@ -244,7 +281,7 @@ void GNAPropagateMatcher :: match() {
 
         if (_env.config[GNA_CONFIG_KEY(DEVICE_MODE)].compare(GNA_CONFIG_VALUE(SW_FP32)) != 0 &&
             !_env.matchThrows) {
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+#if GNA_LIB_VER == 1
             EXPECT_CALL(mockApi, GNAAlloc(_,_,_)).WillOnce(Invoke([&data](
                 intel_gna_handle_t nGNADevice,   // handle to GNA accelerator
                 uint32_t           sizeRequested,
@@ -267,7 +304,7 @@ void GNAPropagateMatcher :: match() {
             } else {
                 EXPECT_CALL(mockApi, gmmSetThreads(_)).Times(0);
             }
-#else
+#elif GNA_LIB_VER == 2
             EXPECT_CALL(mockApi, Gna2MemoryAlloc(_, _, _)).WillOnce(Invoke([&data](
                 uint32_t sizeRequested,
                 uint32_t *sizeGranted,
@@ -278,6 +315,25 @@ void GNAPropagateMatcher :: match() {
                 *memoryAddress = &data.front();
                 return Gna2StatusSuccess;
             }));
+
+            EXPECT_CALL(mockApi, Gna2DeviceGetVersion(_,_)).WillOnce(Invoke([](
+                uint32_t deviceIndex,
+                enum Gna2DeviceVersion * deviceVersion) {
+                    *deviceVersion = Gna2DeviceVersionSoftwareEmulation;
+                    return Gna2StatusSuccess;
+                }));
+
+            EXPECT_CALL(mockApi, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
+
+            EXPECT_CALL(mockApi, Gna2InstrumentationConfigCreate(_,_,_,_)).WillOnce(Return(Gna2StatusSuccess));
+
+
+
+            if(_env.is_setup_of_omp_theads_expected == true) {
+                EXPECT_CALL(mockApi, Gna2DeviceSetNumberOfThreads(_,_)).WillOnce(Return(Gna2StatusSuccess));
+            }
+#else
+#error "Unsupported GNA_LIB_VER"
 #endif
             std::unique_ptr<NNetComponentMatcher> combined(new NNetComponentMatcher());
 
@@ -287,9 +343,15 @@ void GNAPropagateMatcher :: match() {
                         combined->add(new NNetPrecisionMatcher(_env.nnet_precision, INTEL_AFFINE));
                         break;
                     case GnaPluginTestEnvironment::matchProcType :
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+#if GNA_LIB_VER == 1
                         EXPECT_CALL(mockApi, GNAPropagateForward(_, _, _, _, _, Eq(_env.proc_type)))
                             .WillOnce(Return(GNA_NOERROR));
+#elif GNA_LIB_VER == 2
+                        if(_env.proc_type == (GNA_SOFTWARE & GNA_HARDWARE)) {
+                            expect_enqueue_calls(mockApi);
+                        } else {
+                            expect_enqueue_calls(mockApi, false);
+                        }
 #endif
                         break;
                     case GnaPluginTestEnvironment::matchPwlInserted :
@@ -313,9 +375,11 @@ void GNAPropagateMatcher :: match() {
                         combined->add(new DiagLayerMatcher(_env.matchInserted, matchWhat.matchQuantity));
                         break;
                     case GnaPluginTestEnvironment::saveArgs :
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+#if GNA_LIB_VER == 1
                         EXPECT_CALL(mockApi, GNAPropagateForward(_, _, _, _, _, _))
                             .WillOnce(DoAll(SaveArgPointee<1>(savedNet), Return(GNA_NOERROR)));
+#elif GNA_LIB_VER == 2
+                        expect_enqueue_calls(mockApi);
 #endif
                         break;
                     case GnaPluginTestEnvironment::matchInputData :
@@ -337,16 +401,20 @@ void GNAPropagateMatcher :: match() {
                         SaveWeights(combined, _env.transposedData, _env.transposedArgsForSaving);
                         break;
                     default:
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+#if GNA_LIB_VER == 1
                         EXPECT_CALL(mockApi, GNAPropagateForward(_, _, _, _, _, _))
                             .WillOnce(Return(GNA_NOERROR));
+#elif GNA_LIB_VER == 2
+                        expect_enqueue_calls(mockApi);
 #endif
                         break;
                 }
             }
             if (combined && !combined->empty()) {
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+#if GNA_LIB_VER == 1
                 EXPECT_CALL(mockApi, GNAPropagateForward(_, ::testing::MakeMatcher(combined.release()), _, _, _,_)).WillOnce(Return(GNA_NOERROR));
+#elif GNA_LIB_VER == 2
+                expect_enqueue_calls(mockApi);
 #endif
             }
         }
@@ -459,21 +527,52 @@ void GNAPluginAOTMatcher :: match() {
     }
 
     GNACppApi mockApi;
-    std::vector<uint8_t> data(10000);
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+    std::vector<std::vector<uint8_t>> data;
+#if GNA_LIB_VER == 1
     EXPECT_CALL(mockApi, GNAAlloc(_,_,_)).WillOnce(DoAll(SetArgPointee<2>(10000), Return(&data.front())));
     EXPECT_CALL(mockApi, GNADeviceOpenSetThreads(_, _)).WillOnce(Return(1));
+#elif GNA_LIB_VER == 2
+    EXPECT_CALL(mockApi, Gna2MemoryAlloc(_, _, _)).Times(AtLeast(1)).WillRepeatedly(Invoke([&data](
+        uint32_t sizeRequested,
+        uint32_t *sizeGranted,
+        void **memoryAddress) {
+            data.push_back(std::vector<uint8_t>(sizeRequested));
+            *sizeGranted = sizeRequested;
+            *memoryAddress = data.back().data();
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2DeviceGetVersion(_,_)).WillOnce(Invoke([](
+        uint32_t deviceIndex,
+        enum Gna2DeviceVersion * deviceVersion) {
+            *deviceVersion = Gna2DeviceVersionSoftwareEmulation;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2InstrumentationConfigCreate(_,_,_,_)).WillOnce(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2ModelCreate(_,_,_)).WillOnce(Invoke([](
+        uint32_t deviceIndex,
+        struct Gna2Model const * model,
+        uint32_t * modelId) {
+            *modelId = 0;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2RequestConfigCreate(_,_)).WillOnce(Invoke([](
+        uint32_t modelId,
+        uint32_t * requestConfigId) {
+            *requestConfigId = 0;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2RequestConfigEnableHardwareConsistency(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
 #else
-    EXPECT_CALL(mockApi, Gna2MemoryAlloc(_, _, _)).WillOnce(Invoke([&data](
-            uint32_t sizeRequested,
-            uint32_t *sizeGranted,
-            void **memoryAddress
-    ) {
-        data.resize(sizeRequested);
-        *sizeGranted = sizeRequested;
-        *memoryAddress = &data.front();
-        return Gna2StatusSuccess;
-    }));
+#error "Not supported GNA_LIB_VER"
 #endif
     plugin.LoadNetwork(network);
     plugin.Export(_env.exportedModelFileName);
@@ -523,8 +622,9 @@ void GNADumpXNNMatcher::match() {
 
     GNACppApi mockApi;
     std::vector<uint8_t> data(10000);
+
+#if GNA_LIB_VER == 1
     if (!_env.matchThrows) {
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
         EXPECT_CALL(mockApi, GNAAlloc(_,_,_)).WillOnce(DoAll(SetArgPointee<2>(10000), Return(&data.front())));
         EXPECT_CALL(mockApi, GNADeviceOpenSetThreads(_, _)).WillOnce(Return(1));
         intel_gna_model_header header = {};
@@ -532,8 +632,72 @@ void GNADumpXNNMatcher::match() {
         EXPECT_CALL(mockApi, GNADumpXnn(_, _, _, _, _,_)).WillOnce(DoAll(SetArgPointee<3>(header), Return((void*)::operator new[](1))));
         EXPECT_CALL(mockApi, GNAFree(_)).WillOnce(Return(GNA_NOERROR));
         EXPECT_CALL(mockApi, GNADeviceClose(_)).WillOnce(Return(GNA_NOERROR));
-#endif
     }
+#elif GNA_LIB_VER == 2
+    if (!_env.matchThrows) {
+        EXPECT_CALL(mockApi, Gna2MemoryAlloc(_, _, _)).
+            WillOnce(DoAll(SetArgPointee<1>(10000), SetArgPointee<2>(&data.front()), Return(Gna2StatusSuccess)));
+
+        EXPECT_CALL(mockApi, Gna2DeviceGetVersion(_,_)).WillOnce(Invoke([](
+            uint32_t deviceIndex,
+            enum Gna2DeviceVersion * deviceVersion) {
+                *deviceVersion = Gna2DeviceVersionSoftwareEmulation;
+                return Gna2StatusSuccess;
+            }));
+
+        EXPECT_CALL(mockApi, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2InstrumentationConfigCreate(_,_,_,_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2ModelCreate(_,_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([](
+            uint32_t deviceIndex,
+            struct Gna2Model const * model,
+            uint32_t * modelId) {
+                *modelId = 0;
+                return Gna2StatusSuccess;
+            }));
+
+        EXPECT_CALL(mockApi, Gna2MemoryFree(_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2DeviceClose(_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2ModelExportConfigCreate(_,_)).WillOnce(DoAll(SetArgPointee<1>(0), Return(Gna2StatusSuccess)));
+
+        EXPECT_CALL(mockApi, Gna2ModelExportConfigSetSource(_,_,_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2ModelExportConfigSetTarget(_,_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2ModelExport(_,_,_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([] (
+            uint32_t exportConfigId,
+            enum Gna2ModelExportComponent componentType,
+            void ** exportBuffer,
+            uint32_t * exportBufferSize) {
+                *exportBufferSize = 64;
+                *exportBuffer = gnaUserAllocator(sizeof(Gna2ModelSueCreekHeader));
+                return Gna2StatusSuccess;
+            }));
+
+        EXPECT_CALL(mockApi, Gna2ModelExportConfigRelease(_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2ModelRelease(_)).WillOnce(Return(Gna2StatusSuccess));
+
+        EXPECT_CALL(mockApi, Gna2RequestConfigCreate(_,_)).WillOnce(Invoke([](
+            uint32_t modelId,
+            uint32_t * requestConfigId) {
+                *requestConfigId = 0;
+                return Gna2StatusSuccess;
+    }));
+
+        ON_CALL(mockApi, Gna2RequestConfigSetAccelerationMode(_,_)).WillByDefault(Return(Gna2StatusSuccess));
+
+        ON_CALL(mockApi, Gna2RequestConfigEnableHardwareConsistency(_,_)).WillByDefault(Return(Gna2StatusSuccess));
+
+        ON_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).WillByDefault(Return(Gna2StatusSuccess));
+    }
+#else
+#error "Not supported GNA_LIB_VER"
+#endif
+
 
     try {
         // matching gna DumpXNN forward call.
@@ -590,7 +754,7 @@ void GNAQueryStateMatcher :: match() {
         }
     };
 
-#if GNA_LIB_VER == 1 // TODO: GNA2: handle new API
+#if GNA_LIB_VER == 1
     EXPECT_CALL(mockApi, GNAAlloc(_,_,_)).WillOnce(DoAll(SetArgPointee<2>(10000), Return(&data.front())));
     EXPECT_CALL(mockApi, GNADeviceOpenSetThreads(_, _)).WillOnce(Return(1));
     EXPECT_CALL(mockApi, GNAFree(_)).WillOnce(Return(GNA_NOERROR));
@@ -598,6 +762,40 @@ void GNAQueryStateMatcher :: match() {
 #else
     EXPECT_CALL(mockApi, Gna2MemoryAlloc(_, _, _)).
         WillOnce(DoAll(SetArgPointee<1>(10000), SetArgPointee<2>(&data.front()), Return(Gna2StatusSuccess)));
+
+    EXPECT_CALL(mockApi, Gna2DeviceGetVersion(_,_)).WillOnce(Invoke([](
+        uint32_t deviceIndex,
+        enum Gna2DeviceVersion * deviceVersion) {
+            *deviceVersion = Gna2DeviceVersionSoftwareEmulation;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2DeviceOpen(_)).WillOnce(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2InstrumentationConfigCreate(_,_,_,_)).WillOnce(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2MemoryFree(_)).WillOnce(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2DeviceClose(_)).WillOnce(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2ModelCreate(_,_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([](
+        uint32_t deviceIndex,
+        struct Gna2Model const * model,
+        uint32_t * modelId) {
+            *modelId = 0;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2RequestConfigCreate(_,_)).Times(AtLeast(1)).WillRepeatedly(Invoke([](
+        uint32_t modelId,
+        uint32_t * requestConfigId) {
+            *requestConfigId = 0;
+            return Gna2StatusSuccess;
+        }));
+
+    EXPECT_CALL(mockApi, Gna2RequestConfigEnableHardwareConsistency(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
+
+    EXPECT_CALL(mockApi, Gna2InstrumentationConfigAssignToRequestConfig(_,_)).Times(AtLeast(1)).WillRepeatedly(Return(Gna2StatusSuccess));
 #endif
     try {
         loadNetwork();
index e716cf9..f93effc 100644 (file)
@@ -4,7 +4,15 @@
 
 #pragma once
 #include <gmock/gmock-generated-function-mockers.h>
-
+#if GNA_LIB_VER == 1
+#include <gna-api.h>
+#include <gna-api-instrumentation.h>
+#include <gna-api-dumper.h>
+#else
+#include <gna2-instrumentation-api.h>
+#include <gna2-inference-api.h>
+#include <gna2-model-export-api.h>
+#endif
 #if defined(_WIN32)
     #ifdef libGNAStubs_EXPORTS
         #define GNA_STUBS_EXPORT __declspec(dllexport)
@@ -83,5 +91,96 @@ class GNACppApi {
         uint32_t sizeRequested,
         uint32_t * sizeGranted,
         void ** memoryAddress));
+
+    MOCK_METHOD1(Gna2DeviceOpen, Gna2Status (
+        uint32_t deviceIndex));
+
+    MOCK_METHOD2(Gna2DeviceSetNumberOfThreads, Gna2Status(
+        uint32_t deviceIndex,
+        uint32_t numberOfThreads));
+
+    MOCK_METHOD1(Gna2DeviceClose, Gna2Status (
+        uint32_t deviceIndex));
+
+    MOCK_METHOD1(Gna2MemoryFree, Gna2Status (
+        void * memory));
+
+    MOCK_METHOD3(Gna2StatusGetMessage, Gna2Status (
+        enum Gna2Status status,
+        char * messageBuffer,
+        uint32_t messageBufferSize));
+
+    MOCK_METHOD3(Gna2ModelCreate, Gna2Status (
+        uint32_t deviceIndex,
+        struct Gna2Model const * model,
+        uint32_t * modelId));
+
+    MOCK_METHOD1(Gna2ModelRelease, Gna2Status (
+        uint32_t modelId));
+
+    MOCK_METHOD1(Gna2ModelGetLastError, Gna2Status (
+        struct Gna2ModelError* error));
+
+    MOCK_METHOD2(Gna2RequestConfigCreate, Gna2Status (
+        uint32_t modelId,
+        uint32_t * requestConfigId));
+
+    MOCK_METHOD4(Gna2RequestConfigEnableActiveList, Gna2Status (
+        uint32_t requestConfigId,
+        uint32_t operationIndex,
+        uint32_t numberOfIndices,
+        uint32_t const * indices));
+
+    MOCK_METHOD2(Gna2RequestConfigEnableHardwareConsistency, Gna2Status (
+        uint32_t requestConfigId,
+        enum Gna2DeviceVersion deviceVersion));
+
+    MOCK_METHOD2(Gna2RequestConfigSetAccelerationMode, Gna2Status (
+        uint32_t requestConfigId,
+        enum Gna2AccelerationMode accelerationMode));
+
+    MOCK_METHOD2(Gna2RequestEnqueue, Gna2Status (
+        uint32_t requestConfigId,
+        uint32_t * requestId));
+
+    MOCK_METHOD2(Gna2RequestWait, Gna2Status (
+        uint32_t requestId,
+        uint32_t timeoutMilliseconds));
+
+    MOCK_METHOD2(Gna2ModelExportConfigCreate, Gna2Status (
+        Gna2UserAllocator userAllocator,
+        uint32_t * exportConfigId));
+
+    MOCK_METHOD1(Gna2ModelExportConfigRelease, Gna2Status (
+        uint32_t exportConfigId));
+
+    MOCK_METHOD3(Gna2ModelExportConfigSetSource, Gna2Status (
+        uint32_t exportConfigId,
+        uint32_t sourceDeviceIndex,
+        uint32_t sourceModelId));
+
+    MOCK_METHOD2(Gna2ModelExportConfigSetTarget, Gna2Status (
+        uint32_t exportConfigId,
+        enum Gna2DeviceVersion targetDeviceVersion));
+
+    MOCK_METHOD4(Gna2ModelExport, Gna2Status (
+        uint32_t exportConfigId,
+        enum Gna2ModelExportComponent componentType,
+        void ** exportBuffer,
+        uint32_t * exportBufferSize));
+
+    MOCK_METHOD2(Gna2DeviceGetVersion, Gna2Status (
+        uint32_t deviceIndex,
+        enum Gna2DeviceVersion * deviceVersion));
+
+    MOCK_METHOD4(Gna2InstrumentationConfigCreate, Gna2Status (
+        uint32_t numberOfInstrumentationPoints,
+        enum Gna2InstrumentationPoint* selectedInstrumentationPoints,
+        uint64_t * results,
+        uint32_t * instrumentationConfigId));
+
+    MOCK_METHOD2(Gna2InstrumentationConfigAssignToRequestConfig, Gna2Status (
+        uint32_t instrumentationConfigId,
+        uint32_t requestConfigId));
 #endif
 };