Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / circle-opselector / src / ModuleIO.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ModuleIO.h"
18
19 #include <foder/FileLoader.h>
20
21 #include <luci/Importer.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24
25 #include <iostream>
26
27 namespace opselector
28 {
29
30 std::unique_ptr<luci::Module> getModule(std::string &input_path)
31 {
32   // Load model from the file
33   foder::FileLoader file_loader{input_path};
34   std::vector<char> model_data = file_loader.load();
35
36   // Verify flatbuffers
37   flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
38   if (!circle::VerifyModelBuffer(verifier))
39   {
40     std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
41     exit(EXIT_FAILURE);
42   }
43
44   const circle::Model *circle_model = circle::GetModel(model_data.data());
45   if (circle_model == nullptr)
46   {
47     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
48     exit(EXIT_FAILURE);
49   }
50
51   // Import from input Circle file
52   luci::Importer importer;
53
54   return importer.importModule(circle_model);
55 }
56
57 bool exportModule(luci::Module *module, std::string &output_path)
58 {
59   luci::CircleExporter exporter;
60
61   luci::CircleFileExpContract contract(module, output_path);
62
63   if (!exporter.invoke(&contract))
64   {
65     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
66     return false;
67   }
68
69   return true;
70 }
71
72 } // namespace opselector