[moco] import Const attributes and value data (#3487)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 16 May 2019 01:54:15 +0000 (10:54 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 16 May 2019 01:54:15 +0000 (10:54 +0900)
This will enable import for Const attributes and data where there are same number of const value data.
Add test code for this case.

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco/lib/frontend/tf/src/Op/Const.cpp
contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp [new file with mode: 0644]

index d98e0f9..e460529 100644 (file)
 #include <cassert>
 #include <stdexcept>
 
+namespace
+{
+
+void read_value_float32(loco::ConstGen *const_node, int num_elements,
+                        const tensorflow::TensorProto &input_tensor)
+{
+  const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+  int32_t input_elements = input_tensor.float_val_size();
+
+  if (input_elements == num_elements)
+  {
+    for (int32_t i = 0; i < input_elements; i++)
+    {
+      const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+    }
+  }
+  else
+  {
+    throw std::runtime_error("Error: Invalid Const values");
+  }
+}
+
+} // namespace
+
 namespace moco
 {
 namespace tf
@@ -60,11 +85,55 @@ void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
   // Create a "ConstGen" node for Const
   auto const_node = graph->nodes()->create<loco::ConstGen>();
 
-  // TODO fill attributes
+  // set dtype
+  auto dtype = as_loco_datatype(get_datatype_attr(node, "dtype"));
+  const_node->dtype(dtype);
+
+  // import shape and value
+  const auto &input_tensor = get_tensor_attr(node, "value");
+  const auto &input_shape = input_tensor.tensor_shape();
+  const auto &input_dims = input_shape.dim();
+  assert(input_shape.dim_size() <= 6);
+  const_node->rank(input_shape.dim_size());
+  int index = 0;
+  bool zero_sized_shape = false;
+  for (auto &d : input_dims)
+  {
+    if (d.size() > std::numeric_limits<int>::max())
+      throw std::runtime_error("Shape element overflows");
+    if (d.size() == 0)
+      zero_sized_shape = true;
+
+    const_node->dim(index++) = loco::make_dimension(d.size());
+  }
+
+  int num_elements = 1;
+  if (zero_sized_shape)
+  {
+    const_node->rank(0);
+    num_elements = 0;
+  }
+  else
+  {
+    for (int d = 0; d < const_node->rank(); d++)
+    {
+      num_elements *= const_node->dim(d).value();
+    }
+  }
+
+  switch (dtype)
+  {
+  case loco::DataType::FLOAT32:
+    read_value_float32(const_node, num_elements, input_tensor);
+    break;
+
+  // TODO support other types
+
+  default:
+    throw std::runtime_error{"Error: Unsupported data type for " + node.name()};
+  }
 
   nodes->enroll(node.name(), const_node);
-
-  throw std::runtime_error{"NYI"};
 }
 
 } // namespace tf
diff --git a/contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp b/contrib/moco/lib/frontend/tf/src/Op/Const.test.cpp
new file mode 100644 (file)
index 0000000..5afff9b
--- /dev/null
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include <moco/tf/Frontend.h>
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// Test case for "input_tensor.float_val_size() == num_elements"
+
+// clang-format off
+const char *const_float_01_pbtxtdata = STRING_CONTENT(
+node {
+  name: "const/float"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+          dim {
+            size: 2
+          }
+          dim {
+            size: 3
+          }
+        }
+        float_val: 1.1
+        float_val: 2.2
+        float_val: 3.3
+        float_val: 4.4
+        float_val: 5.5
+        float_val: 6.6
+      }
+    }
+  }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, const_float_01)
+{
+  moco::tf::Frontend frontend;
+  moco::tf::ModelSignature signature;
+
+  imemstream mempb(const_float_01_pbtxtdata, std::strlen(const_float_01_pbtxtdata));
+
+  signature.add_output("const/float");
+
+  std::unique_ptr<loco::Graph> graph =
+      frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+
+  loco::Graph::OutputContext *outputs = graph->outputs();
+  ASSERT_EQ(outputs->size(), 1);
+  loco::GraphOutput *output = outputs->at(0);
+  loco::Push *push = output->node();
+
+  loco::Graph::NodeContext *nodes = graph->nodes();
+  ASSERT_EQ(nodes->size(), 2);
+  loco::ConstGen *node0 = dynamic_cast<loco::ConstGen *>(nodes->at(0));
+  ASSERT_NE(node0, nullptr);
+  loco::Push *node1 = dynamic_cast<loco::Push *>(nodes->at(1));
+  ASSERT_EQ(node1, push);
+
+  ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
+  ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
+}