From f99ce168b23255b0709930f28276122f7bb8ab2e Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 30 Oct 2019 19:05:10 +0900 Subject: [PATCH] [tfl-inspect] Introduce DumpConv2DWeight (#8611) * [tfl-inspect] Introduce DumpConv2DWeight This will introduce DumpConv2DWeight to dump Conv2D series node weight input type Signed-off-by: SaeHie Park * fix namespace * update comment * use buildin_code --- compiler/tfl-inspect/src/Dump.cpp | 97 +++++++++++++++++++++++++++++++++++++++ compiler/tfl-inspect/src/Dump.h | 9 ++++ 2 files changed, 106 insertions(+) diff --git a/compiler/tfl-inspect/src/Dump.cpp b/compiler/tfl-inspect/src/Dump.cpp index 5e586d1..16dcb56 100644 --- a/compiler/tfl-inspect/src/Dump.cpp +++ b/compiler/tfl-inspect/src/Dump.cpp @@ -17,7 +17,9 @@ #include "Dump.h" #include "Reader.h" +#include #include +#include namespace tflinspect { @@ -43,3 +45,98 @@ void DumpOperators::run(std::ostream &os, const tflite::Model *model) } } // namespace tflinspect + +namespace +{ + +const tflite::Operator *operator_match_output(tflinspect::Reader &reader, const int32_t tensor) +{ + auto ops = reader.operators(); + + for (uint32_t i = 0; i < ops->Length(); ++i) + { + const auto op = ops->Get(i); + + const std::vector &outputs = tflinspect::as_index_vector(op->outputs()); + + for (auto output : outputs) + { + if (output == tensor) + return op; + } + } + return nullptr; +} + +size_t tensor_buffer_size(tflinspect::Reader &reader, const int32_t tensor_id) +{ + auto tensors = reader.tensors(); + + if (tensor_id < 0 || tensor_id >= tensors->Length()) + { + throw std::runtime_error("Invalid Tensor ID"); + } + + auto tensor = tensors->Get(tensor_id); + + auto buffers = reader.buffers(); + auto buffer_id = tensor->buffer(); + + const uint8_t *buff_data; + size_t size = reader.buffer_info(buffer_id, &buff_data); + + (void)buff_data; + + return size; +} + +} // namespace + +namespace tflinspect +{ + +void DumpConv2DWeight::run(std::ostream &os, const tflite::Model *model) +{ + tflinspect::Reader reader(model); + + assert(reader.num_subgraph() == 1); + reader.select_subgraph(0); + + auto ops = reader.operators(); + + // dump Conv2D, DepthwiseConv2D and its weight input operator + for (uint32_t i = 0; i < ops->Length(); ++i) + { + const auto op = ops->Get(i); + auto bc = reader.builtin_code(op); + + if (bc == tflite::BuiltinOperator_CONV_2D || bc == tflite::BuiltinOperator_DEPTHWISE_CONV_2D) + { + const std::vector &inputs = tflinspect::as_index_vector(op->inputs()); + if (inputs.size() < 2) + { + throw std::runtime_error("Operator has invalid input"); + } + auto weight_input = inputs[1]; // Tensor ID of weight input + + const auto op_weight = operator_match_output(reader, weight_input); + const auto buffer_size = tensor_buffer_size(reader, weight_input); + + std::string weight_op_name = "?"; + + if (op_weight == nullptr && buffer_size > 0) + { + weight_op_name = "CONST"; + } + else if (op_weight != nullptr) + { + weight_op_name = reader.opcode_name(op_weight); + } + + auto op_name = reader.opcode_name(op); + os << op_name << "," << weight_op_name << std::endl; + } + } +} + +} // namespace tflinspect diff --git a/compiler/tfl-inspect/src/Dump.h b/compiler/tfl-inspect/src/Dump.h index 436c781..798c1db 100644 --- a/compiler/tfl-inspect/src/Dump.h +++ b/compiler/tfl-inspect/src/Dump.h @@ -42,6 +42,15 @@ public: void run(std::ostream &os, const tflite::Model *model); }; +class DumpConv2DWeight final : public DumpInterface +{ +public: + DumpConv2DWeight() = default; + +public: + void run(std::ostream &os, const tflite::Model *model); +}; + } // namespace tflinspect #endif // __DUMP_H__ -- 2.7.4