Imported Upstream version 1.4.0
[platform/core/ml/nnfw.git] / compiler / tfl-inspect / src / Model.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Model.h"
18
19 #include <fcntl.h>
20 #include <unistd.h>
21 #include <sys/stat.h>
22 #include <sys/mman.h>
23
24 namespace
25 {
26
27 class MemoryMappedModel final : public tflinspect::Model
28 {
29 public:
30   /**
31    * @require fd and data SHOULD be valid
32    */
33   explicit MemoryMappedModel(int fd, void *data, size_t size) : _fd{fd}, _data{data}, _size{size}
34   {
35     // DO NOTHING
36   }
37
38 public:
39   ~MemoryMappedModel()
40   {
41     munmap(_data, _size);
42     close(_fd);
43   }
44
45 public:
46   MemoryMappedModel(const MemoryMappedModel &) = delete;
47   MemoryMappedModel(MemoryMappedModel &&) = delete;
48
49 public:
50   const ::tflite::Model *model(void) const override { return ::tflite::GetModel(_data); }
51
52 private:
53   int _fd = -1;
54   void *_data = nullptr;
55   size_t _size = 0;
56 };
57
58 class FileDescriptor final
59 {
60 public:
61   FileDescriptor(int value) : _value{value}
62   {
63     // DO NOTHING
64   }
65
66 public:
67   // NOTE Copy is not allowed
68   FileDescriptor(const FileDescriptor &) = delete;
69
70 public:
71   // NOTE Move is allowed
72   FileDescriptor(FileDescriptor &&fd) { _value = fd.release(); }
73
74 public:
75   ~FileDescriptor()
76   {
77     if (_value != -1)
78     {
79       // Close on descturction
80       close(_value);
81     }
82   }
83
84 public:
85   int value(void) const { return _value; }
86
87 public:
88   int release(void)
89   {
90     auto res = _value;
91     _value = -1;
92     return res;
93   }
94
95 private:
96   int _value = -1;
97 };
98
99 } // namespace
100
101 namespace tflinspect
102 {
103
104 std::unique_ptr<Model> load_tflite(const std::string &path)
105 {
106   FileDescriptor fd = open(path.c_str(), O_RDONLY);
107
108   if (fd.value() == -1)
109   {
110     // Return nullptr on open failure
111     return nullptr;
112   }
113
114   struct stat st;
115   if (fstat(fd.value(), &st) == -1)
116   {
117     // Return nullptr on fstat failure
118     return nullptr;
119   }
120
121   auto size = st.st_size;
122   auto data = mmap(nullptr, size, PROT_READ, MAP_SHARED, fd.value(), 0);
123
124   if (data == MAP_FAILED)
125   {
126     // Return nullptr on mmap failure
127     return nullptr;
128   }
129
130   // Check if file is a valid Flatbuffer file
131   const uint8_t *u8data = reinterpret_cast<const uint8_t *>(data);
132   flatbuffers::Verifier verifier{u8data, static_cast<size_t>(size)};
133   if (!tflite::VerifyModelBuffer(verifier))
134   {
135     munmap(data, size);
136     close(fd.release());
137     return nullptr;
138   }
139
140   return std::unique_ptr<tflinspect::Model>{new MemoryMappedModel(fd.release(), data, size)};
141 }
142
143 } // namespace tflinspect