[moco] Set name for TFNode when importing (#8933)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 14 Nov 2019 03:22:02 +0000 (12:22 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 14 Nov 2019 03:22:02 +0000 (12:22 +0900)
This will set name for TFNode when importing from TensorFlow NodeDef

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
30 files changed:
compiler/moco/import/src/Nodes/Add.cpp
compiler/moco/import/src/Nodes/AvgPool.cpp
compiler/moco/import/src/Nodes/BiasAdd.cpp
compiler/moco/import/src/Nodes/Concat.cpp
compiler/moco/import/src/Nodes/Const.cpp
compiler/moco/import/src/Nodes/Conv2D.cpp
compiler/moco/import/src/Nodes/Conv2DBackpropInput.cpp
compiler/moco/import/src/Nodes/DepthwiseConv2dNative.cpp
compiler/moco/import/src/Nodes/FakeQuantWithMinMaxVars.cpp
compiler/moco/import/src/Nodes/FusedBatchNorm.cpp
compiler/moco/import/src/Nodes/Identity.cpp
compiler/moco/import/src/Nodes/MaxPool.cpp
compiler/moco/import/src/Nodes/Mean.cpp
compiler/moco/import/src/Nodes/Mul.cpp
compiler/moco/import/src/Nodes/Pad.cpp
compiler/moco/import/src/Nodes/Placeholder.cpp
compiler/moco/import/src/Nodes/RealDiv.cpp
compiler/moco/import/src/Nodes/Relu.cpp
compiler/moco/import/src/Nodes/Relu6.cpp
compiler/moco/import/src/Nodes/Reshape.cpp
compiler/moco/import/src/Nodes/Rsqrt.cpp
compiler/moco/import/src/Nodes/Shape.cpp
compiler/moco/import/src/Nodes/Softmax.cpp
compiler/moco/import/src/Nodes/Sqrt.cpp
compiler/moco/import/src/Nodes/SquaredDifference.cpp
compiler/moco/import/src/Nodes/Squeeze.cpp
compiler/moco/import/src/Nodes/StopGradient.cpp
compiler/moco/import/src/Nodes/Sub.cpp
compiler/moco/import/src/Nodes/Tanh.cpp
compiler/moco/import/src/TestHelper.test.cpp

index f7e599b..f0c4e2f 100644 (file)
@@ -72,6 +72,7 @@ void AddGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext
 
   // creating TF dialect Add node
   auto tf_add = graph->nodes()->create<TFAdd>();
+  tf_add->name(node.name());
 
   TensorName output_name(node.name(), 0);
   tensor_names->enroll(output_name, tf_add);
index 219940f..bdf56c9 100644 (file)
@@ -88,6 +88,7 @@ void AvgPoolGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
   // tensorflow data_format: one of NHWC or NCHW.
   auto data_layout = get_string_attr(node, "data_format");
   auto avgPool_node = graph->nodes()->create<TFAvgPool>();
+  avgPool_node->name(node.name());
   avgPool_node->data_layout(data_layout);
 
   // padding
index 4811cea..1402c94 100644 (file)
@@ -96,7 +96,7 @@ void BiasAddGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
   // tensorflow data_format: one of NHWC or NCHW.
   auto data_layout = plier::tf::get_string_attr(node, "data_format");
   auto tf_bias_add = graph->nodes()->create<TFBiasAdd>();
-
+  tf_bias_add->name(node.name());
   tf_bias_add->data_layout(data_layout);
 
   // To set the input node of encode_node with biasAdd_name
index 775a6c2..386f9b7 100644 (file)
@@ -91,6 +91,7 @@ void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node,
   const int num_inputs = node.input_size() - 1;
   std::vector<TensorName> input_names;
   auto concat_node = graph->nodes()->create<TFConcatV2>(num_inputs);
+  concat_node->name(node.name());
 
   for (int ni = 0; ni < num_inputs; ++ni)
   {
index 7454b35..2e5bc01 100644 (file)
@@ -116,6 +116,7 @@ void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
 
   // Create a "TFConstant" node for Const
   auto const_node = graph->nodes()->create<TFConst>();
+  const_node->name(node.name());
 
   // set dtype
   auto dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "dtype"));
index 4341273..46b55f5 100644 (file)
@@ -95,6 +95,7 @@ void Conv2DGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCont
   std::string conv2d_name = node.name();
 
   auto conv2d = graph->nodes()->create<TFConv2D>();
+  conv2d->name(node.name());
 
   // read attributes
   auto data_layout = plier::tf::get_string_attr(node, "data_format");
index a0036df..e0e5a3b 100644 (file)
@@ -101,6 +101,7 @@ void Conv2DBackpropInputGraphBuilder::build(const tensorflow::NodeDef &node,
   std::string conv2d_backprop_name = node.name();
 
   auto conv2d_backprop = graph->nodes()->create<TFConv2DBackpropInput>();
+  conv2d_backprop->name(node.name());
 
   // read attributes
   auto data_layout = plier::tf::get_string_attr(node, "data_format");
index 8ef44cc..a549365 100644 (file)
@@ -105,6 +105,7 @@ void DepthwiseConv2dNativeGraphBuilder::build(const tensorflow::NodeDef &node,
   UpdateQueue *updates = context->updates();
 
   auto depthwiseconv2d_native_node = graph->nodes()->create<TFDepthwiseConv2dNative>();
+  depthwiseconv2d_native_node->name(node.name());
 
   // read attributes
   auto data_layout = get_string_attr(node, "data_format");
index 15b07e7..c24fe9f 100644 (file)
@@ -88,6 +88,7 @@ void FakeQuantWithMinMaxVarsGraphBuilder::build(const tensorflow::NodeDef &node,
   UpdateQueue *updates = context->updates();
 
   auto fakequant_node = graph->nodes()->create<TFFakeQuantWithMinMaxVars>();
+  fakequant_node->name(node.name());
 
   // read optional attributes
   if (has_attr(node, "num_bits"))
index 3dd7f07..6644ae6 100644 (file)
@@ -82,6 +82,7 @@ void FusedBatchNormGraphBuilder::build(const tensorflow::NodeDef &node,
 
   // creating TF dialect FusedBatchNorm node
   auto tf_fbn = graph->nodes()->create<TFFusedBatchNorm>();
+  tf_fbn->name(node.name());
   tf_fbn->epsilon(epsilon);
 
   TensorName output_name(node.name(), 0);
index 650fa66..d6c4219 100644 (file)
@@ -75,6 +75,7 @@ void IdentityGraphBuilder::build(const tensorflow::NodeDef &node,
 
   // Create a Identity node
   auto identity_node = graph->nodes()->create<TFIdentity>();
+  identity_node->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 976284f..eab7029 100644 (file)
@@ -82,6 +82,7 @@ void MaxPoolGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
   // tensorflow data_format: one of NHWC or NCHW.
   auto data_layout = plier::tf::get_string_attr(node, "data_format");
   auto maxPool_node = graph->nodes()->create<TFMaxPool>();
+  maxPool_node->name(node.name());
   maxPool_node->data_layout(data_layout);
 
   // padding
index b7664b3..c9e3a26 100644 (file)
@@ -85,6 +85,7 @@ void MeanGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContex
 
   // creating TF dialect Mean node
   auto tf_mean = graph->nodes()->create<TFMean>();
+  tf_mean->name(node.name());
   tf_mean->keep_dims(plier::tf::get_bool_attr(node, "keep_dims"));
 
   TensorName output_name(node.name(), 0);
index 8176ee4..1b61551 100644 (file)
@@ -72,6 +72,7 @@ void MulGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext
 
   // creating TF dialect Mul node
   auto tf_mul = graph->nodes()->create<TFMul>();
+  tf_mul->name(node.name());
 
   TensorName output_name(node.name(), 0);
   tensor_names->enroll(output_name, tf_mul);
index 1798cac..4e48f8f 100644 (file)
@@ -73,6 +73,7 @@ void PadGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext
 
   // creating TF dialect Pad node
   auto tf_pad = graph->nodes()->create<TFPad>();
+  tf_pad->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 6b2e732..77a5be4 100644 (file)
@@ -51,7 +51,7 @@ void PlaceholderGraphBuilder::build(const tensorflow::NodeDef &node,
 
   // Create a "Placeholder" node as an input
   auto placeholder_node = graph->nodes()->create<moco::TFPlaceholder>();
-
+  placeholder_node->name(node.name());
   placeholder_node->dtype(dtype);
 
   // Setting shape info.
index d5114ec..237cff2 100644 (file)
@@ -73,6 +73,7 @@ void RealDivGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
 
   // creating TF dialect RealDiv node
   auto tf_div = graph->nodes()->create<TFRealDiv>();
+  tf_div->name(node.name());
 
   TensorName output_name(node.name(), 0);
   tensor_names->enroll(output_name, tf_div);
index 1736eba..f57fe6d 100644 (file)
@@ -72,6 +72,7 @@ void ReluGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContex
 
   // Create a "TFRelu" node for Relu
   auto relu_node = graph->nodes()->create<TFRelu>();
+  relu_node->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 9fbbf1d..b71d5e2 100644 (file)
@@ -66,6 +66,7 @@ void Relu6GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
 
   // Create a "TFRelu6" node for Relu
   auto relu_node = graph->nodes()->create<TFRelu6>();
+  relu_node->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 60bc72d..1e7f1a7 100644 (file)
@@ -83,6 +83,7 @@ void ReshapeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
   std::string reshape_name = node.name();
 
   auto reshape = graph->nodes()->create<TFReshape>();
+  reshape->name(node.name());
 
   // save the name for graph link updates
   TensorName output_name(reshape_name, 0);
index 87eccae..b1dcff7 100644 (file)
@@ -69,6 +69,7 @@ void RsqrtGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
 
   // creating TF dialect Rsqrt node
   auto tf_rsqrt = graph->nodes()->create<TFRsqrt>();
+  tf_rsqrt->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index b055b87..d295bcd 100644 (file)
@@ -73,6 +73,7 @@ void ShapeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
 
   // create TF dialect Shape node
   auto tf_shape = graph->nodes()->create<TFShape>();
+  tf_shape->name(node.name());
 
   if (plier::tf::has_attrs(node, {"out_type"}))
   {
index e606ef7..fa04797 100644 (file)
@@ -73,6 +73,7 @@ void SoftmaxGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
 
   // creating TF dialect Softmax node
   auto tf_softmax = graph->nodes()->create<TFSoftmax>();
+  tf_softmax->name(node.name());
 
   TensorName output_name(node.name(), 0);
   tensor_names->enroll(output_name, tf_softmax);
index 6880054..f8b97c0 100644 (file)
@@ -69,6 +69,7 @@ void SqrtGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContex
 
   // creating TF dialect Sqrt node
   auto tf_sqrt = graph->nodes()->create<TFSqrt>();
+  tf_sqrt->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 5f80128..c86d3c3 100644 (file)
@@ -76,6 +76,7 @@ void SquaredDifferenceGraphBuilder::build(const tensorflow::NodeDef &node,
 
   // creating TF dialect SquaredDifference node
   auto tf_sqdiff = graph->nodes()->create<TFSquaredDifference>();
+  tf_sqdiff->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 90d189b..249673f 100644 (file)
@@ -91,6 +91,7 @@ void SqueezeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon
 
   // creating TF dialect Squeeze node
   auto tf_squeeze = graph->nodes()->create<TFSqueeze>();
+  tf_squeeze->name(node.name());
   tf_squeeze->squeeze_dims(squeeze_dims);
 
   TensorName output_name(node.name(), 0);
index 9b924b8..fafdfa4 100644 (file)
@@ -71,6 +71,7 @@ void StopGradientGraphBuilder::build(const tensorflow::NodeDef &node,
 
   // creating TF dialect StopGradient node
   auto tf_stopgradient = graph->nodes()->create<TFStopGradient>();
+  tf_stopgradient->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 8cd4f08..d1cd4ba 100644 (file)
@@ -72,6 +72,7 @@ void SubGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext
 
   // creating TF dialect Sub node
   auto tf_sub = graph->nodes()->create<TFSub>();
+  tf_sub->name(node.name());
 
   TensorName output_name(node.name(), 0);
   tensor_names->enroll(output_name, tf_sub);
index 4ab2c4f..aecfeb7 100644 (file)
@@ -69,6 +69,7 @@ void TanhGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContex
 
   // creating TF dialect Tanh node
   auto tf_tanh = graph->nodes()->create<TFTanh>();
+  tf_tanh->name(node.name());
 
   // register string-name to node
   TensorName output_name(node.name(), 0);
index 6d60cd5..b32ee07 100644 (file)
@@ -74,6 +74,7 @@ void TFNodeBuildTester::run(tensorflow::NodeDef &nodedef, moco::GraphBuilder &gr
 
   auto tfnode = output();
   ASSERT_NE(tfnode, nullptr);
+  ASSERT_STREQ(tfnode->name().c_str(), _output);
 
   int idx = 0;
   ASSERT_EQ(tfnode->arity(), _inputs.size());