namespace TFLiteRun
{
-TensorDumper::TensorDumper(tflite::Interpreter &interpreter) : _interpreter(interpreter)
+TensorDumper::TensorDumper()
{
// DO NOTHING
}
+void TensorDumper::addTensors(tflite::Interpreter &interpreter, const std::vector<int> &indices)
+{
+ for (const auto &o : indices)
+ {
+ const TfLiteTensor *tensor = interpreter.tensor(o);
+ int size = tensor->bytes;
+ std::vector<char> buffer;
+ buffer.resize(size);
+ memcpy(buffer.data(), tensor->data.raw, size);
+ _tensors.emplace_back(o, std::move(buffer));
+ }
+}
+
void TensorDumper::dump(const std::string &filename) const
{
// TODO Handle file open/write error
std::ofstream file(filename, std::ios::out | std::ios::binary);
- // Concat inputs and outputs
- std::vector<int> indices = _interpreter.inputs();
- indices.insert(indices.end(), _interpreter.outputs().begin(), _interpreter.outputs().end());
-
// Write number of tensors
- uint32_t num_tensors = static_cast<uint32_t>(indices.size());
+ uint32_t num_tensors = static_cast<uint32_t>(_tensors.size());
file.write(reinterpret_cast<const char *>(&num_tensors), sizeof(num_tensors));
// Write tensor indices
- file.write(reinterpret_cast<const char *>(indices.data()), sizeof(int) * num_tensors);
+ for (const auto &t : _tensors)
+ {
+ file.write(reinterpret_cast<const char *>(&t._index), sizeof(int));
+ }
// Write data
- for (const auto &o : indices)
+ for (const auto &t : _tensors)
{
- const TfLiteTensor *tensor = _interpreter.tensor(o);
- file.write(tensor->data.raw, tensor->bytes);
+ file.write(t._data.data(), t._data.size());
}
+
file.close();
}
#ifndef __TFLITE_RUN_TENSOR_DUMPER_H__
#define __TFLITE_RUN_TENSOR_DUMPER_H__
+#include <memory>
#include <string>
+#include <vector>
namespace tflite
{
class TensorDumper
{
+private:
+ struct Tensor
+ {
+ int _index;
+ std::vector<char> _data;
+
+ Tensor(int index, std::vector<char> &&data) : _index(index), _data(std::move(data)) {}
+ };
+
public:
- TensorDumper(tflite::Interpreter &interpreter);
+ TensorDumper();
+ void addTensors(tflite::Interpreter &interpreter, const std::vector<int> &indices);
void dump(const std::string &filename) const;
private:
- tflite::Interpreter &_interpreter;
+ std::vector<Tensor> _tensors;
};
} // end of namespace TFLiteRun
}
}
+ TFLiteRun::TensorDumper tensor_dumper;
+ // Must be called before `interpreter->Invoke()`
+ tensor_dumper.addTensors(*interpreter, interpreter->inputs());
+
std::cout << "input tensor indices = [";
for (const auto &o : interpreter->inputs())
{
assert(status == kTfLiteOk);
};
+ // Must be called after `interpreter->Invoke()`
+ tensor_dumper.addTensors(*interpreter, interpreter->outputs());
+
std::cout << "output tensor indices = [";
for (const auto &o : interpreter->outputs())
{
if (!args.getDumpFilename().empty())
{
const std::string &dump_filename = args.getDumpFilename();
- TFLiteRun::TensorDumper tensor_dumper(*interpreter);
tensor_dumper.dump(dump_filename);
- std::cout << "Output tensors have been dumped to file \"" << dump_filename << "\"."
+ std::cout << "Input/output tensors have been dumped to file \"" << dump_filename << "\"."
<< std::endl;
}