From 19b78574fc300971e84e82cf79e7b6891bcca15f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 8 Jul 2019 09:14:32 +0900 Subject: [PATCH] [exo/tflite] Use loco::output_nodes in exportOpDefinedTensors (#4118) exportOpDefinedTensors currently uses its own implementation of loco::output_nodes. Signed-off-by: Jonghyun Park --- contrib/exo-tflite/src/TFLExporterImpl.cpp | 2 +- contrib/exo-tflite/src/TensorExporter.cpp | 16 ++-------------- contrib/exo-tflite/src/TensorExporter.h | 6 +++--- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/contrib/exo-tflite/src/TFLExporterImpl.cpp b/contrib/exo-tflite/src/TFLExporterImpl.cpp index 7db8f7e..5910eb6 100644 --- a/contrib/exo-tflite/src/TFLExporterImpl.cpp +++ b/contrib/exo-tflite/src/TFLExporterImpl.cpp @@ -64,7 +64,7 @@ void TFLExporter::Impl::exportGraph(loco::Graph *graph) registerGraphIOName(graph, gd); // parse graph into SerializedModelData structure - exportOpDefinedTensors(graph->nodes(), _builder, gd); + exportOpDefinedTensors(graph, _builder, gd); exportNodes(graph->nodes(), _builder, gd); diff --git a/contrib/exo-tflite/src/TensorExporter.cpp b/contrib/exo-tflite/src/TensorExporter.cpp index 44af923..8314bf9 100644 --- a/contrib/exo-tflite/src/TensorExporter.cpp +++ b/contrib/exo-tflite/src/TensorExporter.cpp @@ -98,24 +98,12 @@ void exportOpDefinedTensor(NodeT *node, FlatBufferBuilder &builder, SerializedMo gd._tensors.push_back(tensor_offset); } -void exportOpDefinedTensors(loco::Graph::NodeContext *nodes, FlatBufferBuilder &builder, - SerializedModelData &gd) +void exportOpDefinedTensors(loco::Graph *g, FlatBufferBuilder &builder, SerializedModelData &gd) { - // find entrances of graph - std::vector roots; - for (uint32_t node_id = 0; node_id < nodes->size(); ++node_id) - { - loco::Node *node = nodes->at(node_id); - if (dynamic_cast(node)) - { - roots.push_back(node); - } - } - // Operations should be traversed in RPO because during processing of current operation // we need to know all attributes of previous operations, // like shape, type,tensor id related with previous operation - auto sequence = loco::postorder_traversal(roots); + auto sequence = loco::postorder_traversal(loco::output_nodes(g)); for (loco::Node *node : sequence) { if (auto *pull = dynamic_cast(node)) diff --git a/contrib/exo-tflite/src/TensorExporter.h b/contrib/exo-tflite/src/TensorExporter.h index a55a362..4c1902a 100644 --- a/contrib/exo-tflite/src/TensorExporter.h +++ b/contrib/exo-tflite/src/TensorExporter.h @@ -25,10 +25,10 @@ /** * @brief create Tensors corresponding to results of all nodes in graph - * @param nodes list of nodes in computational graph + * @param computational graph * @param gd information about serialized parts of model */ -void exportOpDefinedTensors(loco::Graph::NodeContext *nodes, - flatbuffers::FlatBufferBuilder &builder, SerializedModelData &gd); +void exportOpDefinedTensors(loco::Graph *g, flatbuffers::FlatBufferBuilder &builder, + SerializedModelData &gd); #endif // __TENSOR_EXPORTER_H__ -- 2.7.4