2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/Validate.h"
19 #include <luci/IR/Nodes/CircleOutput.h>
22 #include <loco/IR/NodeShape.h>
23 #include <loco/Service/ShapeInference.h>
24 #include <loco/Service/TypeInference.h>
32 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
35 for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
39 os << tensor_shape.dim(r).value();
46 * @brief returns a node that is CircleOutput with index is out_index in nodes
48 luci::CircleOutput *find_node(std::vector<loco::Node *> nodes, loco::GraphOutputIndex out_index)
50 for (auto node : nodes)
52 auto circle_output = dynamic_cast<luci::CircleOutput *>(node);
53 if (circle_output != nullptr)
55 if (circle_output->indexed() && circle_output->index() == out_index)
62 bool validate_shape_dtype(loco::Graph *g)
66 auto output_nodes = loco::output_nodes(g);
68 auto count = g->outputs()->size();
69 for (uint32_t out = 0; out < count; ++out)
71 auto graph_out = g->outputs()->at(out);
72 auto out_index = graph_out->index();
74 auto circle_output = find_node(output_nodes, out_index);
75 assert(circle_output != nullptr);
76 assert(circle_output->from() != nullptr);
77 auto circle_node = loco::must_cast<luci::CircleNode *>(circle_output->from());
79 // Shape and dtype validation for CiecleOutputExclude is not needed
80 if (dynamic_cast<luci::CircleOutputExclude *>(circle_node))
83 assert(loco::shape_known(circle_node));
85 // check if output node shape is same as graph output shape
86 auto co_tensor_shape = loco::shape_get(circle_node).as<loco::TensorShape>();
87 auto go_tensor_shape = graph_out->shape();
88 assert(go_tensor_shape);
89 if (!(co_tensor_shape == *go_tensor_shape))
91 INFO(l) << "[luci] Shape for output #" << out_index << " not same " << std::endl;
92 INFO(l) << "[luci] " << circle_node->name() << " " << co_tensor_shape << " vs "
93 << *go_tensor_shape << std::endl;
97 // check if data type match
98 assert(loco::dtype_known(circle_node));
99 if (graph_out->dtype() != loco::dtype_get(circle_node))
101 INFO(l) << "[luci] Type for output #" << out_index << " not same " << std::endl;
114 bool validate(loco::Graph *g)
119 if (!validate_shape_dtype(g))
122 // TODO add more validation