#ifndef __NNKIT_SUPPORT_TF_BACKEND_H__
#define __NNKIT_SUPPORT_TF_BACKEND_H__
-#include "nnkit/Backend.h"
+#include "nnkit/support/tf/TensorDataMap.h"
+#include "nnkit/support/tf/TensorContext.h"
+#include "nnkit/support/tf/ParsedTensor.h"
+#include "nnkit/support/tf/Runner.h"
+
+#include <nnkit/Backend.h>
+
+#include <memory>
+#include <vector>
namespace nnkit
{
class Backend final : public nnkit::Backend
{
public:
+ Backend() = delete;
+ Backend(const Backend &) = delete;
+ Backend(Backend &&) = delete;
+
Backend(const char *pb_path, const char *info_path);
void prepare(const std::function<void(nnkit::TensorContext &)> &f) override;
void run(void) override;
- void teardown(const std::function<void(nnkit::TensorContext &)> &f);
+ void teardown(const std::function<void(nnkit::TensorContext &)> &f) override;
+
+private:
+ std::vector<std::unique_ptr<ParsedTensor>> _inputs;
+ std::vector<std::unique_ptr<ParsedTensor>> _outputs;
+
+ TensorDataMap _data_map;
+
+ Runner _tf_runner;
};
} // namespace tf
#include "nnkit/support/tf/Backend.h"
+#include "nnkit/support/tf/ParsedTensor.h"
+#include "nnkit/support/tf/TensorInfoParser.h"
+#include "nnkit/support/tf/TensorDataMap.h"
+#include "nnkit/support/tf/TensorContext.h"
+#include "nnkit/support/tf/Runner.h"
+
+#include <nnkit/Backend.h>
+
+#include <cstring> // memcpy
+
namespace nnkit
{
namespace support
namespace tf
{
-Backend::Backend(const char *pb_path, const char *info_path)
+Backend::Backend(const char *pb_path, const char *info_path) : _tf_runner(pb_path)
{
- throw new std::runtime_error("NYI");
+ auto parsed_tensors = parse(info_path);
+
+ for (auto &parsed_tensor : parsed_tensors)
+ {
+ if (parsed_tensor->kind() == ParsedTensor::Kind::Input)
+ _inputs.emplace_back(std::move(parsed_tensor));
+ else
+ _outputs.emplace_back(std::move(parsed_tensor));
+ }
}
void Backend::prepare(const std::function<void(nnkit::TensorContext &)> &f)
{
- throw new std::runtime_error("NYI");
+ assert(_inputs.size() == 1); // TODO support more than 1
+
+ for (const auto &input_tensor : _inputs)
+ _data_map.allocate(input_tensor.get());
+
+ TensorContext ctx(_inputs, _data_map);
+ f(ctx); // fill values
+
+ _tf_runner.prepareInputs(_inputs, _data_map);
+ _tf_runner.prepareOutputs(_outputs);
}
-void Backend::run(void) { throw new std::runtime_error("NYI"); }
+void Backend::run(void)
+{
+ _tf_runner.run();
+
+ // get result
+ assert(_outputs.size() == 1); // TODO support more than 1
+
+ for (const auto &output_tensor : _outputs)
+ {
+ const TF_Tensor *output = _tf_runner.output();
+
+ const size_t byte_size = TF_TensorByteSize(output);
+ const uint8_t *tf_data = reinterpret_cast<const uint8_t *>(TF_TensorData(output));
+
+ uint8_t *dest = _data_map.allocate(output_tensor.get());
+
+ std::memcpy(dest, tf_data, byte_size);
+ }
+}
void Backend::teardown(const std::function<void(nnkit::TensorContext &)> &f)
{
- throw new std::runtime_error("NYI");
+ TensorContext ctx(_outputs, _data_map);
+ f(ctx);
}
} // namespace tf