[moco-tf] rewrite graphDef (#8397)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Thu, 24 Oct 2019 02:59:52 +0000 (11:59 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 24 Oct 2019 02:59:52 +0000 (11:59 +0900)
* [moco-tf] rewrite graphDef

This commit adds rewrite graphDef function to Frontend.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
* split one function into two.

* change function name.

* add for supporting zero-dim-size.

compiler/moco-tf/src/Frontend.cpp

index 4fba5d8..4f8edd8 100644 (file)
@@ -106,6 +106,60 @@ void load_tf(std::istream *stream, moco::tf::Frontend::FileType type,
   }
 }
 
+// If Placeholder has no shape attribute, set unknown_rank property to true.
+void set_unknown_rank(tensorflow::GraphDef &tf_graph_def)
+{
+  for (auto &n : *tf_graph_def.mutable_node())
+  {
+    if (n.op().compare("Placeholder"))
+      continue;
+
+    auto iter = n.attr().find("shape");
+    if (iter == n.attr().end())
+    {
+      tensorflow::AttrValue attr;
+      attr.mutable_shape()->set_unknown_rank(true);
+      n.mutable_attr()->insert({"shape", attr});
+    }
+  }
+}
+
+// If Placeholder's shape has unknown dimension or unknown rank, set it according to signature.
+void set_input_shape(const moco::tf::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def)
+{
+  for (auto &n : *tf_graph_def.mutable_node())
+  {
+    if (n.op().compare("Placeholder"))
+      continue;
+
+    auto node_shape = n.mutable_attr()->at("shape").mutable_shape();
+    auto sig_shape = signature.shape(n.name() + ":0");
+    if (node_shape->unknown_rank() || !node_shape->dim_size())
+    {
+      node_shape->clear_unknown_rank();
+      for (uint32_t i = 0; i < sig_shape->rank(); i++)
+        node_shape->add_dim()->set_size(-1);
+    }
+    for (uint32_t d = 0; d < node_shape->dim_size(); d++)
+    {
+      if (node_shape->mutable_dim(d)->size() == -1)
+      {
+        node_shape->mutable_dim(d)->set_size(sig_shape->dim(d));
+      }
+      else
+      {
+        assert(node_shape->dim(d).size() == sig_shape->dim(d));
+      }
+    }
+  }
+}
+
+void transform_tf(const moco::tf::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def)
+{
+  set_unknown_rank(tf_graph_def);
+  set_input_shape(signature, tf_graph_def);
+}
+
 /**
  * @brief Returns GraphBuilderRegistry that looks up default registry and additions
  *        such as custom op
@@ -158,6 +212,8 @@ std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, con
 
   load_tf(modelfile, type, tf_graph_def);
 
+  transform_tf(signature, tf_graph_def);
+
   auto graph = import(signature, tf_graph_def);
 
   return std::move(graph);
@@ -170,6 +226,8 @@ std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, std
 
   load_tf(stream, type, tf_graph_def);
 
+  transform_tf(signature, tf_graph_def);
+
   auto graph = import(signature, tf_graph_def);
 
   return std::move(graph);