Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / libs / tflite / include / tflite / TensorLogger.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file     TensorLogger.h
19  * @brief    This file contains TensorLogger class
20  * @ingroup  COM_AI_RUNTIME
21  */
22
23 #ifndef __NNFW_TFLITE_TENSOR_LOGGER_H__
24 #define __NNFW_TFLITE_TENSOR_LOGGER_H__
25
26 #include "misc/tensor/IndexIterator.h"
27 #include "tflite/TensorView.h"
28
29 #include <tensorflow/lite/interpreter.h>
30 #include <tensorflow/lite/context.h>
31 #include <fstream>
32 #include <iomanip>
33
34 namespace nnfw
35 {
36 namespace tflite
37 {
38
39 /**
40  * @brief Class to write input and output value / shape into a file in python form
41  * @note This is a utility to write input and output value / shape into a file in python form.\n
42  *       any python app can load this value by running the python code below:\n
43  *       exec(open(filename).read())\n
44  *       generated python code looks like the following: \n
45  *       tensor_shape_gen = []\n
46  *       tensor_value_gen = []\n\n
47  *       tensor_shape_gen.append("{2, 1, 2}")\n
48  *       tensor_value_gen.append([1, 2, 3, 4])\n\n
49  *       tensor_shape_gen.append("{2}")\n
50  *       tensor_value_gen.append([1, 2])\n\n
51  *       tensor_shape_gen.append("{2, 1, 2}")\n
52  *       tensor_value_gen.append([1, 4, 3, 8])\n
53  */
54 class TensorLogger
55 {
56 private:
57   std::ofstream _outfile;
58
59 public:
60   /**
61    * @brief Get TensorLogger instance
62    * @return The TensorLogger instance
63    */
64   static TensorLogger &get()
65   {
66     static TensorLogger instance;
67     return instance;
68   }
69
70   /**
71    * @brief Save the tensor details to file from interpreter
72    * @param[in] path The file path to save
73    * @param[in] interp The TfLite interpreter
74    */
75   void save(const std::string &path, ::tflite::Interpreter &interp)
76   {
77     open(path);
78
79     int log_index = 0;
80     for (const auto id : interp.inputs())
81     {
82       _outfile << "# input tensors" << std::endl;
83       printTensor(interp, id, log_index++);
84     }
85     for (const auto id : interp.outputs())
86     {
87       _outfile << "# output tensors" << std::endl;
88       printTensor(interp, id, log_index++);
89     }
90     close();
91   }
92
93 private:
94   void open(const std::string &path)
95   {
96     if (!_outfile.is_open())
97       _outfile.open(path, std::ios_base::out);
98
99     _outfile << "# ------ file: " << path << " ------" << std::endl
100              << "tensor_shape_gen = []" << std::endl
101              << "tensor_value_gen = []" << std::endl
102              << std::endl;
103   }
104
105   void printTensor(::tflite::Interpreter &interp, const int id, const int log_index)
106   {
107     const TfLiteTensor *tensor = interp.tensor(id);
108
109     _outfile << "# tensor name: " << tensor->name << std::endl;
110     _outfile << "# tflite::interpreter.tensor(" << id << ") -> tensor_value_gen[" << log_index
111              << "]" << std::endl;
112
113     if (tensor->type == kTfLiteInt32)
114     {
115       printTensorShape(tensor);
116       printTensorValue<int32_t>(tensor, tensor->data.i32);
117     }
118     else if (interp.tensor(id)->type == kTfLiteUInt8)
119     {
120       printTensorShape(tensor);
121       printTensorValue<uint8_t>(tensor, tensor->data.uint8);
122     }
123     else if (tensor->type == kTfLiteFloat32)
124     {
125       printTensorShape(tensor);
126       printTensorValue<float>(tensor, tensor->data.f);
127     }
128   }
129
130   void printTensorShape(const TfLiteTensor *tensor)
131   {
132     _outfile << "tensor_shape_gen.append('{";
133
134     int r = 0;
135     for (; r < tensor->dims->size - 1; r++)
136     {
137       _outfile << tensor->dims->data[r] << ", ";
138     }
139     _outfile << tensor->dims->data[r];
140
141     _outfile << "}')" << std::endl;
142   }
143
144   template <typename T> void printTensorValue(const TfLiteTensor *tensor, T *tensor_data_ptr)
145   {
146     _outfile << "tensor_value_gen.append([";
147
148     _outfile << std::fixed << std::setprecision(10);
149
150     const T *end = reinterpret_cast<const T *>(tensor->data.raw_const + tensor->bytes);
151     for (T *ptr = tensor_data_ptr; ptr < end; ptr++)
152       _outfile << *ptr << ", ";
153
154     _outfile << "])" << std::endl << std::endl;
155   }
156
157   void close()
158   {
159     _outfile << "# --------- tensor shape and value defined above ---------" << std::endl;
160     _outfile.close();
161   }
162 };
163
164 } // namespace tflite
165 } // namespace nnfw
166
167 #endif // __NNFW_TFLITE_TENSOR_LOGGER_H__