[tfkit] unpack float data type for Const nodes (#3125)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 28 Mar 2019 00:20:27 +0000 (09:20 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 28 Mar 2019 00:20:27 +0000 (09:20 +0900)
* [tfkit] unpack float data type for Const nodes

This will unpack all the float values in Const nodes in the graph that will change 'tensor_content' to 'float_val'

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* remove unused header

* fix namespace

contrib/tfkit/src/UnpackCommand.cpp

index 1df9e13..13d6e1e 100644 (file)
@@ -15,6 +15,7 @@
  */
 
 #include "UnpackCommand.hpp"
+#include "Support.hpp"
 
 #include <tensorflow/core/framework/graph.pb.h>
 
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include <google/protobuf/text_format.h>
 
+#include <cassert>
 #include <iostream>
 #include <stdexcept>
 
+namespace
+{
+
+template <typename T> void unpack(tensorflow::TensorProto *);
+
+template <> void unpack<float>(tensorflow::TensorProto *input_tensor)
+{
+  const auto &input_shape = input_tensor->tensor_shape();
+  assert(input_shape.dim_size() <= 6);
+  int input_flat_size = tfkit::tf::GetElementCount(input_shape);
+
+  assert(input_tensor->tensor_content().size() == input_flat_size * sizeof(float));
+
+  input_tensor->clear_float_val();
+
+  const float *tensor_content =
+      reinterpret_cast<const float *>(input_tensor->tensor_content().data());
+  for (int i = 0; i < input_flat_size; i++)
+  {
+    input_tensor->add_float_val(tensor_content[i]);
+  }
+  input_tensor->clear_tensor_content();
+}
+
+void unpack(tensorflow::GraphDef &graph_def)
+{
+  auto nodes = graph_def.mutable_node();
+  for (int i = 0; i < nodes->size(); ++i)
+  {
+    tensorflow::NodeDef *n = nodes->Mutable(i);
+    // TODO: handle for other operators
+    if (n->op() == "Const")
+    {
+      const auto dtype = tfkit::tf::GetDataTypeAttr(*n, "dtype");
+      tensorflow::TensorProto *tensor = tfkit::tf::GetTensorAttr(*n, "value");
+
+      switch (dtype)
+      {
+      case tensorflow::DT_FLOAT:
+        unpack<float>(tensor);
+        break;
+      default:
+        throw std::runtime_error{"Unsupported dtype"};
+      }
+    }
+  }
+}
+
+} // namespace
+
 namespace tfkit
 {
 
@@ -41,8 +93,8 @@ int UnpackCommand::run(int, const char *const *) const
     return 255;
   }
 
-  // TODO: add convert tensor_content to float_val
-  throw std::runtime_error{"unpack command under development"};
+  // convert tensor_content to float_val
+  unpack(graph_def);
 
   // Write text into standard output
   google::protobuf::io::OstreamOutputStream os{&std::cout};