From b306872e6e10f7427b1656356033059c7440eed5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dmitry=20Mozolev/SRR-AI=20Tools=20Lab/=2E/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 23 May 2018 02:01:52 +0300 Subject: [PATCH] Add TFLite model visiting mechanism (#170) * Add TFLite model visiting mechanism Two main things added: - TFLite model visitor interface - Code for iterating over the model contents (walker) Signed-off-by: Dmitry Mozolev * Change class names --- .../libs/frontend/tflite/include/tflite_visitor.h | 33 ++++++++++ .../libs/frontend/tflite/include/tflite_walker.h | 71 ++++++++++++++++++++++ .../nnc/libs/frontend/tflite/src/tflite_walker.cpp | 33 ++++++++++ 3 files changed, 137 insertions(+) create mode 100644 contrib/nnc/libs/frontend/tflite/include/tflite_visitor.h create mode 100644 contrib/nnc/libs/frontend/tflite/include/tflite_walker.h create mode 100644 contrib/nnc/libs/frontend/tflite/src/tflite_walker.cpp diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_visitor.h b/contrib/nnc/libs/frontend/tflite/include/tflite_visitor.h new file mode 100644 index 0000000..9ef9e8e --- /dev/null +++ b/contrib/nnc/libs/frontend/tflite/include/tflite_visitor.h @@ -0,0 +1,33 @@ +#ifndef NNCC_TFLITE_VISITOR_H +#define NNCC_TFLITE_VISITOR_H + +#include "schema_v3_generated.h" + +using namespace tflite; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +class Visitor +{ +public: + virtual void visit(const Model *) = 0; + virtual void visit(const SubGraph *) = 0; + virtual void visit(const Tensor *) = 0; + virtual void visit(const OperatorCode *) = 0; + virtual void visit(const Operator *) = 0; + virtual void visit(const Buffer *) = 0; +}; + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc + +#endif // NNCC_TFLITE_VISITOR_H diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_walker.h b/contrib/nnc/libs/frontend/tflite/include/tflite_walker.h new file mode 100644 index 0000000..e1425e0 --- /dev/null +++ b/contrib/nnc/libs/frontend/tflite/include/tflite_walker.h @@ -0,0 +1,71 @@ +#ifndef NNCC_TFLITE_WALKER_H +#define NNCC_TFLITE_WALKER_H + +#include "flatbuffers/flatbuffers.h" + +#include +#include + +#include "schema_v3_generated.h" +#include "tflite_visitor.h" + +using namespace tflite; + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +class ModelWalker +{ +public: + explicit ModelWalker(std::vector actions) : actions(std::move(actions)) {}; + + template void walk(T *elem); + template void performActions(T *elem); + template void walkVector(const flatbuffers::Vector> *v); + +private: + void walkContents(const Model *); + void walkContents(const Tensor *); + void walkContents(const OperatorCode *); + void walkContents(const Operator *); + void walkContents(const SubGraph *); + void walkContents(const Buffer *); + + std::vector actions; +}; + +template void ModelWalker::walk(T *elem) +{ + performActions(elem); + walkContents(elem); +} + +template void ModelWalker::performActions(T *elem) +{ + for (auto action : actions) + { + action->visit(elem); + } +} + +template +void ModelWalker::walkVector(const flatbuffers::Vector> *v) +{ + for (auto it = v->begin(); it != v->end(); ++it) + { + walk(*it); + } +} + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc + +#endif // NNCC_TFLITE_WALKER_H diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_walker.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_walker.cpp new file mode 100644 index 0000000..03b397d --- /dev/null +++ b/contrib/nnc/libs/frontend/tflite/src/tflite_walker.cpp @@ -0,0 +1,33 @@ +#include "tflite_walker.h" + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +void ModelWalker::walkContents(const Model *m) +{ + walkVector(m->operator_codes()); + walkVector(m->buffers()); + walkVector(m->subgraphs()); +} + +void ModelWalker::walkContents(const SubGraph *s) +{ + walkVector(s->tensors()); + walkVector(s->operators()); +} + +void ModelWalker::walkContents(const Tensor *t) {} +void ModelWalker::walkContents(const Buffer *b) {} +void ModelWalker::walkContents(const OperatorCode *oc) {} +void ModelWalker::walkContents(const Operator *) {} + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc -- 2.7.4