2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * @file TensorLogger.h
19 * @brief This file contains TensorLogger class
20 * @ingroup COM_AI_RUNTIME
23 #ifndef __NNFW_TFLITE_TENSOR_LOGGER_H__
24 #define __NNFW_TFLITE_TENSOR_LOGGER_H__
26 #include "misc/tensor/IndexIterator.h"
27 #include "tflite/TensorView.h"
29 #include <tensorflow/lite/interpreter.h>
30 #include <tensorflow/lite/context.h>
40 * @brief Class to write input and output value / shape into a file in python form
41 * @note This is a utility to write input and output value / shape into a file in python form.\n
42 * any python app can load this value by running the python code below:\n
43 * exec(open(filename).read())\n
44 * generated python code looks like the following: \n
45 * tensor_shape_gen = []\n
46 * tensor_value_gen = []\n\n
47 * tensor_shape_gen.append("{2, 1, 2}")\n
48 * tensor_value_gen.append([1, 2, 3, 4])\n\n
49 * tensor_shape_gen.append("{2}")\n
50 * tensor_value_gen.append([1, 2])\n\n
51 * tensor_shape_gen.append("{2, 1, 2}")\n
52 * tensor_value_gen.append([1, 4, 3, 8])\n
57 std::ofstream _outfile;
61 * @brief Get TensorLogger instance
62 * @return The TensorLogger instance
64 static TensorLogger &get()
66 static TensorLogger instance;
71 * @brief Save the tensor details to file from interpreter
72 * @param[in] path The file path to save
73 * @param[in] interp The TfLite interpreter
75 void save(const std::string &path, ::tflite::Interpreter &interp)
80 for (const auto id : interp.inputs())
82 _outfile << "# input tensors" << std::endl;
83 printTensor(interp, id, log_index++);
85 for (const auto id : interp.outputs())
87 _outfile << "# output tensors" << std::endl;
88 printTensor(interp, id, log_index++);
94 void open(const std::string &path)
96 if (!_outfile.is_open())
97 _outfile.open(path, std::ios_base::out);
99 _outfile << "# ------ file: " << path << " ------" << std::endl
100 << "tensor_shape_gen = []" << std::endl
101 << "tensor_value_gen = []" << std::endl
105 void printTensor(::tflite::Interpreter &interp, const int id, const int log_index)
107 const TfLiteTensor *tensor = interp.tensor(id);
109 _outfile << "# tensor name: " << tensor->name << std::endl;
110 _outfile << "# tflite::interpreter.tensor(" << id << ") -> tensor_value_gen[" << log_index
113 if (tensor->type == kTfLiteInt32)
115 printTensorShape(tensor);
116 printTensorValue<int32_t>(tensor, tensor->data.i32);
118 else if (interp.tensor(id)->type == kTfLiteUInt8)
120 printTensorShape(tensor);
121 printTensorValue<uint8_t>(tensor, tensor->data.uint8);
123 else if (tensor->type == kTfLiteFloat32)
125 printTensorShape(tensor);
126 printTensorValue<float>(tensor, tensor->data.f);
130 void printTensorShape(const TfLiteTensor *tensor)
132 _outfile << "tensor_shape_gen.append('{";
135 for (; r < tensor->dims->size - 1; r++)
137 _outfile << tensor->dims->data[r] << ", ";
139 _outfile << tensor->dims->data[r];
141 _outfile << "}')" << std::endl;
144 template <typename T> void printTensorValue(const TfLiteTensor *tensor, T *tensor_data_ptr)
146 _outfile << "tensor_value_gen.append([";
148 _outfile << std::fixed << std::setprecision(10);
150 const T *end = reinterpret_cast<const T *>(tensor->data.raw_const + tensor->bytes);
151 for (T *ptr = tensor_data_ptr; ptr < end; ptr++)
152 _outfile << *ptr << ", ";
154 _outfile << "])" << std::endl << std::endl;
159 _outfile << "# --------- tensor shape and value defined above ---------" << std::endl;
164 } // namespace tflite
167 #endif // __NNFW_TFLITE_TENSOR_LOGGER_H__