d224fd172e965ec51f3d4be8b3378480747d91fb
[platform/core/ml/nnfw.git] / compiler / luci / service / src / Validate.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/Validate.h"
18
19 #include <luci/IR/Nodes/CircleOutput.h>
20 #include <luci/Log.h>
21
22 #include <loco/IR/NodeShape.h>
23 #include <loco/Service/ShapeInference.h>
24 #include <loco/Service/TypeInference.h>
25
26 #include <cassert>
27 #include <vector>
28
29 namespace
30 {
31
32 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
33 {
34   os << "[";
35   for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
36   {
37     if (r)
38       os << ",";
39     os << tensor_shape.dim(r).value();
40   }
41   os << "]";
42   return os;
43 }
44
45 /**
46  * @brief  returns a node that is CircleOutput with index is out_index in nodes
47  */
48 luci::CircleOutput *find_node(std::vector<loco::Node *> nodes, loco::GraphOutputIndex out_index)
49 {
50   for (auto node : nodes)
51   {
52     auto circle_output = dynamic_cast<luci::CircleOutput *>(node);
53     if (circle_output != nullptr)
54     {
55       if (circle_output->indexed() && circle_output->index() == out_index)
56         return circle_output;
57     }
58   }
59   return nullptr;
60 }
61
62 bool validate_shape_dtype(loco::Graph *g)
63 {
64   LOGGER(l);
65
66   auto output_nodes = loco::output_nodes(g);
67
68   auto count = g->outputs()->size();
69   for (uint32_t out = 0; out < count; ++out)
70   {
71     auto graph_out = g->outputs()->at(out);
72     auto out_index = graph_out->index();
73
74     auto circle_output = find_node(output_nodes, out_index);
75     assert(circle_output != nullptr);
76     assert(circle_output->from() != nullptr);
77     auto circle_node = loco::must_cast<luci::CircleNode *>(circle_output->from());
78
79     // Shape and dtype validation for CiecleOutputExclude is not needed
80     if (dynamic_cast<luci::CircleOutputExclude *>(circle_node))
81       continue;
82
83     assert(loco::shape_known(circle_node));
84
85     // check if output node shape is same as graph output shape
86     auto co_tensor_shape = loco::shape_get(circle_node).as<loco::TensorShape>();
87     auto go_tensor_shape = graph_out->shape();
88     assert(go_tensor_shape);
89     if (!(co_tensor_shape == *go_tensor_shape))
90     {
91       INFO(l) << "[luci] Shape for output #" << out_index << " not same " << std::endl;
92       INFO(l) << "[luci]    " << circle_node->name() << " " << co_tensor_shape << " vs "
93               << *go_tensor_shape << std::endl;
94       return false;
95     }
96
97     // check if data type match
98     assert(loco::dtype_known(circle_node));
99     if (graph_out->dtype() != loco::dtype_get(circle_node))
100     {
101       INFO(l) << "[luci] Type for output #" << out_index << " not same " << std::endl;
102       return false;
103     }
104   }
105
106   return true;
107 }
108
109 } // namespace
110
111 namespace luci
112 {
113
114 bool validate(loco::Graph *g)
115 {
116   if (!loco::valid(g))
117     return false;
118
119   if (!validate_shape_dtype(g))
120     return false;
121
122   // TODO add more validation
123
124   return true;
125 }
126
127 } // namespace luci