Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / Interpreter.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci_interpreter/Interpreter.h"
18 #include "luci_interpreter/SimpleMemoryManager.h"
19
20 #include "loader/ModuleLoader.h"
21
22 #include <stdexcept>
23
24 namespace luci_interpreter
25 {
26
27 namespace
28 {
29
30 class EventNotifierImpl final : public EventNotifier
31 {
32 public:
33   EventNotifierImpl(const RuntimeToIR &runtime_to_ir,
34                     const std::vector<ExecutionObserver *> &observers)
35     : _runtime_to_ir(runtime_to_ir), _observers(observers)
36   {
37   }
38
39   void postTensorWrite(const Tensor *tensor) override
40   {
41     assert(tensor != nullptr);
42     for (const auto &observer : _observers)
43     {
44       observer->postTensorWrite(_runtime_to_ir.tensor_to_node.at(tensor), tensor);
45     }
46   }
47
48   void preOperatorExecute(const Kernel *kernel) override
49   {
50     assert(kernel != nullptr);
51     for (const auto &observer : _observers)
52     {
53       observer->preOperatorExecute(_runtime_to_ir.kernel_to_node.at(kernel));
54     }
55   }
56
57   void postOperatorExecute(const Kernel *kernel) override
58   {
59     assert(kernel != nullptr);
60     for (const auto &observer : _observers)
61     {
62       observer->postOperatorExecute(_runtime_to_ir.kernel_to_node.at(kernel));
63     }
64   }
65
66 private:
67   const RuntimeToIR &_runtime_to_ir;
68   const std::vector<ExecutionObserver *> &_observers;
69 };
70
71 } // namespace
72
73 Interpreter::Interpreter(const luci::Module *module,
74                          luci_interpreter::IMemoryManager *memory_manager)
75 {
76   _runtime_to_ir = std::make_unique<RuntimeToIR>();
77   _event_notifier = std::make_unique<EventNotifierImpl>(*_runtime_to_ir, _observers);
78   _runtime_module = std::make_unique<RuntimeModule>(_event_notifier.get());
79
80   if (memory_manager == nullptr)
81   {
82     _default_memory_manager = std::make_unique<SimpleMemoryManager>();
83     _memory_manager = _default_memory_manager.get();
84   }
85   else
86   {
87     _memory_manager = memory_manager;
88   }
89
90   ModuleLoader loader(module, _runtime_module.get(), *_runtime_to_ir, _node_to_tensor,
91                       _memory_manager);
92   loader.load();
93 }
94
95 Interpreter::~Interpreter() = default;
96
97 void Interpreter::writeInputTensor(const luci::CircleInput *input_node, const void *data,
98                                    size_t data_size)
99 {
100   Tensor *tensor = _runtime_module->getInputTensors()[input_node->index()];
101   if (tensor == nullptr)
102   {
103     const std::string &name = input_node->name();
104     throw std::runtime_error("Cannot find tensor for input node named \"" + name + "\".");
105   }
106   if (data != nullptr)
107     tensor->writeData(data, data_size);
108 }
109
110 void Interpreter::readOutputTensor(const luci::CircleOutput *output_node, void *data,
111                                    size_t data_size)
112 {
113   Tensor *tensor = _runtime_module->getOutputTensors()[output_node->index()];
114   if (tensor == nullptr)
115   {
116     const std::string &name = output_node->name();
117     throw std::runtime_error("Cannot find tensor for output node named \"" + name + "\".");
118   }
119   if (data != nullptr)
120     tensor->readData(data, data_size);
121 }
122
123 void Interpreter::interpret() { _runtime_module->execute(); }
124
125 void Interpreter::attachObserver(ExecutionObserver *observer)
126 {
127   if (std::find(_observers.cbegin(), _observers.cend(), observer) != _observers.cend())
128     throw std::runtime_error("Observer is already attached.");
129   _observers.push_back(observer);
130 }
131
132 ExecutionObserver::~ExecutionObserver() = default;
133
134 void ExecutionObserver::postTensorWrite(const luci::CircleNode *, const Tensor *) {}
135
136 void ExecutionObserver::preOperatorExecute(const luci::CircleNode *) {}
137
138 void ExecutionObserver::postOperatorExecute(const luci::CircleNode *) {}
139
140 } // namespace luci_interpreter