From 410bd4271bfc0d134aa0a6c56db3cb2e7a7d6375 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 28 Mar 2019 09:20:27 +0900 Subject: [PATCH] [tfkit] unpack float data type for Const nodes (#3125) * [tfkit] unpack float data type for Const nodes This will unpack all the float values in Const nodes in the graph that will change 'tensor_content' to 'float_val' Signed-off-by: SaeHie Park * remove unused header * fix namespace --- contrib/tfkit/src/UnpackCommand.cpp | 56 +++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/contrib/tfkit/src/UnpackCommand.cpp b/contrib/tfkit/src/UnpackCommand.cpp index 1df9e13..13d6e1e 100644 --- a/contrib/tfkit/src/UnpackCommand.cpp +++ b/contrib/tfkit/src/UnpackCommand.cpp @@ -15,6 +15,7 @@ */ #include "UnpackCommand.hpp" +#include "Support.hpp" #include @@ -22,9 +23,60 @@ #include #include +#include #include #include +namespace +{ + +template void unpack(tensorflow::TensorProto *); + +template <> void unpack(tensorflow::TensorProto *input_tensor) +{ + const auto &input_shape = input_tensor->tensor_shape(); + assert(input_shape.dim_size() <= 6); + int input_flat_size = tfkit::tf::GetElementCount(input_shape); + + assert(input_tensor->tensor_content().size() == input_flat_size * sizeof(float)); + + input_tensor->clear_float_val(); + + const float *tensor_content = + reinterpret_cast(input_tensor->tensor_content().data()); + for (int i = 0; i < input_flat_size; i++) + { + input_tensor->add_float_val(tensor_content[i]); + } + input_tensor->clear_tensor_content(); +} + +void unpack(tensorflow::GraphDef &graph_def) +{ + auto nodes = graph_def.mutable_node(); + for (int i = 0; i < nodes->size(); ++i) + { + tensorflow::NodeDef *n = nodes->Mutable(i); + // TODO: handle for other operators + if (n->op() == "Const") + { + const auto dtype = tfkit::tf::GetDataTypeAttr(*n, "dtype"); + tensorflow::TensorProto *tensor = tfkit::tf::GetTensorAttr(*n, "value"); + + switch (dtype) + { + case tensorflow::DT_FLOAT: + unpack(tensor); + break; + default: + throw std::runtime_error{"Unsupported dtype"}; + } + } + } +} + +} // namespace + namespace tfkit { @@ -41,8 +93,8 @@ int UnpackCommand::run(int, const char *const *) const return 255; } - // TODO: add convert tensor_content to float_val - throw std::runtime_error{"unpack command under development"}; + // convert tensor_content to float_val + unpack(graph_def); // Write text into standard output google::protobuf::io::OstreamOutputStream os{&std::cout}; -- 2.7.4