2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci_interpreter/Interpreter.h"
18 #include "luci_interpreter/SimpleMemoryManager.h"
20 #include "loader/ModuleLoader.h"
24 namespace luci_interpreter
30 class EventNotifierImpl final : public EventNotifier
33 EventNotifierImpl(const RuntimeToIR &runtime_to_ir,
34 const std::vector<ExecutionObserver *> &observers)
35 : _runtime_to_ir(runtime_to_ir), _observers(observers)
39 void postTensorWrite(const Tensor *tensor) override
41 assert(tensor != nullptr);
42 for (const auto &observer : _observers)
44 observer->postTensorWrite(_runtime_to_ir.tensor_to_node.at(tensor), tensor);
48 void preOperatorExecute(const Kernel *kernel) override
50 assert(kernel != nullptr);
51 for (const auto &observer : _observers)
53 observer->preOperatorExecute(_runtime_to_ir.kernel_to_node.at(kernel));
57 void postOperatorExecute(const Kernel *kernel) override
59 assert(kernel != nullptr);
60 for (const auto &observer : _observers)
62 observer->postOperatorExecute(_runtime_to_ir.kernel_to_node.at(kernel));
67 const RuntimeToIR &_runtime_to_ir;
68 const std::vector<ExecutionObserver *> &_observers;
73 Interpreter::Interpreter(const luci::Module *module,
74 luci_interpreter::IMemoryManager *memory_manager)
76 _runtime_to_ir = std::make_unique<RuntimeToIR>();
77 _event_notifier = std::make_unique<EventNotifierImpl>(*_runtime_to_ir, _observers);
78 _runtime_module = std::make_unique<RuntimeModule>(_event_notifier.get());
80 if (memory_manager == nullptr)
82 _default_memory_manager = std::make_unique<SimpleMemoryManager>();
83 _memory_manager = _default_memory_manager.get();
87 _memory_manager = memory_manager;
90 ModuleLoader loader(module, _runtime_module.get(), *_runtime_to_ir, _node_to_tensor,
95 Interpreter::~Interpreter() = default;
97 void Interpreter::writeInputTensor(const luci::CircleInput *input_node, const void *data,
100 Tensor *tensor = _runtime_module->getInputTensors()[input_node->index()];
101 if (tensor == nullptr)
103 const std::string &name = input_node->name();
104 throw std::runtime_error("Cannot find tensor for input node named \"" + name + "\".");
107 tensor->writeData(data, data_size);
110 void Interpreter::readOutputTensor(const luci::CircleOutput *output_node, void *data,
113 Tensor *tensor = _runtime_module->getOutputTensors()[output_node->index()];
114 if (tensor == nullptr)
116 const std::string &name = output_node->name();
117 throw std::runtime_error("Cannot find tensor for output node named \"" + name + "\".");
120 tensor->readData(data, data_size);
123 void Interpreter::interpret() { _runtime_module->execute(); }
125 void Interpreter::attachObserver(ExecutionObserver *observer)
127 if (std::find(_observers.cbegin(), _observers.cend(), observer) != _observers.cend())
128 throw std::runtime_error("Observer is already attached.");
129 _observers.push_back(observer);
132 ExecutionObserver::~ExecutionObserver() = default;
134 void ExecutionObserver::postTensorWrite(const luci::CircleNode *, const Tensor *) {}
136 void ExecutionObserver::preOperatorExecute(const luci::CircleNode *) {}
138 void ExecutionObserver::postOperatorExecute(const luci::CircleNode *) {}
140 } // namespace luci_interpreter