From 3bc5abf29d58bf1bf068383f5896f4cdb2176b84 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 22 Apr 2019 09:50:12 +0900 Subject: [PATCH] [moco] convert two nodes (#3283) This will fill two Placeholder and Identity node converters Signed-off-by: SaeHie Park --- contrib/moco/lib/frontend/tf/src/Op/Identity.cpp | 26 +++++++++++++++-- .../moco/lib/frontend/tf/src/Op/Placeholder.cpp | 34 ++++++++++++++++++++-- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/contrib/moco/lib/frontend/tf/src/Op/Identity.cpp b/contrib/moco/lib/frontend/tf/src/Op/Identity.cpp index c5edfa4..48bcdaa 100644 --- a/contrib/moco/lib/frontend/tf/src/Op/Identity.cpp +++ b/contrib/moco/lib/frontend/tf/src/Op/Identity.cpp @@ -18,6 +18,8 @@ #include "GraphBuilderContext.h" +#include + #include #include @@ -28,16 +30,34 @@ namespace moco namespace tf { -bool IdentityGraphBuilder::validate(const tensorflow::NodeDef &node) const { return true; } +bool IdentityGraphBuilder::validate(const tensorflow::NodeDef &node) const +{ + if (node.input_size() < 1) // from TensorFlow lite toco + return false; + + return true; +} void IdentityGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { assert(context != nullptr); - // TODO implement + loco::Graph *graph = context->graph(); + SymbolTable *nodes = context->nodes(); + SymbolTable *input_names = context->input_names(); + + // Create a "Forward" node for Identity + auto forward_node = graph->nodes()->create(); + + nodes->enroll(node.name(), forward_node); - throw std::runtime_error{"IdentityGraphBuilder NYI"}; + // Record all inputs to forward_node + for (int i = 0; i < node.input_size(); ++i) + { + const auto &input_name = node.input(i); + input_names->list(forward_node, input_name); + } } } // namespace tf diff --git a/contrib/moco/lib/frontend/tf/src/Op/Placeholder.cpp b/contrib/moco/lib/frontend/tf/src/Op/Placeholder.cpp index c6caf8f..5028f94 100644 --- a/contrib/moco/lib/frontend/tf/src/Op/Placeholder.cpp +++ b/contrib/moco/lib/frontend/tf/src/Op/Placeholder.cpp @@ -16,6 +16,7 @@ #include "Placeholder.h" +#include "Convert.h" #include "GraphBuilderContext.h" #include @@ -28,16 +29,43 @@ namespace moco namespace tf { -bool PlaceholderGraphBuilder::validate(const tensorflow::NodeDef &node) const { return true; } +bool PlaceholderGraphBuilder::validate(const tensorflow::NodeDef &node) const +{ + if (!node.attr().count("dtype")) + return false; + if (!node.attr().count("shape")) + return false; + + return true; +} void PlaceholderGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { assert(context != nullptr); - // TODO implement + loco::Graph *graph = context->graph(); + SymbolTable *nodes = context->nodes(); + + loco::DataType dtype = as_loco_datatype(get_datatype_attr(node, "dtype")); + const auto &shape = get_shape_attr(node, "shape"); + int64_t num_dims = shape.dim_size(); + + // TODO support other types + assert(dtype == loco::DataType::FLOAT32); + + // Create a "pull" node as an input + auto pull_node = graph->nodes()->create(); + + pull_node->dtype(dtype); + + pull_node->rank(num_dims); + for (int64_t d = 0; d < num_dims; d++) + { + pull_node->dim(d) = loco::make_dimension(shape.dim(d).size()); + } - throw std::runtime_error{"PlaceholderGraphBuilder NYI"}; + nodes->enroll(node.name(), pull_node); } } // namespace tf -- 2.7.4