2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PackCommand.hpp"
18 #include "Support.hpp"
20 #include <tensorflow/core/framework/graph.pb.h>
22 #include <google/protobuf/io/coded_stream.h>
23 #include <google/protobuf/io/zero_copy_stream_impl.h>
24 #include <google/protobuf/text_format.h>
33 template <typename T> void pack(tensorflow::TensorProto *);
35 template <> void pack<float>(tensorflow::TensorProto *input_tensor)
37 const auto &input_shape = input_tensor->tensor_shape();
38 assert(input_shape.dim_size() <= 6);
39 int input_flat_size = tfkit::tf::GetElementCount(input_shape);
41 // Adjust where shape is not set but actual value exist
42 if (input_tensor->float_val().size() > 0 && input_flat_size == -1)
44 input_flat_size = input_tensor->float_val().size();
47 if (input_tensor->float_val().size() == 0)
49 // There may be tensor_content and we don't need to do anything as it is
50 // already packed format
52 else if (input_tensor->float_val().size() == input_flat_size)
54 input_tensor->clear_tensor_content();
56 std::vector<float> tensor_content;
57 for (int i = 0; i < input_flat_size; ++i)
59 tensor_content.push_back(input_tensor->float_val(i));
62 input_tensor->set_tensor_content(std::string(
63 reinterpret_cast<const char *>(tensor_content.data()), sizeof(float) * input_flat_size));
65 input_tensor->clear_float_val();
69 throw std::runtime_error{"Number of elements mismatch in pack<float>."};
70 // TODO: support for these
74 template <> void pack<int32_t>(tensorflow::TensorProto *input_tensor)
76 const auto &input_shape = input_tensor->tensor_shape();
77 assert(input_shape.dim_size() <= 6);
78 int input_flat_size = tfkit::tf::GetElementCount(input_shape);
80 // Adjust where shape is not set but actual value exist
81 if (input_tensor->int_val().size() > 0 && input_flat_size == -1)
83 input_flat_size = input_tensor->int_val().size();
86 if (input_tensor->int_val().size() == 0)
88 // There may be tensor_content and we don't need to do anything as it is
89 // already packed format
91 else if (input_tensor->int_val().size() == input_flat_size)
93 input_tensor->clear_tensor_content();
95 std::vector<int32_t> tensor_content;
96 for (int i = 0; i < input_flat_size; ++i)
98 tensor_content.push_back(input_tensor->int_val(i));
101 input_tensor->set_tensor_content(std::string(
102 reinterpret_cast<const char *>(tensor_content.data()), sizeof(int32_t) * input_flat_size));
104 input_tensor->clear_int_val();
108 throw std::runtime_error{"Number of elements mismatch in pack<int32_t>."};
109 // TODO: support for these
113 void pack(tensorflow::GraphDef &graph_def)
115 auto nodes = graph_def.mutable_node();
116 for (int i = 0; i < nodes->size(); ++i)
118 tensorflow::NodeDef *n = nodes->Mutable(i);
119 // TODO: handle for other operators
120 if (n->op() == "Const")
122 const auto dtype = tfkit::tf::GetDataTypeAttr(*n, "dtype");
123 tensorflow::TensorProto *tensor = tfkit::tf::GetTensorAttr(*n, "value");
127 case tensorflow::DT_FLOAT:
130 case tensorflow::DT_INT32:
131 pack<int32_t>(tensor);
134 throw std::runtime_error{"Unsupported dtype"};
145 int PackCommand::run(int argc, const char *const *argv) const
147 tensorflow::GraphDef graph_def;
149 CmdArguments cmdargs(argc, argv);
151 auto ioconfig = make_ioconfig(cmdargs);
153 google::protobuf::io::IstreamInputStream is{ioconfig->in()};
155 if (!google::protobuf::TextFormat::Parse(&is, &graph_def))
157 std::cerr << "ERROR: Failed to parse prototxt" << std::endl;
161 // convert float_val to tensor_content
164 google::protobuf::io::OstreamOutputStream os{ioconfig->out()};
165 google::protobuf::TextFormat::Print(graph_def, &os);