[moco] import TensorFlow Const for one value data (#3491)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 16 May 2019 03:36:13 +0000 (12:36 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 16 May 2019 03:36:13 +0000 (12:36 +0900)
This will enable TensorFlow Const conversion where there is only one float value.
And add a test code for this case.

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/src/Op/Const.cpp
contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp

index e460529..bf003a8 100644 (file)
@@ -35,7 +35,14 @@ void read_value_float32(loco::ConstGen *const_node, int num_elements,
 
   int32_t input_elements = input_tensor.float_val_size();
 
-  if (input_elements == num_elements)
+  if (input_elements == 1)
+  {
+    for (int32_t i = 0; i < num_elements; i++)
+    {
+      const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(0);
+    }
+  }
+  else if (input_elements == num_elements)
   {
     for (int32_t i = 0; i < input_elements; i++)
     {
index 5afff9b..f90e894 100644 (file)
@@ -104,3 +104,74 @@ TEST(TensorFlowFrontend, const_float_01)
   ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
   ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
 }
+
+namespace
+{
+// Test case for "input_tensor.float_val_size() == 1"
+
+// clang-format off
+const char *const_float_02_pbtxtdata = STRING_CONTENT(
+node {
+  name: "const/float"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 1.1
+      }
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, const_float_02)
+{
+  moco::tf::Frontend frontend;
+  moco::tf::ModelSignature signature;
+
+  imemstream mempb(const_float_02_pbtxtdata, std::strlen(const_float_02_pbtxtdata));
+
+  signature.add_output("const/float");
+
+  std::unique_ptr<loco::Graph> graph =
+      frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+
+  loco::Graph::OutputContext *outputs = graph->outputs();
+  ASSERT_EQ(outputs->size(), 1);
+  loco::GraphOutput *output = outputs->at(0);
+  loco::Push *push = output->node();
+
+  loco::Graph::NodeContext *nodes = graph->nodes();
+  ASSERT_EQ(nodes->size(), 2);
+  loco::ConstGen *node0 = dynamic_cast<loco::ConstGen *>(nodes->at(0));
+  ASSERT_NE(node0, nullptr);
+  loco::Push *node1 = dynamic_cast<loco::Push *>(nodes->at(1));
+  ASSERT_EQ(node1, push);
+
+  ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 1.1f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 1.1f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 1.1f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 1.1f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 1.1f);
+}