From cd53ede592e58b5c335079f462581d1460359be7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 10 Jul 2019 17:33:12 +0900 Subject: [PATCH] [exo/tflite] Honor Graph Input/Output order (#4184) TFLExporter now creates TensorFlow Lite Graph Input/Output using graph-level specification. Signed-off-by: Jonghyun Park --- contrib/exo-tflite/src/OperationExporter.cpp | 4 ++-- contrib/exo-tflite/src/TFLExporterImpl.cpp | 30 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/contrib/exo-tflite/src/OperationExporter.cpp b/contrib/exo-tflite/src/OperationExporter.cpp index ada47f3..17e34ac 100644 --- a/contrib/exo-tflite/src/OperationExporter.cpp +++ b/contrib/exo-tflite/src/OperationExporter.cpp @@ -284,11 +284,11 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, } else if (dynamic_cast(node)) { - data._inputs.push_back(get_tensor_index(node)); + // DO NOTHING } else if (dynamic_cast(node)) { - data._outputs.push_back(get_tensor_index(node->arg(0))); + // DO NOTHING } else if (auto *encode = dynamic_cast(node)) { diff --git a/contrib/exo-tflite/src/TFLExporterImpl.cpp b/contrib/exo-tflite/src/TFLExporterImpl.cpp index 37e36ed..eb780f5 100644 --- a/contrib/exo-tflite/src/TFLExporterImpl.cpp +++ b/contrib/exo-tflite/src/TFLExporterImpl.cpp @@ -21,9 +21,35 @@ #include "OperationExporter.h" #include "ExporterUtils.h" +#include #include #include +namespace +{ + +void registerGraphInputTensors(loco::Graph *graph, SubGraphContext &ctx) +{ + for (uint32_t n = 0; n < graph->inputs()->size(); ++n) + { + auto node = graph->inputs()->at(n)->node(); + assert(node != nullptr); + ctx._inputs.push_back(get_tensor_index(node)); + } +} + +void registerGraphOutputTensors(loco::Graph *graph, SubGraphContext &ctx) +{ + for (uint32_t n = 0; n < graph->outputs()->size(); ++n) + { + auto node = graph->outputs()->at(n)->node()->from(); + assert(node != nullptr); + ctx._outputs.push_back(get_tensor_index(node)); + } +} + +} // namespace + namespace exo { using namespace tflite; @@ -75,6 +101,10 @@ void TFLExporter::Impl::exportGraph(loco::Graph *graph) // parse graph into SerializedModelData structure exportOpDefinedTensors(graph, _builder, gd); + // NOTE Invoke these register functions only after each node is annotated with its tensor_index + registerGraphInputTensors(graph, gd); + registerGraphOutputTensors(graph, gd); + exportNodes(graph->nodes(), _builder, gd); // excode operator codes -- 2.7.4