From f51849eb432b3929db83f4fecc3939c6bc1fc229 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 15 Apr 2019 11:03:42 +0900 Subject: [PATCH] [nnkit/TF] Backend implementation for nnkit tensorflow backend (#3251) * [nnkit/TF] Backend.h for nnkit tensorflow backend This commit adds Backend.h which reads input, parse test.info, run tensorflow, and produce output. Signed-off-by: Hyun Sik Yoon * impl to cpp * remove useless header. add #include * headers --- .../support/tf/include/nnkit/support/tf/Backend.h | 24 +++++++++- contrib/nnkit/libs/support/tf/src/Backend.cpp | 56 ++++++++++++++++++++-- 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/contrib/nnkit/libs/support/tf/include/nnkit/support/tf/Backend.h b/contrib/nnkit/libs/support/tf/include/nnkit/support/tf/Backend.h index 2ae086b..906379f 100644 --- a/contrib/nnkit/libs/support/tf/include/nnkit/support/tf/Backend.h +++ b/contrib/nnkit/libs/support/tf/include/nnkit/support/tf/Backend.h @@ -17,7 +17,15 @@ #ifndef __NNKIT_SUPPORT_TF_BACKEND_H__ #define __NNKIT_SUPPORT_TF_BACKEND_H__ -#include "nnkit/Backend.h" +#include "nnkit/support/tf/TensorDataMap.h" +#include "nnkit/support/tf/TensorContext.h" +#include "nnkit/support/tf/ParsedTensor.h" +#include "nnkit/support/tf/Runner.h" + +#include + +#include +#include namespace nnkit { @@ -29,13 +37,25 @@ namespace tf class Backend final : public nnkit::Backend { public: + Backend() = delete; + Backend(const Backend &) = delete; + Backend(Backend &&) = delete; + Backend(const char *pb_path, const char *info_path); void prepare(const std::function &f) override; void run(void) override; - void teardown(const std::function &f); + void teardown(const std::function &f) override; + +private: + std::vector> _inputs; + std::vector> _outputs; + + TensorDataMap _data_map; + + Runner _tf_runner; }; } // namespace tf diff --git a/contrib/nnkit/libs/support/tf/src/Backend.cpp b/contrib/nnkit/libs/support/tf/src/Backend.cpp index 9786db7..b1e35d4 100644 --- a/contrib/nnkit/libs/support/tf/src/Backend.cpp +++ b/contrib/nnkit/libs/support/tf/src/Backend.cpp @@ -16,6 +16,16 @@ #include "nnkit/support/tf/Backend.h" +#include "nnkit/support/tf/ParsedTensor.h" +#include "nnkit/support/tf/TensorInfoParser.h" +#include "nnkit/support/tf/TensorDataMap.h" +#include "nnkit/support/tf/TensorContext.h" +#include "nnkit/support/tf/Runner.h" + +#include + +#include // memcpy + namespace nnkit { namespace support @@ -23,21 +33,57 @@ namespace support namespace tf { -Backend::Backend(const char *pb_path, const char *info_path) +Backend::Backend(const char *pb_path, const char *info_path) : _tf_runner(pb_path) { - throw new std::runtime_error("NYI"); + auto parsed_tensors = parse(info_path); + + for (auto &parsed_tensor : parsed_tensors) + { + if (parsed_tensor->kind() == ParsedTensor::Kind::Input) + _inputs.emplace_back(std::move(parsed_tensor)); + else + _outputs.emplace_back(std::move(parsed_tensor)); + } } void Backend::prepare(const std::function &f) { - throw new std::runtime_error("NYI"); + assert(_inputs.size() == 1); // TODO support more than 1 + + for (const auto &input_tensor : _inputs) + _data_map.allocate(input_tensor.get()); + + TensorContext ctx(_inputs, _data_map); + f(ctx); // fill values + + _tf_runner.prepareInputs(_inputs, _data_map); + _tf_runner.prepareOutputs(_outputs); } -void Backend::run(void) { throw new std::runtime_error("NYI"); } +void Backend::run(void) +{ + _tf_runner.run(); + + // get result + assert(_outputs.size() == 1); // TODO support more than 1 + + for (const auto &output_tensor : _outputs) + { + const TF_Tensor *output = _tf_runner.output(); + + const size_t byte_size = TF_TensorByteSize(output); + const uint8_t *tf_data = reinterpret_cast(TF_TensorData(output)); + + uint8_t *dest = _data_map.allocate(output_tensor.get()); + + std::memcpy(dest, tf_data, byte_size); + } +} void Backend::teardown(const std::function &f) { - throw new std::runtime_error("NYI"); + TensorContext ctx(_outputs, _data_map); + f(ctx); } } // namespace tf -- 2.7.4