From 38278bd201dab0dd276ae5823a310f25ca062684 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 15 Oct 2019 09:15:30 +0900 Subject: [PATCH] [moco-tf] Graph build Identity as TF dialect (#8139) This will update graph builder not to produce only TF dialect for Identity node - related Importer test is also modified to work with this change Signed-off-by: SaeHie Park --- compiler/moco-tf/src/Importer.test.cpp | 45 +------------------- compiler/moco-tf/src/Op/Identity.cpp | 75 +--------------------------------- compiler/moco-tf/src/Op/Identity.h | 19 +-------- 3 files changed, 4 insertions(+), 135 deletions(-) diff --git a/compiler/moco-tf/src/Importer.test.cpp b/compiler/moco-tf/src/Importer.test.cpp index 770984b..6fc3e82 100644 --- a/compiler/moco-tf/src/Importer.test.cpp +++ b/compiler/moco-tf/src/Importer.test.cpp @@ -80,44 +80,6 @@ node { } // namespace -TEST(TensorFlowImport, load_model_withio) -{ - moco::tf::ModelSignature signature; - - signature.add_input(moco::tf::TensorName("Placeholder", 0)); - signature.add_output(moco::tf::TensorName("output/identity", 0)); - - tensorflow::GraphDef graph_def; - EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def)); - - using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl; - - moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; - r.add("Identity", stdex::make_unique()); - moco::tf::Importer importer{&r}; - - std::unique_ptr graph = importer.import(signature, graph_def); - - // what to test: - // - import reads Pull - // - import reads Forward - // - attribute values should match - - auto pull = find_first_node_bytype(graph.get()); - ASSERT_NE(pull, nullptr); - auto forward = find_first_node_bytype(graph.get()); - ASSERT_NE(forward, nullptr); - - ASSERT_EQ(pull->dtype(), loco::DataType::FLOAT32); - ASSERT_EQ(pull->rank(), 4); - loco::Dimension dim1 = 1; - loco::Dimension dim2 = 2; - ASSERT_EQ(pull->dim(0).value(), dim1.value()); - ASSERT_EQ(pull->dim(1).value(), dim2.value()); - ASSERT_EQ(pull->dim(2).value(), dim1.value()); - ASSERT_EQ(pull->dim(3).value(), dim2.value()); -} - TEST(TensorFlowImport, load_model_withio_tf) { moco::tf::ModelSignature signature; @@ -128,12 +90,7 @@ TEST(TensorFlowImport, load_model_withio_tf) tensorflow::GraphDef graph_def; EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def)); - using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl; - - moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()}; - // TODO add Placeholder - r.add("Identity", stdex::make_unique()); - moco::tf::Importer importer{&r}; + moco::tf::Importer importer; std::unique_ptr graph = importer.import(signature, graph_def); diff --git a/compiler/moco-tf/src/Op/Identity.cpp b/compiler/moco-tf/src/Op/Identity.cpp index 03dac6d..c67aea8 100644 --- a/compiler/moco-tf/src/Op/Identity.cpp +++ b/compiler/moco-tf/src/Op/Identity.cpp @@ -38,21 +38,6 @@ namespace using namespace moco::tf; -class IdentityGraphUpdate final : public GraphUpdate -{ -public: - IdentityGraphUpdate(loco::Forward *node, const std::vector &names) - : _node(node), _names(names) - { - } - - void input(const SymbolTable *) const override; - -private: - loco::Forward *_node; - const std::vector _names; -}; - class TFIdentityGraphUpdate final : public GraphUpdate { public: @@ -68,15 +53,6 @@ private: const std::vector _names; }; -void IdentityGraphUpdate::input(const SymbolTable *tensor_names) const -{ - for (auto &name : _names) - { - loco::Node *target = tensor_names->node(name); - _node->input(target); - } -} - void TFIdentityGraphUpdate::input(const SymbolTable *tensor_names) const { for (auto &name : _names) @@ -93,16 +69,7 @@ namespace moco namespace tf { -/** - * @brief GraphBuilder for Identity node - */ -class IdentityGraphBuilder final : public IdentityGraphBuilderBase -{ -public: - void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; -}; - -bool IdentityGraphBuilderBase::validate(const tensorflow::NodeDef &node) const +bool IdentityGraphBuilder::validate(const tensorflow::NodeDef &node) const { if (node.input_size() < 1) // from TensorFlow lite toco return false; @@ -115,46 +82,6 @@ void IdentityGraphBuilder::build(const tensorflow::NodeDef &node, { assert(context != nullptr); - if (moco::tf::get()) - { - IdentityGraphBuilderImpl builder; - return builder.build(node, context); - } - else - { - IdentityGraphBuilderImpl builder; - return builder.build(node, context); - } -} - -void IdentityGraphBuilderImpl::build(const tensorflow::NodeDef &node, - GraphBuilderContext *context) const -{ - loco::Graph *graph = context->graph(); - SymbolTable *tensor_names = context->tensor_names(); - UpdateQueue *updates = context->updates(); - - // Create a "Forward" node for Identity - auto forward_node = graph->nodes()->create(); - - // register string-name to node - TensorName output_name(node.name(), 0); - tensor_names->enroll(output_name, forward_node); - - // Queue node input update - // TODO: Check if we really need multiple input handlings - std::vector names; - for (int i = 0; i < node.input_size(); ++i) - { - names.emplace_back(TensorName(node.input(i))); - } - auto update = stdex::make_unique(forward_node, names); - updates->enroll(std::move(update)); -} - -void IdentityGraphBuilderImpl::build(const tensorflow::NodeDef &node, - GraphBuilderContext *context) const -{ loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); diff --git a/compiler/moco-tf/src/Op/Identity.h b/compiler/moco-tf/src/Op/Identity.h index 55da007..e900a53 100644 --- a/compiler/moco-tf/src/Op/Identity.h +++ b/compiler/moco-tf/src/Op/Identity.h @@ -18,31 +18,16 @@ #define __OP_IDENTITY_H__ #include "GraphBuilder.h" -#include "ImportTarget.h" namespace moco { namespace tf { -struct IdentityGraphBuilderBase : public GraphBuilder +class IdentityGraphBuilder : public GraphBuilder { - virtual ~IdentityGraphBuilderBase() = default; - +public: bool validate(const tensorflow::NodeDef &) const final; -}; - -template class IdentityGraphBuilderImpl; - -template <> -struct IdentityGraphBuilderImpl final : public IdentityGraphBuilderBase -{ - void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; -}; - -template <> -struct IdentityGraphBuilderImpl final : public IdentityGraphBuilderBase -{ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final; }; -- 2.7.4