2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "loco/IR/Graph.h"
19 #include <stdex/Memory.h>
26 std::unique_ptr<loco::TensorShape> make_tensor_shape(std::initializer_list<loco::Dimension> dims)
28 auto tensor_shape = stdex::make_unique<loco::TensorShape>();
30 tensor_shape->rank(dims.size());
33 for (auto it = dims.begin(); it != dims.end(); ++it)
35 tensor_shape->dim(axis++) = *it;
37 assert(axis == dims.size());
48 void Mixin<Trait::TensorShaped>::shape(std::initializer_list<Dimension> dims)
50 shape(make_tensor_shape(dims));
53 GraphInput *Graph::InputContext::create(void)
55 return take(stdex::make_unique<GraphInput>(size()));
58 GraphOutput *Graph::OutputContext::create(void)
60 return take(stdex::make_unique<GraphOutput>(size()));
63 std::set<loco::Node *> all_nodes(loco::Graph *g)
65 std::set<loco::Node *> res;
67 for (uint32_t n = 0; n < g->nodes()->size(); ++n)
69 res.insert(g->nodes()->at(n));
75 std::vector<Node *> input_nodes(const Graph *g)
77 std::map<GraphInputIndex, loco::Node *> table;
79 for (uint32_t n = 0; n < g->nodes()->size(); ++n)
81 auto node = g->nodes()->at(n);
83 if (auto service = node->dialect()->service<GraphInputIndexQueryService>())
85 if (service->associated(node))
87 auto input_index = service->index(node);
88 assert(table.find(input_index) == table.end());
89 table[input_index] = node;
94 std::vector<loco::Node *> res;
96 for (uint32_t n = 0; n < g->inputs()->size(); ++n)
98 auto it = table.find(n);
99 res.emplace_back(it == table.end() ? nullptr : it->second);
105 std::vector<loco::Node *> output_nodes(loco::Graph *g)
107 std::map<GraphOutputIndex, loco::Node *> table;
109 for (uint32_t n = 0; n < g->nodes()->size(); ++n)
111 auto node = g->nodes()->at(n);
113 if (auto service = node->dialect()->service<GraphOutputIndexQueryService>())
115 if (service->associated(node))
117 auto output_index = service->index(node);
118 assert(table.find(output_index) == table.end());
119 table[output_index] = node;
124 std::vector<loco::Node *> res;
126 for (uint32_t n = 0; n < g->outputs()->size(); ++n)
128 auto it = table.find(n);
129 res.emplace_back(it == table.end() ? nullptr : it->second);
135 std::unique_ptr<Graph> make_graph(void) { return std::unique_ptr<Graph>{new Graph}; }