[moco-tf] Graph build Identity as TF dialect (#8139)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 15 Oct 2019 00:15:30 +0000 (09:15 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 15 Oct 2019 00:15:30 +0000 (09:15 +0900)
This will update graph builder not to produce only TF dialect for Identity node
- related Importer test is also modified to work with this change

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/moco-tf/src/Importer.test.cpp
compiler/moco-tf/src/Op/Identity.cpp
compiler/moco-tf/src/Op/Identity.h

index 770984b..6fc3e82 100644 (file)
@@ -80,44 +80,6 @@ node {
 
 } // namespace
 
-TEST(TensorFlowImport, load_model_withio)
-{
-  moco::tf::ModelSignature signature;
-
-  signature.add_input(moco::tf::TensorName("Placeholder", 0));
-  signature.add_output(moco::tf::TensorName("output/identity", 0));
-
-  tensorflow::GraphDef graph_def;
-  EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
-
-  using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::Canonical>;
-
-  moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
-  r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
-  moco::tf::Importer importer{&r};
-
-  std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
-  // what to test:
-  // - import reads Pull
-  // - import reads Forward
-  // - attribute values should match
-
-  auto pull = find_first_node_bytype<loco::Pull>(graph.get());
-  ASSERT_NE(pull, nullptr);
-  auto forward = find_first_node_bytype<loco::Forward>(graph.get());
-  ASSERT_NE(forward, nullptr);
-
-  ASSERT_EQ(pull->dtype(), loco::DataType::FLOAT32);
-  ASSERT_EQ(pull->rank(), 4);
-  loco::Dimension dim1 = 1;
-  loco::Dimension dim2 = 2;
-  ASSERT_EQ(pull->dim(0).value(), dim1.value());
-  ASSERT_EQ(pull->dim(1).value(), dim2.value());
-  ASSERT_EQ(pull->dim(2).value(), dim1.value());
-  ASSERT_EQ(pull->dim(3).value(), dim2.value());
-}
-
 TEST(TensorFlowImport, load_model_withio_tf)
 {
   moco::tf::ModelSignature signature;
@@ -128,12 +90,7 @@ TEST(TensorFlowImport, load_model_withio_tf)
   tensorflow::GraphDef graph_def;
   EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
 
-  using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::TensorFlow>;
-
-  moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
-  // TODO add Placeholder
-  r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
-  moco::tf::Importer importer{&r};
+  moco::tf::Importer importer;
 
   std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
 
index 03dac6d..c67aea8 100644 (file)
@@ -38,21 +38,6 @@ namespace
 
 using namespace moco::tf;
 
-class IdentityGraphUpdate final : public GraphUpdate
-{
-public:
-  IdentityGraphUpdate(loco::Forward *node, const std::vector<TensorName> &names)
-      : _node(node), _names(names)
-  {
-  }
-
-  void input(const SymbolTable *) const override;
-
-private:
-  loco::Forward *_node;
-  const std::vector<TensorName> _names;
-};
-
 class TFIdentityGraphUpdate final : public GraphUpdate
 {
 public:
@@ -68,15 +53,6 @@ private:
   const std::vector<TensorName> _names;
 };
 
-void IdentityGraphUpdate::input(const SymbolTable *tensor_names) const
-{
-  for (auto &name : _names)
-  {
-    loco::Node *target = tensor_names->node(name);
-    _node->input(target);
-  }
-}
-
 void TFIdentityGraphUpdate::input(const SymbolTable *tensor_names) const
 {
   for (auto &name : _names)
@@ -93,16 +69,7 @@ namespace moco
 namespace tf
 {
 
-/**
- * @brief GraphBuilder for Identity node
- */
-class IdentityGraphBuilder final : public IdentityGraphBuilderBase
-{
-public:
-  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
-};
-
-bool IdentityGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+bool IdentityGraphBuilder::validate(const tensorflow::NodeDef &node) const
 {
   if (node.input_size() < 1) // from TensorFlow lite toco
     return false;
@@ -115,46 +82,6 @@ void IdentityGraphBuilder::build(const tensorflow::NodeDef &node,
 {
   assert(context != nullptr);
 
-  if (moco::tf::get<moco::tf::Knob::ImportAsTFIdentity>())
-  {
-    IdentityGraphBuilderImpl<ImportTarget::TensorFlow> builder;
-    return builder.build(node, context);
-  }
-  else
-  {
-    IdentityGraphBuilderImpl<ImportTarget::Canonical> builder;
-    return builder.build(node, context);
-  }
-}
-
-void IdentityGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
-                                                              GraphBuilderContext *context) const
-{
-  loco::Graph *graph = context->graph();
-  SymbolTable *tensor_names = context->tensor_names();
-  UpdateQueue *updates = context->updates();
-
-  // Create a "Forward" node for Identity
-  auto forward_node = graph->nodes()->create<loco::Forward>();
-
-  // register string-name to node
-  TensorName output_name(node.name(), 0);
-  tensor_names->enroll(output_name, forward_node);
-
-  // Queue node input update
-  // TODO: Check if we really need multiple input handlings
-  std::vector<TensorName> names;
-  for (int i = 0; i < node.input_size(); ++i)
-  {
-    names.emplace_back(TensorName(node.input(i)));
-  }
-  auto update = stdex::make_unique<IdentityGraphUpdate>(forward_node, names);
-  updates->enroll(std::move(update));
-}
-
-void IdentityGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
-                                                               GraphBuilderContext *context) const
-{
   loco::Graph *graph = context->graph();
   SymbolTable *tensor_names = context->tensor_names();
   UpdateQueue *updates = context->updates();
index 55da007..e900a53 100644 (file)
 #define __OP_IDENTITY_H__
 
 #include "GraphBuilder.h"
-#include "ImportTarget.h"
 
 namespace moco
 {
 namespace tf
 {
 
-struct IdentityGraphBuilderBase : public GraphBuilder
+class IdentityGraphBuilder : public GraphBuilder
 {
-  virtual ~IdentityGraphBuilderBase() = default;
-
+public:
   bool validate(const tensorflow::NodeDef &) const final;
-};
-
-template <ImportTarget T> class IdentityGraphBuilderImpl;
-
-template <>
-struct IdentityGraphBuilderImpl<ImportTarget::Canonical> final : public IdentityGraphBuilderBase
-{
-  void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
-};
-
-template <>
-struct IdentityGraphBuilderImpl<ImportTarget::TensorFlow> final : public IdentityGraphBuilderBase
-{
   void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
 };