From f74b82a77205e0583ac9484a28a06176e789cfa9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Mon, 12 Nov 2018 18:38:35 +0900 Subject: [PATCH] [enco/frontend] Introduce TflOpCodeContext class (#2201) * [enco/frontend] Introduce TflOpCodeContext class This will introduce TflOpCodeContext class that holds operator codes of the TF lite model and related methods Signed-off-by: SaeHie Park * apply comment as add static * add note for opcode_name * move to tflimport namespace region * add const and change comment --- contrib/enco/frontend/tflite/src/Frontend.cpp | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/contrib/enco/frontend/tflite/src/Frontend.cpp b/contrib/enco/frontend/tflite/src/Frontend.cpp index ffb90a7..c8cd818 100644 --- a/contrib/enco/frontend/tflite/src/Frontend.cpp +++ b/contrib/enco/frontend/tflite/src/Frontend.cpp @@ -23,6 +23,7 @@ #include #include +#include using namespace nncc::core::ADT; @@ -168,6 +169,76 @@ void set_module_outputs(coco::Module *m, TensorContext &ctx, TensorBags &bags, } /** + * @brief Class that holds operator codes and related methods + */ +class TflOpCodeContext +{ +public: + TflOpCodeContext(const flatbuffers::Vector> *opcodes) + { + for (const tflite::OperatorCode *opcode : *opcodes) + { + _opcodes.push_back(opcode); + } + } + + /** + * @brief Returns BuiltinOperator value of the operator + */ + tflite::BuiltinOperator builtin_code(const tflite::Operator *op) const + { + uint32_t index = op->opcode_index(); + assert(index < _opcodes.size()); + const tflite::OperatorCode *opcode = _opcodes.at(index); + return opcode->builtin_code(); + } + + /** + * @brief Returns human readable name of the operator code of the operator + * + * @note TF lite InterpreterBuilder sets an error state and returns error code + * for invalid opcode. Here we just return human readable message as + * this method returns a name for the operator code. + */ + const char *opcode_name(const tflite::Operator *op) const + { + uint32_t index = op->opcode_index(); + assert(index < _opcodes.size()); + const tflite::OperatorCode *opcode = _opcodes.at(index); + + if (!is_valid(opcode)) + return "(invalid)"; + + if (is_custom(opcode)) + { + if (!opcode->custom_code()) + return "(invalid custom)"; + + return opcode->custom_code()->c_str(); + } + + tflite::BuiltinOperator code = opcode->builtin_code(); + return EnumNameBuiltinOperator(code); + } + +public: + static bool is_valid(const tflite::OperatorCode *opcode) + { + tflite::BuiltinOperator code = opcode->builtin_code(); + return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX); + } + + static bool is_custom(const tflite::OperatorCode *opcode) + { + tflite::BuiltinOperator code = opcode->builtin_code(); + return (code == tflite::BuiltinOperator_CUSTOM); + } + +private: + std::vector _opcodes; +}; + +/** * @brief Class to read and provide buffer information of tflite */ class TflBufferContext -- 2.7.4