Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / tfkit / src / PackCommand.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PackCommand.hpp"
18 #include "Support.hpp"
19
20 #include <tensorflow/core/framework/graph.pb.h>
21
22 #include <google/protobuf/io/coded_stream.h>
23 #include <google/protobuf/io/zero_copy_stream_impl.h>
24 #include <google/protobuf/text_format.h>
25
26 #include <cassert>
27 #include <stdexcept>
28 #include <vector>
29
30 namespace
31 {
32
33 template <typename T> void pack(tensorflow::TensorProto *);
34
35 template <> void pack<float>(tensorflow::TensorProto *input_tensor)
36 {
37   const auto &input_shape = input_tensor->tensor_shape();
38   assert(input_shape.dim_size() <= 6);
39   int input_flat_size = tfkit::tf::GetElementCount(input_shape);
40
41   // Adjust where shape is not set but actual value exist
42   if (input_tensor->float_val().size() > 0 && input_flat_size == -1)
43   {
44     input_flat_size = input_tensor->float_val().size();
45   }
46
47   if (input_tensor->float_val().size() == 0)
48   {
49     // There may be tensor_content and we don't need to do anything as it is
50     // already packed format
51   }
52   else if (input_tensor->float_val().size() == input_flat_size)
53   {
54     input_tensor->clear_tensor_content();
55
56     std::vector<float> tensor_content;
57     for (int i = 0; i < input_flat_size; ++i)
58     {
59       tensor_content.push_back(input_tensor->float_val(i));
60     }
61
62     input_tensor->set_tensor_content(std::string(
63         reinterpret_cast<const char *>(tensor_content.data()), sizeof(float) * input_flat_size));
64
65     input_tensor->clear_float_val();
66   }
67   else
68   {
69     throw std::runtime_error{"Number of elements mismatch in pack<float>."};
70     // TODO: support for these
71   }
72 }
73
74 template <> void pack<int32_t>(tensorflow::TensorProto *input_tensor)
75 {
76   const auto &input_shape = input_tensor->tensor_shape();
77   assert(input_shape.dim_size() <= 6);
78   int input_flat_size = tfkit::tf::GetElementCount(input_shape);
79
80   // Adjust where shape is not set but actual value exist
81   if (input_tensor->int_val().size() > 0 && input_flat_size == -1)
82   {
83     input_flat_size = input_tensor->int_val().size();
84   }
85
86   if (input_tensor->int_val().size() == 0)
87   {
88     // There may be tensor_content and we don't need to do anything as it is
89     // already packed format
90   }
91   else if (input_tensor->int_val().size() == input_flat_size)
92   {
93     input_tensor->clear_tensor_content();
94
95     std::vector<int32_t> tensor_content;
96     for (int i = 0; i < input_flat_size; ++i)
97     {
98       tensor_content.push_back(input_tensor->int_val(i));
99     }
100
101     input_tensor->set_tensor_content(std::string(
102         reinterpret_cast<const char *>(tensor_content.data()), sizeof(int32_t) * input_flat_size));
103
104     input_tensor->clear_int_val();
105   }
106   else
107   {
108     throw std::runtime_error{"Number of elements mismatch in pack<int32_t>."};
109     // TODO: support for these
110   }
111 }
112
113 void pack(tensorflow::GraphDef &graph_def)
114 {
115   auto nodes = graph_def.mutable_node();
116   for (int i = 0; i < nodes->size(); ++i)
117   {
118     tensorflow::NodeDef *n = nodes->Mutable(i);
119     // TODO: handle for other operators
120     if (n->op() == "Const")
121     {
122       const auto dtype = tfkit::tf::GetDataTypeAttr(*n, "dtype");
123       tensorflow::TensorProto *tensor = tfkit::tf::GetTensorAttr(*n, "value");
124
125       switch (dtype)
126       {
127         case tensorflow::DT_FLOAT:
128           pack<float>(tensor);
129           break;
130         case tensorflow::DT_INT32:
131           pack<int32_t>(tensor);
132           break;
133         default:
134           throw std::runtime_error{"Unsupported dtype"};
135       }
136     }
137   }
138 }
139
140 } // namespace
141
142 namespace tfkit
143 {
144
145 int PackCommand::run(int argc, const char *const *argv) const
146 {
147   tensorflow::GraphDef graph_def;
148
149   CmdArguments cmdargs(argc, argv);
150
151   auto ioconfig = make_ioconfig(cmdargs);
152
153   google::protobuf::io::IstreamInputStream is{ioconfig->in()};
154
155   if (!google::protobuf::TextFormat::Parse(&is, &graph_def))
156   {
157     std::cerr << "ERROR: Failed to parse prototxt" << std::endl;
158     return 255;
159   }
160
161   // convert float_val to tensor_content
162   pack(graph_def);
163
164   google::protobuf::io::OstreamOutputStream os{ioconfig->out()};
165   google::protobuf::TextFormat::Print(graph_def, &os);
166
167   return 0;
168 }
169
170 } // namespace tfkit