a87d31a9fd215a214c6ab37e6147f87331d84344
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / OpSequences.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ir/OpSequences.h"
18 #include "util/logging.h"
19 #include <memory>
20
21 #include <cassert>
22 #include <string>
23
24 namespace onert
25 {
26 namespace ir
27 {
28
29 OpSequenceIndex OpSequences::emplace(const OperationIndex &index, Layout layout)
30 {
31   std::unique_ptr<OpSequence> op_seq = std::make_unique<OpSequence>(layout);
32   op_seq->appendOperation(index);
33   const OpSequenceIndex &seq_index = push(std::move(op_seq));
34   cacheSequenceIndex(seq_index, index);
35   return seq_index;
36 }
37
38 OpSequenceIndex OpSequences::emplace(std::unique_ptr<OpSequence> &&op_seq)
39 {
40   auto &operations = op_seq->operations();
41   const OpSequenceIndex &seq_index = push(std::move(op_seq));
42   for (const auto &op_idx : operations)
43   {
44     cacheSequenceIndex(seq_index, op_idx);
45   }
46   return seq_index;
47 }
48
49 void OpSequences::cacheSequenceIndex(const OpSequenceIndex &seq_index,
50                                      const OperationIndex &op_index) const
51 {
52   _seq_indexes.emplace(op_index, seq_index);
53 }
54
55 OpSequenceIndex *OpSequences::findSequenceIndex(const OperationIndex &operation_index) const
56 {
57   // If opration_index is cached, return sequence_index from cache
58   if (_seq_indexes.count(operation_index))
59   {
60     auto &op_seq_index = _seq_indexes.at(operation_index);
61     if (_objects.count(op_seq_index) && _objects.at(op_seq_index)->exist(operation_index))
62     {
63       return &op_seq_index;
64     }
65     else
66     {
67       _seq_indexes.erase(operation_index);
68       return nullptr;
69     }
70   }
71   return nullptr;
72 }
73
74 bool OpSequences::containsOperation(const OperationIndex &operation_index) const
75 {
76   return findOperation(operation_index).valid();
77 }
78
79 OpSequenceIndex OpSequences::getOperation(const OperationIndex &operation_index) const
80 {
81   OpSequenceIndex ret = findOperation(operation_index);
82   assert(ret.valid());
83   return ret;
84 }
85
86 // TODO: Extract this into external helper function
87 void OpSequences::dump(const std::string &msg, const Operations &operations) const
88 {
89   VERBOSE(OpSequences) << "OpSequences(" << msg << ")" << std::endl;
90   iterate([&](const OpSequenceIndex &idx, const OpSequence &op_seq) {
91     VERBOSE(OpSequences) << idx.value() << "] " << getStrFromOpSeq(op_seq, operations) << std::endl;
92   });
93 }
94
95 void OpSequences::removeFromOpSequence(const OperationIndex &operation_index)
96 {
97   const auto op_seq_index = findOperation(operation_index);
98   auto &op_seq = at(op_seq_index);
99   _seq_indexes.erase(operation_index);
100   op_seq.remove(operation_index);
101   if (op_seq.size() == 0)
102   {
103     remove(op_seq_index);
104   }
105 }
106
107 OpSequenceIndex OpSequences::findOperation(const OperationIndex &operation_index) const
108 {
109   if (OpSequenceIndex *op_seq_index = findSequenceIndex(operation_index))
110     return *op_seq_index;
111
112   for (auto &e : _objects)
113   {
114     OpSequence &object = *e.second;
115     auto it = find(object.operations().begin(), object.operations().end(), operation_index);
116     if (it != object.operations().end())
117     {
118       cacheSequenceIndex(e.first, operation_index);
119       return e.first;
120     }
121   }
122   throw std::runtime_error("Operation not found");
123 }
124
125 } // namespace ir
126 } // namespace onert