[moco/tf] Import as TFConv2D (#4223)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 12 Jul 2019 05:49:54 +0000 (14:49 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 12 Jul 2019 05:49:54 +0000 (14:49 +0900)
This will enable import as TFConv2D

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Op/Conv2D.cpp

index 3f32a03..df47a5a 100644 (file)
 #include "GraphBuilderContext.h"
 #include "Knob.h"
 
+#include "IR/TFConv2D.h"
+
 #include "Annotations/PaddingData.h"
+#include "Annotations/PadData.h"
 
 #include <moco/tf/Names.h>
 
@@ -78,6 +81,31 @@ void KernelUpdate::input(const SymbolTable *node_table) const
   _ker_enc->input(ker_node);
 }
 
+class TFConv2DGraphUpdate final : public GraphUpdate
+{
+public:
+  TFConv2DGraphUpdate(TFConv2D *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+  void input(const SymbolTable *) const override;
+
+private:
+  TFConv2D *_node;
+  std::vector<TensorName> _names;
+};
+
+void TFConv2DGraphUpdate::input(const SymbolTable *node_table) const
+{
+  assert(_names.size() == 2);
+
+  auto ifm_node = node_table->node(_names[0]);
+  auto ker_node = node_table->node(_names[1]);
+  assert(ifm_node != nullptr);
+  assert(ker_node != nullptr);
+
+  _node->ifm(ifm_node);
+  _node->ker(ker_node);
+}
+
 } // namespace
 
 namespace moco
@@ -241,7 +269,43 @@ void Conv2DGraphBuilder::buildCanonical(const tensorflow::NodeDef &node,
 void Conv2DGraphBuilder::buildTF(const tensorflow::NodeDef &node,
                                  GraphBuilderContext *context) const
 {
-  throw std::runtime_error("NYI");
+  loco::Graph *graph = context->graph();
+  SymbolTable *tensor_names = context->tensor_names();
+  UpdateQueue *updates = context->updates();
+
+  // name of loco nodes
+  std::string conv2d_name = node.name();
+
+  auto conv2d = graph->nodes()->create<TFConv2D>();
+
+  // read attributes
+  auto data_layout = get_string_attr(node, "data_format");
+  if (!(data_layout == "NHWC" || data_layout == "NCHW"))
+  {
+    throw std::runtime_error("Not yet supported");
+  }
+  conv2d->data_layout(data_layout);
+
+  auto tf_strides = get_list_attr(node, "strides");
+  auto strides = as_int64_list(tf_strides);
+  conv2d->strides(strides);
+
+  auto padding = moco::str_toupper(get_string_attr(node, "padding"));
+  assert(padding == "VALID" || padding == "SAME");
+  conv2d->padding(padding);
+
+  // save the name for graph link updates
+  TensorName output_name(conv2d_name, 0);
+  tensor_names->enroll(output_name, conv2d);
+
+  std::vector<TensorName> input_names;
+  input_names.push_back(TensorName(node.input(0))); // input
+  input_names.push_back(TensorName(node.input(1))); // kernel
+
+  // Record ifm inputs to featureEncode_node
+  auto tfconv2d_update = stdex::make_unique<TFConv2DGraphUpdate>(conv2d, input_names);
+
+  updates->enroll(std::move(tfconv2d_update));
 }
 
 } // namespace tf