From 748d9641a83c33c8e692478402adf6969fa39ac3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 4 May 2018 10:13:47 +0900 Subject: [PATCH] Add 'tflitekit' (#192) This commit adds 'tflitekit' which provides various command related with TensorFlow Lite. Signed-off-by: Jonghyun Park --- contrib/tflitekit/CMakeLists.txt | 11 +++++ contrib/tflitekit/src/tflitekit.cpp | 90 +++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 contrib/tflitekit/CMakeLists.txt create mode 100644 contrib/tflitekit/src/tflitekit.cpp diff --git a/contrib/tflitekit/CMakeLists.txt b/contrib/tflitekit/CMakeLists.txt new file mode 100644 index 0000000..5f95cdf --- /dev/null +++ b/contrib/tflitekit/CMakeLists.txt @@ -0,0 +1,11 @@ +nncc_find_package(TensorFlowLite QUIET) + +if(NOT TensorFlowLite_FOUND) + return() +endif(NOT TensorFlowLite_FOUND) + +file(GLOB_RECURSE SOURCES "src/*.cpp") + +add_executable(tflitekit ${SOURCES}) +target_link_libraries(tflitekit nncc_foundation) +target_link_libraries(tflitekit tensorflowlite) diff --git a/contrib/tflitekit/src/tflitekit.cpp b/contrib/tflitekit/src/tflitekit.cpp new file mode 100644 index 0000000..1612153 --- /dev/null +++ b/contrib/tflitekit/src/tflitekit.cpp @@ -0,0 +1,90 @@ +struct Command +{ + virtual ~Command() = default; + + virtual int run(int argc, char **argv) const = 0; +}; + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +#include + +using namespace tflite; +using namespace tflite::ops::builtin; + +class RunCommand final : public Command +{ +public: + int run(int argc, char **argv) const override; +}; + +int RunCommand::run(int argc, char **argv) const +{ + // USAGE: HEADER run [.tflite] + const auto filename = argv[0]; + + StderrReporter error_reporter; + + auto model = FlatBufferModel::BuildFromFile(filename, &error_reporter); + + std::unique_ptr interpreter; + + TfLiteStatus status = kTfLiteError; + + BuiltinOpResolver resolver; + InterpreterBuilder builder(*model, resolver); + + status = builder(&interpreter); + assert(status == kTfLiteOk); + + interpreter->SetNumThreads(1); + + status = interpreter->AllocateTensors(); + assert(status == kTfLiteOk); + + status = interpreter->Invoke(); + assert(status == kTfLiteOk); + + std::cout << "# of outputs: " << interpreter->outputs().size() << std::endl; + + return 0; +} + +#include + +#include +#include +#include + +int main(int argc, char **argv) +{ + std::map> commands; + + commands["run"] = nncc::foundation::make_unique(); + + if (argc < 2) + { + std::cerr << "ERROR: COMMAND is not provided" << std::endl; + std::cerr << std::endl; + std::cerr << "USAGE: " << argv[0] << " [COMMAND] ..." << std::endl; + return 255; + } + + // USAGE: HEADER [command] ... + if (commands.find(argv[1]) == commands.end()) + { + std::cerr << "ERROR: '" << argv[1] << "' is not a valid command" << std::endl; + std::cerr << std::endl; + std::cerr << "USAGE: " << argv[0] << " [COMMAND] ..." << std::endl; + std::cerr << std::endl; + std::cerr << "SUPPORTED COMMANDS:" << std::endl; + for (auto it = commands.begin(); it != commands.end(); ++it) + { + std::cerr << " " << it->first << std::endl; + } + return 255; + } + + return commands.at(argv[1])->run(argc - 2, argv + 2); +} -- 2.7.4