Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / common / format_reader / MnistUbyte.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <fstream>
6 #include <iostream>
7 #include <string>
8 #include <MnistUbyte.h>
9
10 using namespace FormatReader;
11
12 int MnistUbyte::reverseInt(int i) {
13     unsigned char ch1, ch2, ch3, ch4;
14     ch1 = (unsigned char) (i & 255);
15     ch2 = (unsigned char) ((i >> 8) & 255);
16     ch3 = (unsigned char) ((i >> 16) & 255);
17     ch4 = (unsigned char) ((i >> 24) & 255);
18     return (static_cast<int>(ch1) << 24) + (static_cast<int>(ch2) << 16) + (static_cast<int>(ch3) << 8) + ch4;
19 }
20
21 MnistUbyte::MnistUbyte(const std::string &filename) {
22     std::ifstream file(filename, std::ios::binary);
23     if (!file.is_open()) {
24         return;
25     }
26     int magic_number = 0;
27     int number_of_images = 0;
28     int n_rows = 0;
29     int n_cols = 0;
30     file.read(reinterpret_cast<char *>(&magic_number), sizeof(magic_number));
31     magic_number = reverseInt(magic_number);
32     if (magic_number != 2051) {
33         return;
34     }
35     file.read(reinterpret_cast<char *>(&number_of_images), sizeof(number_of_images));
36     number_of_images = reverseInt(number_of_images);
37     file.read(reinterpret_cast<char *>(&n_rows), sizeof(n_rows));
38     n_rows = reverseInt(n_rows);
39     _height = (size_t) n_rows;
40     file.read(reinterpret_cast<char *>(&n_cols), sizeof(n_cols));
41     n_cols = reverseInt(n_cols);
42     _width = (size_t) n_cols;
43     if (number_of_images > 1) {
44         std::cout << "[MNIST] Warning: number_of_images  in mnist file equals " << number_of_images
45                   << ". Only a first image will be read." << std::endl;
46     }
47
48     size_t size = _width * _height * 1;
49
50     _data.reset(new unsigned char[size], std::default_delete<unsigned char[]>());
51     size_t count = 0;
52     if (0 < number_of_images) {
53         for (int r = 0; r < n_rows; ++r) {
54             for (int c = 0; c < n_cols; ++c) {
55                 unsigned char temp = 0;
56                 file.read(reinterpret_cast<char *>(&temp), sizeof(temp));
57                 _data.get()[count++] = temp;
58             }
59         }
60     }
61
62     file.close();
63 }