a824c3411b24e32f8a96a69c498c366b4a489eda
[platform/core/ml/nnfw.git] / runtime / libs / tflite / include / tflite / TensorLogger.h
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * @file     TensorLogger.h
19  * @brief    This file contains TensorLogger class
20  * @ingroup  COM_AI_RUNTIME
21  */
22
23 #ifndef __NNFW_TFLITE_TENSOR_LOGGER_H__
24 #define __NNFW_TFLITE_TENSOR_LOGGER_H__
25
26 #include "misc/tensor/IndexIterator.h"
27 #include "tflite/TensorView.h"
28
29 #include <tensorflow/lite/interpreter.h>
30 #include <tensorflow/lite/context.h>
31 #include <fstream>
32 #include <iomanip>
33
34 namespace nnfw
35 {
36 namespace tflite
37 {
38
39 /**
40  * @brief Class to write input and output value / shape into a file in python form
41  * @note This is a utility to write input and output value / shape into a file in python form.\n
42  *       any python app can load this value by running the python code below:\n
43  *       exec(open(filename).read())\n
44  *       generated python code looks like the following: \n
45  *       tensor_shape_gen = []\n
46  *       tensor_value_gen = []\n\n
47  *       tensor_shape_gen.append("{2, 1, 2}")\n
48  *       tensor_value_gen.append([1, 2, 3, 4])\n\n
49  *       tensor_shape_gen.append("{2}")\n
50  *       tensor_value_gen.append([1, 2])\n\n
51  *       tensor_shape_gen.append("{2, 1, 2}")\n
52  *       tensor_value_gen.append([1, 4, 3, 8])\n
53  */
54 class TensorLogger
55 {
56 private:
57   std::ofstream _outfile;
58
59 public:
60   /**
61    * @brief Get TensorLogger instance
62    * @return The TensorLogger instance
63    */
64   static TensorLogger &get()
65   {
66     static TensorLogger instance;
67     return instance;
68   }
69
70   /**
71    * @brief Save the tensor details to file from interpreter
72    * @param[in] path The file path to save
73    * @param[in] interp The TfLite interpreter
74    */
75   void save(const std::string &path, ::tflite::Interpreter &interp)
76   {
77     open(path);
78
79     int log_index = 0;
80     for (const auto id : interp.inputs())
81     {
82       _outfile << "# input tensors" << std::endl;
83       printTensor(interp, id, log_index++);
84     }
85     for (const auto id : interp.outputs())
86     {
87       _outfile << "# output tensors" << std::endl;
88       printTensor(interp, id, log_index++);
89     }
90     close();
91   }
92
93 private:
94   void open(const std::string &path)
95   {
96     if (!_outfile.is_open())
97       _outfile.open(path, std::ios_base::out);
98
99     _outfile << "# ------ file: " << path << " ------" << std::endl
100              << "tensor_shape_gen = []" << std::endl
101              << "tensor_value_gen = []" << std::endl
102              << std::endl;
103   }
104
105   void printTensor(::tflite::Interpreter &interp, const int id, const int log_index)
106   {
107     const TfLiteTensor *tensor = interp.tensor(id);
108
109     _outfile << "# tensor name: " << tensor->name << std::endl;
110     _outfile << "# tflite::interpreter.tensor(" << id << ") -> "
111                                                          "tensor_value_gen["
112              << log_index << "]" << std::endl;
113
114     if (tensor->type == kTfLiteInt32)
115     {
116       printTensorShape(tensor);
117       printTensorValue<int32_t>(tensor, tensor->data.i32);
118     }
119     else if (interp.tensor(id)->type == kTfLiteUInt8)
120     {
121       printTensorShape(tensor);
122       printTensorValue<uint8_t>(tensor, tensor->data.uint8);
123     }
124     else if (tensor->type == kTfLiteFloat32)
125     {
126       printTensorShape(tensor);
127       printTensorValue<float>(tensor, tensor->data.f);
128     }
129   }
130
131   void printTensorShape(const TfLiteTensor *tensor)
132   {
133     _outfile << "tensor_shape_gen.append('{";
134
135     int r = 0;
136     for (; r < tensor->dims->size - 1; r++)
137     {
138       _outfile << tensor->dims->data[r] << ", ";
139     }
140     _outfile << tensor->dims->data[r];
141
142     _outfile << "}')" << std::endl;
143   }
144
145   template <typename T> void printTensorValue(const TfLiteTensor *tensor, T *tensor_data_ptr)
146   {
147     _outfile << "tensor_value_gen.append([";
148
149     _outfile << std::fixed << std::setprecision(10);
150
151     const T *end = reinterpret_cast<const T *>(tensor->data.raw_const + tensor->bytes);
152     for (T *ptr = tensor_data_ptr; ptr < end; ptr++)
153       _outfile << *ptr << ", ";
154
155     _outfile << "])" << std::endl << std::endl;
156   }
157
158   void close()
159   {
160     _outfile << "# --------- tensor shape and value defined above ---------" << std::endl;
161     _outfile.close();
162   }
163 };
164
165 } // namespace tflite
166 } // namespace nnfw
167
168 #endif // __NNFW_TFLITE_TENSOR_LOGGER_H__