6c45295161531dee5f0d84061c41952f3c1eccf6
[platform/core/ml/nnfw.git] / compiler / tfl-inspect / src / Reader.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Reader.h"
18
19 #include <mio_tflite280/Helper.h>
20
21 #include <cassert>
22 #include <sstream>
23 #include <string>
24
25 namespace tflinspect
26 {
27
28 Reader::Reader(const tflite::Model *model)
29 {
30   _subgraphs = model->subgraphs();
31   _buffers = model->buffers();
32
33   auto opcodes = model->operator_codes();
34   for (const ::tflite::OperatorCode *opcode : *opcodes)
35   {
36     _op_codes.push_back(opcode);
37   }
38 }
39
40 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
41 {
42   if (buff_data != nullptr)
43   {
44     *buff_data = nullptr;
45   }
46
47   if (buf_idx == 0)
48     return 0;
49
50   if (auto *buffer = (*_buffers)[buf_idx])
51   {
52     if (auto *array = buffer->data())
53     {
54       if (size_t size = array->size())
55       {
56         if (buff_data != nullptr)
57         {
58           *buff_data = reinterpret_cast<const uint8_t *>(array->data());
59         }
60         return size;
61       }
62     }
63   }
64
65   return 0;
66 }
67
68 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
69 {
70   uint32_t index = op->opcode_index();
71   assert(index < _op_codes.size());
72   const tflite::OperatorCode *opcode = _op_codes.at(index);
73
74   return mio::tflite::builtin_code_neutral(opcode);
75 }
76
77 std::string Reader::opcode_name(const tflite::Operator *op) const
78 {
79   uint32_t index = op->opcode_index();
80   assert(index < _op_codes.size());
81   const tflite::OperatorCode *opcode = _op_codes.at(index);
82
83   if (!mio::tflite::is_valid(opcode))
84   {
85     std::ostringstream oss;
86     oss << "(invalid: " << index << ")";
87     return oss.str();
88   }
89
90   return mio::tflite::opcode_name(opcode);
91 }
92
93 bool Reader::select_subgraph(uint32_t sgindex)
94 {
95   _tensors = nullptr;
96   _operators = nullptr;
97
98   _inputs.clear();
99   _outputs.clear();
100
101   if (_subgraphs->Length() <= sgindex)
102   {
103     assert(false);
104     return false;
105   }
106
107   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
108
109   _tensors = subgraph->tensors();
110   _operators = subgraph->operators();
111
112   _inputs = as_index_vector(subgraph->inputs());
113   _outputs = as_index_vector(subgraph->outputs());
114
115   return true;
116 }
117
118 } // namespace tflinspect