[tfl-inspect] Introduce DumpConv2DWeight (#8611)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 30 Oct 2019 10:05:10 +0000 (19:05 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 30 Oct 2019 10:05:10 +0000 (19:05 +0900)
* [tfl-inspect] Introduce DumpConv2DWeight

This will introduce DumpConv2DWeight to dump Conv2D series node weight input type

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* fix namespace

* update comment

* use buildin_code

compiler/tfl-inspect/src/Dump.cpp
compiler/tfl-inspect/src/Dump.h

index 5e586d1..16dcb56 100644 (file)
@@ -17,7 +17,9 @@
 #include "Dump.h"
 #include "Reader.h"
 
+#include <string>
 #include <ostream>
+#include <stdexcept>
 
 namespace tflinspect
 {
@@ -43,3 +45,98 @@ void DumpOperators::run(std::ostream &os, const tflite::Model *model)
 }
 
 } // namespace tflinspect
+
+namespace
+{
+
+const tflite::Operator *operator_match_output(tflinspect::Reader &reader, const int32_t tensor)
+{
+  auto ops = reader.operators();
+
+  for (uint32_t i = 0; i < ops->Length(); ++i)
+  {
+    const auto op = ops->Get(i);
+
+    const std::vector<int32_t> &outputs = tflinspect::as_index_vector(op->outputs());
+
+    for (auto output : outputs)
+    {
+      if (output == tensor)
+        return op;
+    }
+  }
+  return nullptr;
+}
+
+size_t tensor_buffer_size(tflinspect::Reader &reader, const int32_t tensor_id)
+{
+  auto tensors = reader.tensors();
+
+  if (tensor_id < 0 || tensor_id >= tensors->Length())
+  {
+    throw std::runtime_error("Invalid Tensor ID");
+  }
+
+  auto tensor = tensors->Get(tensor_id);
+
+  auto buffers = reader.buffers();
+  auto buffer_id = tensor->buffer();
+
+  const uint8_t *buff_data;
+  size_t size = reader.buffer_info(buffer_id, &buff_data);
+
+  (void)buff_data;
+
+  return size;
+}
+
+} // namespace
+
+namespace tflinspect
+{
+
+void DumpConv2DWeight::run(std::ostream &os, const tflite::Model *model)
+{
+  tflinspect::Reader reader(model);
+
+  assert(reader.num_subgraph() == 1);
+  reader.select_subgraph(0);
+
+  auto ops = reader.operators();
+
+  // dump Conv2D, DepthwiseConv2D and its weight input operator
+  for (uint32_t i = 0; i < ops->Length(); ++i)
+  {
+    const auto op = ops->Get(i);
+    auto bc = reader.builtin_code(op);
+
+    if (bc == tflite::BuiltinOperator_CONV_2D || bc == tflite::BuiltinOperator_DEPTHWISE_CONV_2D)
+    {
+      const std::vector<int32_t> &inputs = tflinspect::as_index_vector(op->inputs());
+      if (inputs.size() < 2)
+      {
+        throw std::runtime_error("Operator has invalid input");
+      }
+      auto weight_input = inputs[1]; // Tensor ID of weight input
+
+      const auto op_weight = operator_match_output(reader, weight_input);
+      const auto buffer_size = tensor_buffer_size(reader, weight_input);
+
+      std::string weight_op_name = "?";
+
+      if (op_weight == nullptr && buffer_size > 0)
+      {
+        weight_op_name = "CONST";
+      }
+      else if (op_weight != nullptr)
+      {
+        weight_op_name = reader.opcode_name(op_weight);
+      }
+
+      auto op_name = reader.opcode_name(op);
+      os << op_name << "," << weight_op_name << std::endl;
+    }
+  }
+}
+
+} // namespace tflinspect
index 436c781..798c1db 100644 (file)
@@ -42,6 +42,15 @@ public:
   void run(std::ostream &os, const tflite::Model *model);
 };
 
+class DumpConv2DWeight final : public DumpInterface
+{
+public:
+  DumpConv2DWeight() = default;
+
+public:
+  void run(std::ostream &os, const tflite::Model *model);
+};
+
 } // namespace tflinspect
 
 #endif // __DUMP_H__