1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <gtest/gtest.h>
9 #include "inference_engine.hpp"
11 class TestsCommonFunc {
13 InferenceEngine::Blob::Ptr readBMP(std::string path, unsigned batch) {
15 std::ifstream input(path, std::ios::binary);
16 if (!input) return nullptr;
18 unsigned char bmpFileHeader[14];
19 input.read((char*)bmpFileHeader, sizeof(bmpFileHeader));
20 if(bmpFileHeader[0]!='B' || bmpFileHeader[1]!='M') return nullptr;
21 if(bmpFileHeader[11]!=0 || bmpFileHeader[12]!=0 || bmpFileHeader[13]!=0 ) return nullptr;
23 unsigned char bmpInfoHeader[40];
24 input.read((char*)bmpInfoHeader, sizeof(bmpInfoHeader));
25 if(bmpInfoHeader[14]!=24) return nullptr; // bits per pixel
26 if(bmpInfoHeader[16]!=0) return nullptr; // compression is not supported
28 bool rowsReversed = (*(int32_t*)(bmpInfoHeader + 8)) < 0;
29 uint32_t width = *(int32_t*)(bmpInfoHeader + 4);
30 uint32_t height = abs(*(int32_t*)(bmpInfoHeader + 8));
32 size_t padSize = width & 3;
35 InferenceEngine::Blob::Ptr blob(new InferenceEngine::TBlob<float>( InferenceEngine::Precision::FP32, InferenceEngine::NCHW, {batch, 3, width, height} ));
37 float *blob_ptr = (float*)(void*)blob->buffer();
39 unsigned int offset = *(unsigned int *)(bmpFileHeader + 10);
40 for (int b = 0; b < batch; b++) {
41 int b_off = 3*width*height*b;
42 input.seekg(offset, std::ios::beg);
43 //reading by rows in invert vertically
44 for (uint32_t i = 0; i < height; i++) {
45 int storeAt = rowsReversed ? i : height - 1 - i;
47 for (uint32_t j = 0; j < width; j++) {
48 unsigned char RGBA[3];
49 input.read((char *) RGBA, sizeof(RGBA));
51 blob_ptr[b_off + j + storeAt * width] = RGBA[0];
52 blob_ptr[b_off + j + storeAt * width + height * width * 1] = RGBA[1];
53 blob_ptr[b_off + j + storeAt * width + height * width * 2] = RGBA[2];
55 input.read(pad, padSize);
62 inline void bswap_32(char* ptr, size_t size) {
63 char* end = ptr + size;
65 for (; ptr<end; ptr+=4) {
66 tmp = ptr[0]; ptr[0] = ptr[3]; ptr[3] = tmp;
67 tmp = ptr[1]; ptr[1] = ptr[2]; ptr[2] = tmp;
71 InferenceEngine::Blob::Ptr readUbyte(std::string path, unsigned batch) {
73 std::ifstream input(path, std::ios::binary);
75 uint32_t magic_number;
81 input.read((char *) &hdr, sizeof(hdr));
82 bswap_32((char *) &hdr, sizeof(hdr));
83 if (hdr.magic_number != 2051) return nullptr; // Invalid MNIST image file
85 InferenceEngine::Blob::Ptr blob(new InferenceEngine::TBlob<float>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW,
86 {batch, hdr.n_images, hdr.n_cols, hdr.n_rows }));
88 float *blob_ptr = (float*)(void*)blob->buffer();
89 for (int b = 0; b < batch; b++) {
90 input.seekg(sizeof(hdr), std::ios::beg);
91 int b_off = b*hdr.n_images*hdr.n_rows*hdr.n_cols;
92 for (uint32_t i = 0; i < hdr.n_images; ++i) {
93 for (uint32_t r = 0; r < hdr.n_rows; ++r) {
94 for (uint32_t c = 0; c < hdr.n_cols; ++c) {
95 unsigned char temp = 0;
96 input.read((char *) &temp, sizeof(temp));
97 blob_ptr[b_off + i * hdr.n_rows * hdr.n_cols + r * hdr.n_cols + c] = temp;
107 InferenceEngine::Blob::Ptr readInput(std::string path, int batch = 1) {
108 if ( path.substr(path.rfind('.') + 1) == "bmp" ) return readBMP(path, batch);
109 if ( path.substr(path.rfind('-') + 1) == "ubyte" ) return readUbyte(path, batch);
113 bool compareTop(InferenceEngine::Blob& blob, std::vector<std::pair<int, float>> &ref_top, int batch_to_compare = 0, float threshold = 0.005f) {
114 if (blob.dims()[0] == 7)
115 return compareTopLikeObjDetection(blob, ref_top, batch_to_compare);
117 return compareTopLikeClassification(blob, ref_top, batch_to_compare, threshold);
120 bool compareTopLikeObjDetection (InferenceEngine::Blob& blob, std::vector<std::pair<int, float>> &ref_top,
121 int batch_to_compare = 0) {
122 assert(blob.dims()[0] == 7);
124 const int box_info_size = 7;
126 int top_num = (int)ref_top.size();
127 float *data_ptr = blob.buffer().as<float*>();
128 const int data_size = blob.size();
129 if (data_size/box_info_size < top_num) {
130 EXPECT_TRUE(data_size/box_info_size >= top_num) << "Dst blob contains less data then expected";
134 for (int i=0; i<top_num; i++) {
135 int lable = data_ptr[i*box_info_size + 1];
136 float confidence = data_ptr[i*box_info_size + 2];
138 if (lable != ref_top[i].first) {
139 EXPECT_EQ(lable , ref_top[i].first) << "Label mismatch";
143 if (fabs(confidence - ref_top[i].second)/ref_top[i].second > 0.005) {
144 EXPECT_NEAR(confidence, ref_top[i].second, ref_top[i].second * 0.005);
152 bool compareTopLikeClassification (InferenceEngine::Blob& blob, std::vector<std::pair<int, float>> &ref_top,
153 int batch_to_compare = 0, float threshold = 0.005f) {
154 int top_num = (int)ref_top.size();
156 size_t data_size = blob.size();
157 float *data_ptr = (float*)(void*)blob.buffer();
159 int batch_size = blob.dims()[blob.dims().size() - 1];
160 assert(batch_size > batch_to_compare);
161 data_size /= batch_size;
162 data_ptr += data_size*batch_to_compare;
164 std::vector<int> top(data_size);
166 for (size_t i = 0; i < data_size; i++) top[i] = (int)i;
167 std::partial_sort (top.begin(), top.begin()+top_num, top.end(),
168 [&](int l, int r) -> bool { return data_ptr[l] > data_ptr[r]; } );
170 for (int i = 0 ; i < top_num; i++) {
171 if (top[i] != ref_top[i].first) {
172 EXPECT_EQ(top[i] , ref_top[i].first);
176 if (fabs(data_ptr[top[i]] - ref_top[i].second)/ref_top[i].second > threshold) {
177 EXPECT_NEAR(data_ptr[top[i]] , ref_top[i].second , ref_top[i].second * threshold);
184 void zeroing (InferenceEngine::Blob& blob) {
185 size_t data_size = blob.size();
186 float *data_ptr = (float *) (void *) blob.buffer();
189 *(float*)&zero = 0.0f;
190 memset(data_ptr, zero, data_size * sizeof(float));
193 bool isZeroFilled (InferenceEngine::Blob& blob, int batch_to_compare = 0) {
194 size_t data_size = blob.size();
195 float *data_ptr = (float*)(void*)blob.buffer();
197 int batch_size = blob.dims()[blob.dims().size() - 1];
198 assert(batch_size > batch_to_compare);
199 data_size /= batch_size;
200 data_ptr += data_size*batch_to_compare;
202 for (int i = 0 ; i < data_size; i++) {
203 EXPECT_FLOAT_EQ(data_ptr[i] , 0.0f);