[moco/tf] Import as TFConst (#4261)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 15 Jul 2019 07:41:28 +0000 (16:41 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 15 Jul 2019 07:41:28 +0000 (16:41 +0900)
This will introduce Knbo for import as TFConst and the changes to import Const as TFConst node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
contrib/moco-tf/src/Knob.lst
contrib/moco-tf/src/Op/Const.cpp

index 1e9abae..234d4b6 100644 (file)
@@ -6,6 +6,7 @@
 
 // Imports
 KNOB_BOOL(ImportAsTFBiasAdd, false, Import BiasAdd node as TFBiasAdd node)
+KNOB_BOOL(ImportAsTFConst, false, Import Const node as TFConst node)
 KNOB_BOOL(ImportAsTFConv2D, false, Import Conv2D node as TFConv2D node)
 
 // TensorFlow dialect transforms
index 36035a7..5047f47 100644 (file)
@@ -17,6 +17,9 @@
 #include "Convert.h"
 #include "GraphBuilder.h"
 #include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFConst.h"
 
 #include <moco/tf/Names.h>
 #include <loco.h>
@@ -100,6 +103,79 @@ void read_value_float32(loco::ConstGen *const_node, int num_elements,
 
 } // namespace
 
+namespace
+{
+
+void read_value_int32(moco::tf::TFConst *const_node, int num_elements,
+                      const tensorflow::TensorProto &input_tensor)
+{
+  const_node->size<loco::DataType::S32>(num_elements);
+
+  int32_t input_elements = input_tensor.int_val_size();
+
+  if (input_tensor.tensor_content().size() == num_elements * sizeof(int32_t))
+  {
+    const std::string &str_content = input_tensor.tensor_content();
+    const int32_t *s32_ptr = reinterpret_cast<const int32_t *>(str_content.c_str());
+    for (int32_t i = 0; i < num_elements; i++)
+    {
+      const_node->at<loco::DataType::S32>(i) = *(s32_ptr + i);
+    }
+  }
+  else if (0 < input_elements && input_elements <= num_elements)
+  {
+    for (int32_t i = 0; i < input_elements; i++)
+    {
+      const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(i);
+    }
+
+    for (int32_t i = input_elements; i < num_elements; i++)
+    {
+      const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(input_elements - 1);
+    }
+  }
+  else
+  {
+    throw std::runtime_error("Error: Invalid Const values");
+  }
+}
+
+void read_value_float32(moco::tf::TFConst *const_node, int num_elements,
+                        const tensorflow::TensorProto &input_tensor)
+{
+  const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+  int32_t input_elements = input_tensor.float_val_size();
+
+  if (input_tensor.tensor_content().size() == num_elements * sizeof(float))
+  {
+    const std::string &str_content = input_tensor.tensor_content();
+    const float *float_ptr = reinterpret_cast<const float *>(str_content.c_str());
+    for (int32_t i = 0; i < num_elements; i++)
+    {
+      const_node->at<loco::DataType::FLOAT32>(i) = *(float_ptr + i);
+    }
+  }
+  else if (0 < input_elements && input_elements <= num_elements)
+  {
+    for (int32_t i = 0; i < input_elements; i++)
+    {
+      const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+    }
+
+    for (int32_t i = input_elements; i < num_elements; i++)
+    {
+      const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(input_elements - 1);
+    }
+  }
+  else
+  {
+    throw std::runtime_error("Error: Invalid Const values");
+  }
+}
+
+} // namespace
+
 namespace moco
 {
 namespace tf
@@ -113,6 +189,10 @@ class ConstGraphBuilder final : public GraphBuilder
 public:
   bool validate(const tensorflow::NodeDef &) const override;
   void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+
+private:
+  void buildCanonical(const tensorflow::NodeDef &, GraphBuilderContext *) const;
+  void buildTF(const tensorflow::NodeDef &, GraphBuilderContext *) const;
 };
 
 bool ConstGraphBuilder::validate(const tensorflow::NodeDef &node) const
@@ -124,6 +204,15 @@ void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
 {
   assert(context != nullptr);
 
+  if (moco::tf::get<moco::tf::Knob::ImportAsTFConst>())
+    buildTF(node, context);
+  else
+    buildCanonical(node, context);
+}
+
+void ConstGraphBuilder::buildCanonical(const tensorflow::NodeDef &node,
+                                       GraphBuilderContext *context) const
+{
   loco::Graph *graph = context->graph();
   SymbolTable *tensor_names = context->tensor_names();
 
@@ -190,6 +279,74 @@ void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderConte
   tensor_names->enroll(output_name, const_node);
 }
 
+void ConstGraphBuilder::buildTF(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+  loco::Graph *graph = context->graph();
+  SymbolTable *tensor_names = context->tensor_names();
+
+  // Create a "TFConstant" node for Const
+  auto const_node = graph->nodes()->create<moco::tf::TFConst>();
+
+  // set dtype
+  auto dtype = as_loco_datatype(get_datatype_attr(node, "dtype"));
+  const_node->dtype(dtype);
+
+  // import shape and value
+  const auto &input_tensor = get_tensor_attr(node, "value");
+  const auto &input_shape = input_tensor.tensor_shape();
+  const auto &input_dims = input_shape.dim();
+  assert(input_shape.dim_size() <= 6);
+  const_node->rank(input_shape.dim_size());
+  int index = 0;
+  bool zero_sized_shape = false;
+  for (auto &d : input_dims)
+  {
+    if (d.size() > std::numeric_limits<int>::max())
+      throw std::runtime_error("Shape element overflows");
+    if (d.size() == 0)
+      zero_sized_shape = true;
+
+    if (d.size() >= 0)
+      const_node->dim(index++) = d.size();
+    else
+      throw std::runtime_error{"Error: Unknown dim size for " + node.name()};
+  }
+
+  int num_elements = 1;
+  if (zero_sized_shape)
+  {
+    const_node->rank(0);
+    num_elements = 0;
+  }
+  else
+  {
+    for (int d = 0; d < const_node->rank(); d++)
+    {
+      num_elements *= const_node->dim(d).value();
+    }
+  }
+
+  switch (dtype)
+  {
+  case loco::DataType::S32:
+    read_value_int32(const_node, num_elements, input_tensor);
+    break;
+
+  case loco::DataType::FLOAT32:
+    read_value_float32(const_node, num_elements, input_tensor);
+    break;
+
+  // TODO support other types
+
+  default:
+    throw std::runtime_error{"Error: Unsupported data type for " + node.name()};
+  }
+
+  // register string-name to node
+  TensorName output_name(node.name(), 0);
+  tensor_names->enroll(output_name, const_node);
+}
+
 } // namespace tf
 } // namespace moco