Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / OpSequences.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ir/OpSequences.h"
18 #include "util/logging.h"
19 #include <memory>
20
21 #include <cassert>
22 #include <string>
23
24 namespace onert
25 {
26 namespace ir
27 {
28
29 OpSequenceIndex OpSequences::emplace(const OperationIndex &index, Layout layout)
30 {
31   std::unique_ptr<OpSequence> op_seq = std::make_unique<OpSequence>(layout);
32   op_seq->appendOperation(index);
33   const OpSequenceIndex &seq_index = push(std::move(op_seq));
34   cacheSequenceIndex(seq_index, index);
35   return seq_index;
36 }
37
38 OpSequenceIndex OpSequences::emplace(std::unique_ptr<OpSequence> &&op_seq)
39 {
40   auto &operations = op_seq->operations();
41   const OpSequenceIndex &seq_index = push(std::move(op_seq));
42   for (const auto &op_idx : operations)
43   {
44     cacheSequenceIndex(seq_index, op_idx);
45   }
46   return seq_index;
47 }
48
49 void OpSequences::cacheSequenceIndex(const OpSequenceIndex &seq_index,
50                                      const OperationIndex &op_index) const
51 {
52   _seq_indexes.emplace(op_index, seq_index);
53 }
54
55 OpSequenceIndex *OpSequences::findSequenceIndex(const OperationIndex &operation_index) const
56 {
57   // If opration_index is cached, return sequence_index from cache
58   if (_seq_indexes.count(operation_index))
59   {
60     auto &op_seq_index = _seq_indexes.at(operation_index);
61     if (_objects.count(op_seq_index) && _objects.at(op_seq_index)->exist(operation_index))
62     {
63       return &op_seq_index;
64     }
65     else
66     {
67       _seq_indexes.erase(operation_index);
68       return nullptr;
69     }
70   }
71   return nullptr;
72 }
73
74 bool OpSequences::containsOperation(const OperationIndex &operation_index) const
75 {
76   return findOperation(operation_index).valid();
77 }
78
79 OpSequenceIndex OpSequences::getOperation(const OperationIndex &operation_index) const
80 {
81   OpSequenceIndex ret = findOperation(operation_index);
82   assert(ret.valid());
83   return ret;
84 }
85
86 void OpSequences::removeFromOpSequence(const OperationIndex &operation_index)
87 {
88   const auto op_seq_index = findOperation(operation_index);
89   auto &op_seq = at(op_seq_index);
90   _seq_indexes.erase(operation_index);
91   op_seq.remove(operation_index);
92   if (op_seq.size() == 0)
93   {
94     remove(op_seq_index);
95   }
96 }
97
98 OpSequenceIndex OpSequences::findOperation(const OperationIndex &operation_index) const
99 {
100   if (OpSequenceIndex *op_seq_index = findSequenceIndex(operation_index))
101     return *op_seq_index;
102
103   for (auto &e : _objects)
104   {
105     OpSequence &object = *e.second;
106     auto it = find(object.operations().begin(), object.operations().end(), operation_index);
107     if (it != object.operations().end())
108     {
109       cacheSequenceIndex(e.first, operation_index);
110       return e.first;
111     }
112   }
113   throw std::runtime_error("Operation not found");
114 }
115
116 void dumpOpSequences(const OpSequences &op_seqs, const Operations &operations)
117 {
118   op_seqs.iterate([&](const OpSequenceIndex &idx, const OpSequence &op_seq) {
119     VERBOSE(OpSequences) << idx.value() << "] " << getStrFromOpSeq(op_seq, operations) << std::endl;
120   });
121 }
122
123 } // namespace ir
124 } // namespace onert