[tflite_run] Dump input and output tensor separately (#1214)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Tue, 15 May 2018 07:08:39 +0000 (16:08 +0900)
committer김정현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh0822.kim@samsung.com>
Tue, 15 May 2018 07:08:39 +0000 (16:08 +0900)
After `interpreter.Invoke()` input is no longer valid in tflite interpreter
so this commit revises to dump inputs before it and dump outputs after it.
Also the dumper has changed to have a copy of each tensor.

Resolve #1138

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
tools/tflite_run/src/tensor_dumper.cc
tools/tflite_run/src/tensor_dumper.h
tools/tflite_run/src/tflite_run.cc

index bb52b6c..2a27f94 100644 (file)
@@ -8,33 +8,45 @@
 namespace TFLiteRun
 {
 
-TensorDumper::TensorDumper(tflite::Interpreter &interpreter) : _interpreter(interpreter)
+TensorDumper::TensorDumper()
 {
   // DO NOTHING
 }
 
+void TensorDumper::addTensors(tflite::Interpreter &interpreter, const std::vector<int> &indices)
+{
+  for (const auto &o : indices)
+  {
+    const TfLiteTensor *tensor = interpreter.tensor(o);
+    int size = tensor->bytes;
+    std::vector<char> buffer;
+    buffer.resize(size);
+    memcpy(buffer.data(), tensor->data.raw, size);
+    _tensors.emplace_back(o, std::move(buffer));
+  }
+}
+
 void TensorDumper::dump(const std::string &filename) const
 {
   // TODO Handle file open/write error
   std::ofstream file(filename, std::ios::out | std::ios::binary);
 
-  // Concat inputs and outputs
-  std::vector<int> indices = _interpreter.inputs();
-  indices.insert(indices.end(), _interpreter.outputs().begin(), _interpreter.outputs().end());
-
   // Write number of tensors
-  uint32_t num_tensors = static_cast<uint32_t>(indices.size());
+  uint32_t num_tensors = static_cast<uint32_t>(_tensors.size());
   file.write(reinterpret_cast<const char *>(&num_tensors), sizeof(num_tensors));
 
   // Write tensor indices
-  file.write(reinterpret_cast<const char *>(indices.data()), sizeof(int) * num_tensors);
+  for (const auto &t : _tensors)
+  {
+    file.write(reinterpret_cast<const char *>(&t._index), sizeof(int));
+  }
 
   // Write data
-  for (const auto &o : indices)
+  for (const auto &t : _tensors)
   {
-    const TfLiteTensor *tensor = _interpreter.tensor(o);
-    file.write(tensor->data.raw, tensor->bytes);
+    file.write(t._data.data(), t._data.size());
   }
+
   file.close();
 }
 
index 80b4e0f..2805f10 100644 (file)
@@ -1,7 +1,9 @@
 #ifndef __TFLITE_RUN_TENSOR_DUMPER_H__
 #define __TFLITE_RUN_TENSOR_DUMPER_H__
 
+#include <memory>
 #include <string>
+#include <vector>
 
 namespace tflite
 {
@@ -13,12 +15,22 @@ namespace TFLiteRun
 
 class TensorDumper
 {
+private:
+  struct Tensor
+  {
+    int _index;
+    std::vector<char> _data;
+
+    Tensor(int index, std::vector<char> &&data) : _index(index), _data(std::move(data)) {}
+  };
+
 public:
-  TensorDumper(tflite::Interpreter &interpreter);
+  TensorDumper();
+  void addTensors(tflite::Interpreter &interpreter, const std::vector<int> &indices);
   void dump(const std::string &filename) const;
 
 private:
-  tflite::Interpreter &_interpreter;
+  std::vector<Tensor> _tensors;
 };
 
 } // end of namespace TFLiteRun
index c6cbee2..5018e77 100644 (file)
@@ -178,6 +178,10 @@ int main(const int argc, char **argv)
     }
   }
 
+  TFLiteRun::TensorDumper tensor_dumper;
+  // Must be called before `interpreter->Invoke()`
+  tensor_dumper.addTensors(*interpreter, interpreter->inputs());
+
   std::cout << "input tensor indices = [";
   for (const auto &o : interpreter->inputs())
   {
@@ -190,6 +194,9 @@ int main(const int argc, char **argv)
     assert(status == kTfLiteOk);
   };
 
+  // Must be called after `interpreter->Invoke()`
+  tensor_dumper.addTensors(*interpreter, interpreter->outputs());
+
   std::cout << "output tensor indices = [";
   for (const auto &o : interpreter->outputs())
   {
@@ -207,9 +214,8 @@ int main(const int argc, char **argv)
   if (!args.getDumpFilename().empty())
   {
     const std::string &dump_filename = args.getDumpFilename();
-    TFLiteRun::TensorDumper tensor_dumper(*interpreter);
     tensor_dumper.dump(dump_filename);
-    std::cout << "Output tensors have been dumped to file \"" << dump_filename << "\"."
+    std::cout << "Input/output tensors have been dumped to file \"" << dump_filename << "\"."
               << std::endl;
   }