From f4eddcb1614725b5ead6a9f801cc91c87da54ae2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Mon, 3 Dec 2018 16:25:21 +0900 Subject: [PATCH] [tfldump] Load tflite file (#2465) This will add tflite loader part of tfldump Signed-off-by: SaeHie Park --- contrib/tfldump/CMakeLists.txt | 4 +- contrib/tfldump/driver/Driver.cpp | 15 +++- contrib/tfldump/include/tflread/Model.h | 21 +++++ contrib/tfldump/src/Load.cpp | 133 ++++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 contrib/tfldump/src/Load.cpp diff --git a/contrib/tfldump/CMakeLists.txt b/contrib/tfldump/CMakeLists.txt index 231d776..731acd2 100644 --- a/contrib/tfldump/CMakeLists.txt +++ b/contrib/tfldump/CMakeLists.txt @@ -11,7 +11,9 @@ FlatBuffers_Target(tfldump_flatbuffer set(DRIVER "driver/Driver.cpp") -add_executable(tfldump ${DRIVER}) +file(GLOB_RECURSE SOURCES "src/*.cpp") + +add_executable(tfldump ${DRIVER} ${SOURCES}) target_include_directories(tfldump PRIVATE include) target_link_libraries(tfldump tfldump_flatbuffer) target_link_libraries(tfldump safemain) diff --git a/contrib/tfldump/driver/Driver.cpp b/contrib/tfldump/driver/Driver.cpp index bd022b6..2b95520 100644 --- a/contrib/tfldump/driver/Driver.cpp +++ b/contrib/tfldump/driver/Driver.cpp @@ -28,7 +28,20 @@ int entry(int argc, char **argv) return 255; } - // TODO load TFlite file + // Load TF lite model from a tflite file + std::unique_ptr model = tflread::load_tflite(argv[1]); + if (model == nullptr) + { + std::cerr << "ERROR: Failed to load tflite '" << argv[1] << "'" << std::endl; + return 255; + } + + const tflite::Model *tflmodel = model->model(); + if (tflmodel == nullptr) + { + std::cerr << "ERROR: Failed to load tflite '" << argv[1] << "'" << std::endl; + return 255; + } // TODO dump TFlite model diff --git a/contrib/tfldump/include/tflread/Model.h b/contrib/tfldump/include/tflread/Model.h index 3cf59c4..267e724 100644 --- a/contrib/tfldump/include/tflread/Model.h +++ b/contrib/tfldump/include/tflread/Model.h @@ -19,4 +19,25 @@ #include +#include + +namespace tflread +{ + +struct Model +{ + virtual ~Model() = default; + + virtual const ::tflite::Model *model(void) const = 0; +}; + +/** + * @brief Load TensorFlow Lite model (as a raw Model) from a given path + * + * @note May return a nullptr + */ +std::unique_ptr load_tflite(const std::string &path); + +} // namespace tflread + #endif // __TFLREAD_MODEL_H__ diff --git a/contrib/tfldump/src/Load.cpp b/contrib/tfldump/src/Load.cpp new file mode 100644 index 0000000..fe04a5d --- /dev/null +++ b/contrib/tfldump/src/Load.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +namespace +{ + +class MemoryMappedModel final : public tflread::Model +{ +public: + /** + * @require fd and data SHOULD be valid + */ + explicit MemoryMappedModel(int fd, void *data, size_t size) : _fd{fd}, _data{data}, _size{size} + { + // DO NOTHING + } + +public: + ~MemoryMappedModel() + { + munmap(_data, _size); + close(_fd); + } + +public: + MemoryMappedModel(const MemoryMappedModel &) = delete; + MemoryMappedModel(MemoryMappedModel &&) = delete; + +public: + const ::tflite::Model *model(void) const override { return ::tflite::GetModel(_data); } + +private: + int _fd = -1; + void *_data = nullptr; + size_t _size = 0; +}; + +class FileDescriptor final +{ +public: + FileDescriptor(int value) : _value{value} + { + // DO NOTHING + } + +public: + // NOTE Copy is not allowed + FileDescriptor(const FileDescriptor &) = delete; + +public: + // NOTE Move is allowed + FileDescriptor(FileDescriptor &&fd) { _value = fd.release(); } + +public: + ~FileDescriptor() + { + if (_value != -1) + { + // Close on descturction + close(_value); + } + } + +public: + int value(void) const { return _value; } + +public: + int release(void) + { + auto res = _value; + _value = -1; + return res; + } + +private: + int _value = -1; +}; + +} // namespace + +namespace tflread +{ + +std::unique_ptr load_tflite(const std::string &path) +{ + FileDescriptor fd = open(path.c_str(), O_RDONLY); + + if (fd.value() == -1) + { + // Return nullptr on open failure + return nullptr; + } + + struct stat st; + if (fstat(fd.value(), &st) == -1) + { + // Return nullptr on fstat failure + return nullptr; + } + + auto size = st.st_size; + auto data = mmap(nullptr, size, PROT_READ, MAP_SHARED, fd.value(), 0); + + if (data == MAP_FAILED) + { + // Return nullptr on mmap failure + return nullptr; + } + + return std::unique_ptr{new MemoryMappedModel(fd.release(), data, size)}; +} + +} // namespace tflread -- 2.7.4