Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / fluid / modules / gapi / src / compiler / gmodelbuilder.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 //
5 // Copyright (C) 2018-2019 Intel Corporation
6
7
8 ////////////////////////////////////////////////////////////////////////////////
9 //
10 //    FIXME: "I personally hate this file"
11 //                                        - Dmitry
12 //
13 ////////////////////////////////////////////////////////////////////////////////
14 #include "precomp.hpp"
15
16 #include <utility>              // tuple
17 #include <stack>                // stack
18 #include <vector>               // vector
19 #include <unordered_set>        // unordered_set
20 #include <type_traits>          // is_same
21
22 #include <ade/util/zip_range.hpp>   // util::indexed
23
24 #include "api/gapi_priv.hpp"    // GOrigin
25 #include "api/gproto_priv.hpp"  // descriptor_of and other GProtoArg-related
26 #include "api/gcall_priv.hpp"
27 #include "api/gnode_priv.hpp"
28
29 #include "compiler/gmodelbuilder.hpp"
30
31 namespace {
32
33
34 // TODO: move to helpers and cover with internal tests?
35 template<typename T> struct GVisited
36 {
37     typedef std::unordered_set<T> VTs;
38
39     bool visited(const T& t) const { return m_visited.find(t) != m_visited.end(); }
40     void visit  (const T& t)       { m_visited.insert(t); }
41     const VTs& visited()     const { return m_visited; }
42
43 private:
44     VTs m_visited;
45 };
46
47 template<typename T, typename U = T> struct GVisitedTracker: protected GVisited<T>
48 {
49     typedef std::vector<U> TUs;
50
51     void  visit(const T& t, const U& u) { GVisited<T>::visit(t); m_tracked.push_back(u); }
52     const TUs& tracked() const          { return m_tracked; }
53     using GVisited<T>::visited;
54
55 private:
56     TUs m_tracked;
57 };
58
59 } // namespace
60
61
62 cv::gimpl::Unrolled cv::gimpl::unrollExpr(const GProtoArgs &ins,
63                                           const GProtoArgs &outs)
64 {
65     // FIXME: Who's gonna check if ins/outs are not EMPTY?
66     // FIXME: operator== for GObjects? (test if the same object or not)
67     using GObjId = const cv::GOrigin*;
68
69     GVisitedTracker<const GNode::Priv*, cv::GNode> ops;
70     GVisited<GObjId> reached_sources;
71     cv::GOriginSet   origins;
72
73     // Cache input argument objects for a faster look-up
74     // While the only reliable way to identify a Data object is Origin
75     // (multiple data objects may refer to the same Origin as result of
76     // multuple yield() calls), input objects can be uniquely identified
77     // by its `priv` address. Here we rely on this to verify if the expression
78     // we unroll actually matches the protocol specified to us by user.
79     std::unordered_set<GObjId> in_objs_p;
80     for (const auto& in_obj : ins)
81     {
82         // Objects are guarnateed to remain alive while this method
83         // is working, so it is safe to keep pointers here and below
84         in_objs_p.insert(&proto::origin_of(in_obj));
85     }
86
87     // Recursive expression traversal
88     std::stack<cv::GProtoArg> data_objs(std::deque<cv::GProtoArg>(outs.begin(), outs.end()));
89     while (!data_objs.empty())
90     {
91         const auto  obj   = data_objs.top();
92         const auto &obj_p = proto::origin_of(obj);
93         data_objs.pop();
94
95         const auto &origin = obj_p;
96         origins.insert(origin); // TODO: Put Object description here later on
97
98         // If this Object is listed in the protocol, don't dive deeper (even
99         // if it is in fact a result of operation). Our computation is
100         // bounded by this data slot, so terminate this recursion path early.
101         if (in_objs_p.find(&obj_p) != in_objs_p.end())
102         {
103             reached_sources.visit(&obj_p);
104             continue;
105         }
106
107         const cv::GNode &node = origin.node;
108         switch (node.shape())
109         {
110         case cv::GNode::NodeShape::EMPTY:
111             // TODO: Own exception type?
112             util::throw_error(std::logic_error("Empty node reached!"));
113             break;
114
115         case cv::GNode::NodeShape::PARAM:
116         case cv::GNode::NodeShape::CONST_BOUNDED:
117             // No preceding operation to this data object - so the data object is either a GComputation
118             // parameter or a constant (compile-time) value
119             // Record it to check if protocol matches expression tree later
120             if (!reached_sources.visited(&obj_p))
121                 reached_sources.visit(&obj_p);
122             break;
123
124         case cv::GNode::NodeShape::CALL:
125             if (!ops.visited(&node.priv()))
126             {
127                 // This operation hasn't been visited yet - mark it so,
128                 // then add its operands to stack to continue recursion.
129                 ops.visit(&node.priv(), node);
130
131                 const cv::GCall&        call   = origin.node.call();
132                 const cv::GCall::Priv&  call_p = call.priv();
133
134                 // Put the outputs object description of the node
135                 // so that they are not lost if they are not consumed by other operations
136                 for (const auto &it : ade::util::indexed(call_p.m_k.outShapes))
137                 {
138                     std::size_t port  = ade::util::index(it);
139                     GShape shape      = ade::util::value(it);
140
141                     GOrigin org { shape, node, port};
142                     origins.insert(org);
143                 }
144
145                 for (const auto &arg : call_p.m_args)
146                 {
147                     if (proto::is_dynamic(arg))
148                     {
149                         data_objs.push(proto::rewrap(arg)); // Dive deeper
150                     }
151                 }
152             }
153             break;
154
155         default:
156             // Unsupported node shape
157             GAPI_Assert(false);
158             break;
159         }
160     }
161
162     // Check if protocol mentions data_objs which weren't reached during traversal
163     const auto missing_reached_sources = [&reached_sources](GObjId p) {
164         return reached_sources.visited().find(p) == reached_sources.visited().end();
165     };
166     if (ade::util::any_of(in_objs_p, missing_reached_sources))
167     {
168         // TODO: Own exception type or a return code?
169       util::throw_error(std::logic_error("Data object listed in Protocol "
170                                      "wasn\'t reached during unroll"));
171     }
172
173     // Check if there endpoint (parameter) data_objs which are not listed in protocol
174     const auto missing_in_proto = [&in_objs_p](GObjId p) {
175         return p->node.shape() != cv::GNode::NodeShape::CONST_BOUNDED &&
176                in_objs_p.find(p) == in_objs_p.end();
177     };
178     if (ade::util::any_of(reached_sources.visited(), missing_in_proto))
179     {
180         // TODO: Own exception type or a return code?
181       util::throw_error(std::logic_error("Data object reached during unroll "
182                                      "wasn\'t found in Protocol"));
183     }
184
185     return cv::gimpl::Unrolled{ops.tracked(), origins};
186 }
187
188
189 cv::gimpl::GModelBuilder::GModelBuilder(ade::Graph &g)
190     : m_g(g)
191 {
192 }
193
194 cv::gimpl::GModelBuilder::ProtoSlots
195 cv::gimpl::GModelBuilder::put(const GProtoArgs &ins, const GProtoArgs &outs)
196 {
197     const auto unrolled = cv::gimpl::unrollExpr(ins, outs);
198
199     // First, put all operations and its arguments into graph.
200     for (const auto &op_expr_node : unrolled.all_ops)
201     {
202         GAPI_Assert(op_expr_node.shape() == GNode::NodeShape::CALL);
203         const GCall&        call    = op_expr_node.call();
204         const GCall::Priv&  call_p  = call.priv();
205         ade::NodeHandle     call_h  = put_OpNode(op_expr_node);
206
207         for (const auto &it : ade::util::indexed(call_p.m_args))
208         {
209             const auto  in_port = ade::util::index(it);
210             const auto& in_arg  = ade::util::value(it);
211
212             if (proto::is_dynamic(in_arg))
213             {
214                 ade::NodeHandle data_h = put_DataNode(proto::origin_of(in_arg));
215                 cv::gimpl::GModel::linkIn(m_g, call_h, data_h, in_port);
216             }
217         }
218     }
219
220     // Then iterate via all "origins", instantiate (if not yet) Data graph nodes
221     // and connect these nodes with their producers in graph
222     for (const auto &origin : unrolled.all_data)
223     {
224         const cv::GNode& prod = origin.node;
225         GAPI_Assert(prod.shape() != cv::GNode::NodeShape::EMPTY);
226
227         ade::NodeHandle data_h = put_DataNode(origin);
228         if (prod.shape() == cv::GNode::NodeShape::CALL)
229         {
230             ade::NodeHandle call_h = put_OpNode(prod);
231             cv::gimpl::GModel::linkOut(m_g, call_h, data_h, origin.port);
232         }
233     }
234
235     // Mark graph data nodes as INPUTs and OUTPUTs respectively (according to the protocol)
236     for (const auto &arg : ins)
237     {
238         ade::NodeHandle nh = put_DataNode(proto::origin_of(arg));
239         m_g.metadata(nh).get<Data>().storage = Data::Storage::INPUT;
240     }
241     for (const auto &arg : outs)
242     {
243         ade::NodeHandle nh = put_DataNode(proto::origin_of(arg));
244         m_g.metadata(nh).get<Data>().storage = Data::Storage::OUTPUT;
245     }
246
247     // And, finally, store data object layout in meta
248     m_g.metadata().set(Layout{m_graph_data});
249
250     // After graph is generated, specify which data objects are actually
251     // computation entry/exit points.
252     using NodeDescr = std::pair<std::vector<RcDesc>,
253                                 std::vector<ade::NodeHandle> >;
254
255     const auto get_proto_slots = [&](const GProtoArgs &proto) -> NodeDescr
256     {
257         NodeDescr slots;
258
259         slots.first.reserve(proto.size());
260         slots.second.reserve(proto.size());
261
262         for (const auto &arg : proto)
263         {
264             ade::NodeHandle nh = put_DataNode(proto::origin_of(arg));
265             const auto &desc = m_g.metadata(nh).get<Data>();
266             //These extra empty {} are to please GCC (-Wmissing-field-initializers)
267             slots.first.push_back(RcDesc{desc.rc, desc.shape, {}});
268             slots.second.push_back(nh);
269         }
270         return slots;
271     };
272
273     auto in_slots  = get_proto_slots(ins);
274     auto out_slots = get_proto_slots(outs);
275     return ProtoSlots{in_slots.first,  out_slots.first,
276                       in_slots.second, out_slots.second};
277 }
278
279 ade::NodeHandle cv::gimpl::GModelBuilder::put_OpNode(const cv::GNode &node)
280 {
281     const auto& node_p = node.priv();
282     const auto  it     = m_graph_ops.find(&node_p);
283     if (it == m_graph_ops.end())
284     {
285         GAPI_Assert(node.shape() == GNode::NodeShape::CALL);
286         const auto &call_p = node.call().priv();
287         auto nh = cv::gimpl::GModel::mkOpNode(m_g, call_p.m_k, call_p.m_args, node_p.m_island);
288         m_graph_ops[&node_p] = nh;
289         return nh;
290     }
291     else return it->second;
292 }
293
294 // FIXME: rename to get_DataNode (and same for Op)
295 ade::NodeHandle cv::gimpl::GModelBuilder::put_DataNode(const GOrigin &origin)
296 {
297     const auto it = m_graph_data.find(origin);
298     if (it == m_graph_data.end())
299     {
300         auto nh = cv::gimpl::GModel::mkDataNode(m_g, origin);
301         m_graph_data[origin] = nh;
302         return nh;
303     }
304     else return it->second;
305 }