[tflite_loader] Model file verification (#8132)
author오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 15 Oct 2019 03:00:25 +0000 (12:00 +0900)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Tue, 15 Oct 2019 03:00:25 +0000 (12:00 +0900)
Verification check before tflite model loading

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
runtimes/neurun/frontend/base_loader/base_loader.h
runtimes/neurun/frontend/circle/circle_loader.cc
runtimes/neurun/frontend/tflite/tflite_loader.cc

index 553de79..9c6e239 100644 (file)
@@ -125,11 +125,13 @@ protected:
 template <typename LoaderDomain, typename SpecificLoader>
 void BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::loadFromFile(const char *file_path)
 {
-  std::ifstream stream(file_path, std::fstream::ate | std::fstream::binary);
+  std::ifstream stream(file_path, std::fstream::in | std::fstream::binary);
+
+  stream.seekg(0, stream.end);
   auto size = stream.tellg();
-  stream.seekg(0);
+  stream.seekg(0, stream.beg);
 
-  _buffer.reserve(size);
+  _buffer.resize(size);
   stream.read(_buffer.data(), size);
 
   stream.close();
@@ -793,6 +795,7 @@ void BaseLoader<LoaderDomain, SpecificLoader>::loadConstantTensor(const Buffer *
 template <typename LoaderDomain, typename SpecificLoader>
 void BaseLoader<LoaderDomain, SpecificLoader>::loadModel()
 {
+  static_cast<SpecificLoader *>(this)->verify();
   _model = LoaderDomain::GetModel(_buffer.data());
   // Version unused
   // const auto version = _model->version();
index 7076ece..18669cf 100644 (file)
@@ -82,6 +82,14 @@ public:
 
     (void)subgraph->data_format();
   }
+
+  void verify()
+  {
+    flatbuffers::Verifier verifier(reinterpret_cast<const std::uint8_t *>(_buffer.data()),
+                                   _buffer.size());
+    if (!circle::VerifyModelBuffer(verifier))
+      throw std::runtime_error{"Invalid circle model"};
+  }
 };
 
 } // namespace
index 6b59369..40aa1dc 100644 (file)
@@ -80,6 +80,14 @@ public:
       loadOperation(op);
     }
   }
+
+  void verify()
+  {
+    flatbuffers::Verifier verifier(reinterpret_cast<const std::uint8_t *>(_buffer.data()),
+                                   _buffer.size());
+    if (!tflite::VerifyModelBuffer(verifier))
+      throw std::runtime_error{"Invalid tflite model"};
+  }
 };
 
 } // namespace