From 3791e6937dd617bddc28122e24860070b0a6cbb4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Tue, 4 Dec 2018 13:42:20 +0900 Subject: [PATCH] [tfldump] Add Reader with OpCode reader (#2475) * [tfldump] Add Reader with OpCode reader This will introduce Reader class that holds OpCodes Dump will dump operator codes that are used in the Model Signed-off-by: SaeHie Park * opcode_name as a function * move TFliteSubGraphs_t inside class * show number legend * add some more --- contrib/tfldump/src/Dump.cpp | 21 +++++++++++-- contrib/tfldump/src/Read.cpp | 71 ++++++++++++++++++++++++++++++++++++++++++++ contrib/tfldump/src/Read.h | 59 ++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 2 deletions(-) create mode 100644 contrib/tfldump/src/Read.cpp create mode 100644 contrib/tfldump/src/Read.h diff --git a/contrib/tfldump/src/Dump.cpp b/contrib/tfldump/src/Dump.cpp index aedac6d..678387c 100644 --- a/contrib/tfldump/src/Dump.cpp +++ b/contrib/tfldump/src/Dump.cpp @@ -16,6 +16,8 @@ #include +#include "Read.h" + #include namespace tfldump @@ -23,10 +25,25 @@ namespace tfldump void dump_model(std::ostream &os, const tflite::Model *model) { - // TODO place reader + tflread::Reader reader(model); + + assert(reader.num_subgraph() == 1); + + auto opcodes = reader.opcodes(); // dump operator_codes - os << "Operator Codes:" << std::endl; + os << "Operator Codes: [order] OpCodeName (OpCode Enum)" << std::endl; + int32_t opcode_index = 0; + for (auto opcode : opcodes) + { + tflite::BuiltinOperator op_code = opcode->builtin_code(); + auto op_name = tflread::opcode_name(opcode); + + os << "[" << opcode_index << "] " << op_name << " (code: " << op_code << ")" << std::endl; + + opcode_index++; + } + os << std::endl; // dump buffer os << "Buffers:" << std::endl; diff --git a/contrib/tfldump/src/Read.cpp b/contrib/tfldump/src/Read.cpp new file mode 100644 index 0000000..594d557 --- /dev/null +++ b/contrib/tfldump/src/Read.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Read.h" + +#include +#include + +namespace tflread +{ + +bool is_valid(const tflite::OperatorCode *opcode) +{ + tflite::BuiltinOperator code = opcode->builtin_code(); + return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX); +} + +bool is_custom(const tflite::OperatorCode *opcode) +{ + tflite::BuiltinOperator code = opcode->builtin_code(); + return (code == tflite::BuiltinOperator_CUSTOM); +} + +std::string opcode_name(const tflite::OperatorCode *opcode) +{ + assert(opcode); + + if (!is_valid(opcode)) + { + std::ostringstream oss; + oss << "(invalid)"; + return oss.str(); + } + + if (is_custom(opcode)) + { + if (!opcode->custom_code()) + return "(invalid custom)"; + + return opcode->custom_code()->c_str(); + } + + tflite::BuiltinOperator code = opcode->builtin_code(); + return tflite::EnumNameBuiltinOperator(code); +} + +Reader::Reader(const tflite::Model *model) +{ + _subgraphs = model->subgraphs(); + + auto opcodes = model->operator_codes(); + for (const ::tflite::OperatorCode *opcode : *opcodes) + { + _op_codes.push_back(opcode); + } +} + +} // namespace tflread diff --git a/contrib/tfldump/src/Read.h b/contrib/tfldump/src/Read.h new file mode 100644 index 0000000..1328c00 --- /dev/null +++ b/contrib/tfldump/src/Read.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFLREAD_READ_H__ +#define __TFLREAD_READ_H__ + +#include + +#include +#include +#include + +namespace tflread +{ + +bool is_valid(const tflite::OperatorCode *opcode); +bool is_custom(const tflite::OperatorCode *opcode); +std::string opcode_name(const tflite::OperatorCode *opcode); + +/** + * @brief Loads TF lite file and provides helpers to access attributes + */ +class Reader +{ +private: + using TFliteSubGraphs_t = flatbuffers::Vector>; + +public: + Reader(const tflite::Model *model); + + Reader() = delete; + +public: + const std::vector &opcodes() { return _op_codes; } + + uint32_t num_subgraph() const { return _subgraphs->Length(); } + +private: + const TFliteSubGraphs_t *_subgraphs; + + std::vector _op_codes; +}; + +} // namespace tflread + +#endif // __TFLREAD_READ_H__ -- 2.7.4