2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "tflchef/ModelChef.h"
19 #include <google/protobuf/io/coded_stream.h>
20 #include <google/protobuf/io/zero_copy_stream_impl.h>
21 #include <google/protobuf/text_format.h>
23 #include <arser/arser.h>
28 int entry(int argc, char **argv)
31 arser.add_argument("recipe")
32 .type(arser::DataType::STR)
33 .help("Source recipe file path to convert");
34 arser.add_argument("tflite").type(arser::DataType::STR).help("Target tflite file path");
38 arser.parse(argc, argv);
40 catch (const std::runtime_error &err)
42 std::cout << err.what() << std::endl;
47 int32_t model_version = 1;
49 ::tflchef::ModelRecipe model_recipe;
51 std::string recipe_path = arser.get<std::string>("recipe");
52 // Load model recipe from a file
54 std::ifstream is{recipe_path};
55 google::protobuf::io::IstreamInputStream iis{&is};
56 if (!google::protobuf::TextFormat::Parse(&iis, &model_recipe))
58 std::cerr << "ERROR: Failed to parse recipe '" << recipe_path << "'" << std::endl;
62 if (model_recipe.has_version())
64 model_version = model_recipe.version();
68 if (model_version > 1)
70 std::cerr << "ERROR: Unsupported recipe version: " << model_version << ", '" << argv[1] << "'"
75 auto generated_model = tflchef::cook(model_recipe);
77 std::string tflite_path = arser.get<std::string>("tflite");
78 // Dump generated model into a file
80 std::ofstream os{tflite_path, std::ios::binary};
81 os.write(generated_model.base(), generated_model.size());