9721c5e9a8481d06cc6965ffa562e6a21f37477a
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / net_pass.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "net_pass.h"
6 #include "blob_factory.hpp"
7 #include "ie_memcpy.h"
8 #include "details/ie_cnn_network_tools.h"
9 #include "ie_layers_internal.hpp"
10 #include "graph_tools.hpp"
11
12 #include <string>
13 #include <utility>
14 #include <algorithm>
15 #include <memory>
16 #include <tuple>
17 #include <set>
18 #include <unordered_map>
19 #include <unordered_set>
20
21 namespace InferenceEngine {
22 namespace NetPass {
23
24 template <typename T, typename P>
25 inline bool one_of(T val, P item) { return val == item; }
26 template <typename T, typename P, typename... Args>
27 inline bool one_of(T val, P item, Args... item_others) {
28     return val == item || one_of(val, item_others...);
29 }
30
31 /************************************************************/
32 /****  TI Utils  ********************************************/
33 /************************************************************/
34
35 static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr> &heads) {
36     CNNLayerSet inputLayers;
37     std::unordered_set<CNNLayer*> allLayers;
38
39     // Define all start layers
40     for (const auto & data : heads) {
41         auto &secondLayers = data->getInputTo();
42
43         details::UnorderedDFS(allLayers, secondLayers.begin()->second, [&](CNNLayerPtr layer) {
44             if (layer->insData.empty()) {
45                 inputLayers.insert(layer);
46             }
47         }, false);
48     }
49
50     std::vector<DataPtr> res = heads;
51     // Add fake input data to point on not achievable
52     // layers from head (like const placeholders)
53     for (auto &starter : inputLayers) {
54         DataPtr holder(new Data(starter->name + ":input_holder", starter->precision));
55         holder->getInputTo()[starter->name] = starter;
56         res.push_back(holder);
57     }
58
59     return res;
60 }
61
62 std::vector<CNNLayerPtr> TIBodySortTopologically(const TensorIterator::Body &body) {
63     std::vector<CNNLayerPtr> all_layers;
64
65     auto all_input_layers = getAllInputs(body.inputs);
66     CNNNetForestDFS(all_input_layers, [&](CNNLayerPtr  current){
67         all_layers.push_back(current);
68     }, false);
69     std::reverse(all_layers.begin(), all_layers.end());
70     return all_layers;
71 }
72
73 TensorIterator::Body CopyTIBody(const TensorIterator::Body &body, std::string suffix) {
74     struct NoneStruct {};
75     auto cp = [&](CNNLayerPtr lp) {
76         return injectData<NoneStruct>(lp);
77     };
78
79     const auto all_orig = TIBodySortTopologically(body);
80     auto num = all_orig.size();
81
82     std::unordered_map<CNNLayer*, CNNLayerPtr> old2new_l;
83     for (int i = 0; i < num; i++) {
84         auto &orig = all_orig[i];
85         old2new_l[orig.get()] = cp(orig);
86     }
87
88     std::unordered_map<Data*, DataPtr> old2new_d;
89     for (auto &in : body.inputs) {
90         auto new_data = std::make_shared<Data>(*in.get());
91         for (auto &to : new_data->getInputTo())
92             to.second = old2new_l[to.second.get()];
93
94         old2new_d[in.get()] = new_data;
95     }
96
97     for (const auto &old : all_orig) {
98         auto &new_one = old2new_l[old.get()];
99         // remap output data
100         for (int i = 0; i != old->outData.size(); i++) {
101             auto old_data = old->outData[i];
102             auto new_data = new_one->outData[i];
103             new_data->getCreatorLayer() = CNNLayerWeakPtr(new_one);
104             old2new_d[old_data.get()] = new_data;
105
106             for (auto &to : new_data->getInputTo())
107                 to.second = old2new_l[to.second.get()];
108         }
109         // remap input data
110         for (int i = 0; i != old->insData.size(); i++) {
111             auto old_data = old->insData[i].lock();
112             auto new_data = old2new_d.at(old_data.get());
113             new_one->insData[i] = new_data;
114         }
115     }
116
117     // Add suffix
118     if (!suffix.empty()) {
119         for (auto &kvp : old2new_l) {
120             auto layer = kvp.second;
121             auto old_name = layer->name;
122             layer->name += suffix;
123             for (auto &ins : layer->insData) {
124                 ins.lock()->getInputTo().erase(old_name);
125                 ins.lock()->getInputTo()[layer->name] = layer;
126             }
127         }
128         for (auto &kvp : old2new_d) kvp.second->setName(kvp.second->getName() + suffix);
129     }
130
131     TensorIterator::Body res;
132     for (auto &in : body.inputs)
133         res.inputs.emplace_back(old2new_d[in.get()]);
134
135     for (auto &out : body.outputs)
136         res.outputs.emplace_back(old2new_d[out.get()]);
137
138     // Fake holder.
139     // The graph itself is a shared_ptr set where parent holds child.
140     // Res.inputs vector hold head of graph and all nodes should be
141     // achievable for oriented search started from that. But place
142     // const holder has no input and cannot be achieved. So we need
143     // to hold then in other way.
144     //
145     // Let's add one more Data object which has no representation in
146     // original network. It will hold all unreachable const placeholder
147     // nodes.
148     //
149     std::vector<CNNLayerPtr> to_hold;
150     for (auto &kvp : old2new_l) {
151         auto layer = kvp.second;
152         if (layer->insData.empty())
153             to_hold.emplace_back(layer);
154     }
155     if (!to_hold.empty()) {
156         auto holder = DataPtr(new Data("const_holder", Precision::UNSPECIFIED));
157         for (auto layer : to_hold) {
158             holder->getInputTo()[layer->name] = layer;
159         }
160         res.inputs.emplace_back(holder);
161     }
162
163     return res;
164 }
165
166 /************************************************************/
167 /****  TI rule helpers  *************************************/
168 /************************************************************/
169
170 inline bool is_full_ranged(const TensorIterator::PortMap& rule, const DataPtr &data) {
171     if (!data)
172         THROW_IE_EXCEPTION << "Internal error. data == nullptr";
173
174     if (rule.axis == -1 || !one_of(rule.stride, 1, -1))
175         return false;
176
177     auto &shape = data->getDims();
178     int size = shape[rule.axis];
179
180     int begin = rule.start >= 0 ? rule.start : size + rule.start + 1;
181     int end = rule.end >= 0 ? rule.end : size + rule.end + 1;
182
183     return (rule.stride == 1)
184         ? begin == 0 && end == size
185         : begin == size && end == 0;
186 }
187
188 using RuleSet = std::vector<TensorIterator::PortMap>;
189 using RuleClassSet = std::tuple<RuleSet, RuleSet, RuleSet>;
190
191 /**
192  * @brief Helper to split port mapping rules to three group
193  *
194  *   first_class  - which has iteration component
195  *   second_class - which has no iteration and there are no backedge connection to the same port
196  *   third_class  - which has no iteration and has corresponding backedge
197  *
198  * @param ti TensorIterator layer to analyze
199  * @return tuple with three classes of port map rule
200  */
201 static RuleClassSet classifyInputRules(const TensorIterator &ti) {
202     RuleSet first_class_rules, second_class_rules, third_class_rules;
203
204     std::set<int> ports_with_backedge;
205     for (const auto &back_edge : ti.back_edges) ports_with_backedge.insert(back_edge.to);
206
207     for (const auto &rule : ti.input_port_map) {
208         if (rule.axis != -1)
209             first_class_rules.push_back(rule);
210
211         else if (!ports_with_backedge.count(rule.to))
212             second_class_rules.push_back(rule);
213
214         else
215             third_class_rules.push_back(rule);
216     }
217     return RuleClassSet {first_class_rules, second_class_rules, third_class_rules};
218 }
219
220 static RuleClassSet classifyOutputRules(const TensorIterator &ti) {
221     RuleSet first_class_rules, second_class_rules, third_class_rules;
222
223     std::set<int> ports_with_backedge;
224     for (const auto &back_edge : ti.back_edges) ports_with_backedge.insert(back_edge.from);
225
226     for (const auto &rule : ti.output_port_map) {
227         if (rule.axis != -1)
228             first_class_rules.push_back(rule);
229
230         else if (!ports_with_backedge.count(rule.to))
231             second_class_rules.push_back(rule);
232
233         else
234             third_class_rules.push_back(rule);
235     }
236     return RuleClassSet {first_class_rules, second_class_rules, third_class_rules};
237 }
238
239 /**
240  * Merge slave connections into master
241  * @param master
242  * @param slave
243  */
244 void CombineData(DataPtr &master, DataPtr &slave) {
245     for (auto &kvp : slave->getInputTo()) {
246         auto &slave_layer = kvp.second;
247         for (auto &slv_ins_wptr : slave_layer->insData) {
248             auto slv_ins = slv_ins_wptr.lock();
249             // Replace slave ptr with master
250             if (slv_ins == slave) slv_ins_wptr = master;
251         }
252         master->getInputTo()[slave_layer->name] = slave_layer;
253     }
254 }
255
256 /************************************************************/
257 /****  Converter Passes  ************************************/
258 /************************************************************/
259
260 static RNNSequenceLayer::CellType cell_type_from_name(std::string &layer_type) {
261     RNNSequenceLayer::CellType res;
262     if (layer_type == "LSTMCell")
263         res = RNNSequenceLayer::LSTM;
264     else if (layer_type == "GRUCell")
265         res = RNNSequenceLayer::GRU;
266     else if (layer_type == "RNNCell")
267         res = RNNSequenceLayer::RNN;
268     else
269         THROW_IE_EXCEPTION << "Unknown Cell type (" << layer_type << "). Expected LSTMCell|GRUCell|RNNCell";
270     return res;
271 }
272
273 static std::string cell_name(RNNSequenceLayer::CellType type) {
274     std::string res;
275     switch (type) {
276         case RNNSequenceLayer::LSTM:
277             res = "LSTM";
278             break;
279         case RNNSequenceLayer::GRU:
280         case RNNSequenceLayer::GRU_LBR:
281             res = "GRU";
282             break;
283         case RNNSequenceLayer::RNN:
284             res = "RNN";
285             break;
286     }
287     return res;
288 }
289
290 template<typename N>
291 bool convertToRNNSeq(CNNLayerPtr cur, const N &net) {
292     if (cur->type != "TensorIterator") return true;
293
294     auto ti = std::dynamic_pointer_cast<TensorIterator>(cur);
295     IE_ASSERT(ti) << "Cannot cast object with type TensorIterator to TensorIterator object";
296
297     auto all_body_layers = TIBodySortTopologically(ti->body);
298
299     // Check if body is:  squeeze -> lstm_cell -> unsqueeze
300     if (all_body_layers.size() != 3
301         || all_body_layers[0]->type != "Reshape"
302         || !one_of(all_body_layers[1]->type, "GRUCell", "RNNCell", "LSTMCell")
303         || all_body_layers[2]->type != "Reshape")
304         return false;
305
306     auto rsp1 = std::dynamic_pointer_cast<ReshapeLayer>(all_body_layers[0]);
307     auto cell = std::dynamic_pointer_cast<RNNCellBase>(all_body_layers[1]);
308     auto rsp2 = std::dynamic_pointer_cast<ReshapeLayer>(all_body_layers[2]);
309
310     IE_ASSERT(rsp1);
311     IE_ASSERT(cell);
312     IE_ASSERT(rsp2);
313
314     int NS = (cell->cellType == RNNSequenceLayer::LSTM) ? 2 : 1;  // number of states
315
316     IE_ASSERT(cell->insData.size() == NS + 1);  // {data, state1, [state2]}
317     IE_ASSERT(cell->outData.size() == NS);  // {state1, [state2]}
318
319     if (cell->insData[0].lock()->getCreatorLayer().lock() != rsp1 ||
320         cell->outData[0]->getInputTo().begin()->second != rsp2)
321         return false;
322
323     // Check port mapping
324     auto _indx_in = [&] (const std::vector<DataPtr> &scope,  const DataPtr &data) {
325         int indx = std::find(scope.begin(), scope.end(), data) - scope.begin();
326         return indx == scope.size() ? -1 : indx;
327     };
328
329     int in_dt_idx = _indx_in(ti->body.inputs, rsp1->insData[0].lock());
330     int in_hs_idx = _indx_in(ti->body.inputs, cell->insData[1].lock());
331     int in_cs_idx = NS == 2 ? _indx_in(ti->body.inputs, cell->insData[2].lock()) : -1;
332
333     int out_dt_idx = _indx_in(ti->body.outputs, rsp2->outData[0]);
334     int out_hs_idx = _indx_in(ti->body.outputs, cell->outData[0]);
335     int out_cs_idx = NS == 2 ? _indx_in(ti->body.outputs, cell->outData[1]) : -1;
336
337     // indexes should be [0,1,2] : sum == 3 or [0,1,-1] : sum == 0
338     int sum = (NS - 1) * 3;
339     if (in_hs_idx + in_cs_idx + in_dt_idx != sum || out_hs_idx + out_cs_idx + out_dt_idx != sum)
340         return false;
341
342     std::map<int, TensorIterator::PortMap> i2map, o2map, be2map;
343     for (auto &m : ti->input_port_map) i2map[m.to] = m;
344     for (auto &m : ti->output_port_map) o2map[m.to] = m;
345     for (auto &m : ti->back_edges) be2map[m.to] = m;
346
347     if (!one_of(i2map.size(), NS + 1, 1) ||
348         !one_of(o2map.size(), NS + 1, 1) ||
349         !one_of(be2map.size(), NS))
350         return false;
351
352     auto in_iter_rule = i2map[in_dt_idx];
353     auto in_iter_data = ti->insData[in_iter_rule.from].lock();
354
355     auto out_iter_rule = o2map[out_dt_idx];
356     auto out_iter_data = ti->outData[out_iter_rule.from];
357
358     // TI iterates only for full range of tensor
359     if (!is_full_ranged(in_iter_rule, in_iter_data) ||
360         !is_full_ranged(out_iter_rule, out_iter_data))
361         return false;
362
363     // supported only same axis and strides for in/out data tensors
364     if (in_iter_rule.axis != out_iter_rule.axis ||
365         in_iter_rule.stride != out_iter_rule.stride)
366         return false;
367
368     // supported only firs and second dim for LSTM-Sequence
369     if (!one_of(in_iter_rule.axis, 0, 1))
370         return false;
371
372     bool no_init_state = i2map.size() == 1;
373     bool no_last_state = o2map.size() == 1;
374
375     if (!no_init_state && ( i2map[in_hs_idx].axis != -1 || (NS == 2 && i2map[in_cs_idx].axis != -1) ))
376         return false;
377     if (!no_last_state && ( o2map[out_hs_idx].axis != -1 || (NS == 2 && o2map[out_cs_idx].axis != -1) ))
378         return false;
379
380     std::vector<int> i_order {i2map[in_dt_idx].from };
381     if (!no_init_state)
382         i_order.push_back(i2map[in_hs_idx].from);
383     if (!no_init_state && NS == 2)
384         i_order.push_back(i2map[in_cs_idx].from);
385
386     std::vector<int> o_order {o2map[out_dt_idx].from};
387     if (!no_last_state)
388         o_order.push_back(o2map[out_hs_idx].from);
389     if (!no_last_state && NS == 2)
390         o_order.push_back(o2map[out_cs_idx].from);
391
392     // need swap an i/o ports if it is not in natural order
393     std::string name = cell->name + "_sequence";
394     std::string type = cell_name(cell->cellType) + "Sequence";
395
396     auto rnn  = std::make_shared<RNNSequenceLayer>(LayerParams{ name, type, cell->precision});
397     rnn->axis = in_iter_rule.axis;
398     rnn->direction = in_iter_rule.stride == 1
399             ? RNNSequenceLayer::FWD
400             : RNNSequenceLayer::BWD;
401
402     // copy base RNN cell fields
403     rnn->cellType = cell->cellType;
404     rnn->_weights = cell->_weights;
405     rnn->_biases = cell->_biases;
406     rnn->blobs["weights"] = rnn->_weights;
407     rnn->blobs["biases"] = rnn->_biases;
408     rnn->blobs = cell->blobs;
409     rnn->activations = cell->activations;
410     rnn->activation_alpha = cell->activation_alpha;
411     rnn->activation_beta = cell->activation_beta;
412     rnn->hidden_size = cell->hidden_size;
413     rnn->clip = cell->clip;
414
415     for (int i : i_order) {
416         auto in_data = ti->insData[i].lock();
417         in_data->getInputTo().erase(ti->name);
418         in_data->getInputTo()[rnn->name] = rnn;
419         rnn->insData.push_back(in_data);
420     }
421     for (int i : o_order) {
422         rnn->outData.push_back(ti->outData[i]);
423         rnn->outData.back()->getCreatorLayer() = rnn;
424     }
425
426     return true;
427 }
428
429 bool unrollTI(CNNLayerPtr cur, ICNNNetwork &net) {
430     if (cur->type != "TensorIterator")
431         return true;
432
433     auto ti = std::dynamic_pointer_cast<TensorIterator>(cur);
434     IE_ASSERT(ti) << "Cannot cast object with type TensorIterator to TensorIterator object";
435
436     int num = getNumIteration(*ti);  // -1 means inconsistent TI
437     if (num == -1) return false;  // TODO: better to throw exception
438
439     const auto &body = ti->body;
440
441     std::vector<TensorIterator::Body> body_list(num);
442     for (int i = 0; i < num; i++) {
443         // copy with additional suffix to each object name
444         body_list[i] = CopyTIBody(body, ":" + std::to_string(i));
445
446         auto holder = body_list[i].inputs.back();
447         if (holder->getPrecision() == Precision::UNSPECIFIED) {
448             for (auto kvp : holder->getInputTo())
449
450                 net.addLayer(kvp.second);
451         }
452     }
453
454     RuleSet first_class, second_class, third_class;
455     std::tie(first_class, second_class, third_class) = classifyInputRules(*ti);
456
457     /** Clean links on TI */
458     for (auto &ins : ti->insData)
459         ins.lock()->getInputTo().erase(ti->name);
460     for (auto &outs : ti->outData)
461         outs->getCreatorLayer().reset();
462
463     /** FIRST class comes */
464     for (int i = 0; i < first_class.size(); i++) {
465         auto &rule = first_class[i];
466         auto in_data = ti->insData[rule.from].lock();
467
468         std::string name = ti->name + ":in_split_" + std::to_string(i);
469         auto split = std::make_shared<SplitLayer>(LayerParams{ name, "Split", cur->precision });
470         split->_axis = rule.axis;
471         split->outData.resize(num);
472         split->insData.emplace_back(in_data);
473         in_data->getInputTo()[split->name] = split;
474
475         for (int j = 0; j < num; j++) {
476             auto body_idx = rule.stride == 1 ? j : num - 1 - j;
477             auto &chunk = body_list[body_idx].inputs[rule.to];
478             chunk->getCreatorLayer() = split;
479             split->outData[j] = chunk;
480         }
481     }
482
483     /** SECOND class come on */
484     for (const auto &rule : second_class) {
485         auto in_data = ti->insData[rule.from].lock();
486
487         for (int j = 0; j < num; j++) {
488             auto &chunk = body_list[j].inputs[rule.to];
489             CombineData(in_data, chunk);
490         }
491     }
492
493     /** BACK EDGES that's your time */
494     for (const auto &rule : ti->back_edges) {
495         for (int i = 1; i < num; i++) {
496             auto &from_data = body_list[i-1].outputs[rule.from];
497             auto &to_data = body_list[i].inputs[rule.to];
498             CombineData(from_data, to_data);
499         }
500     }
501
502     /** THIRD class end up */
503     for (const auto &rule : third_class) {
504         // first iteration
505         auto from_data = ti->insData[rule.from].lock();
506         auto &to_data = body_list[0].inputs[rule.to];;
507         CombineData(from_data, to_data);
508     }
509
510     /** And the same actions for outputs connections */
511     std::tie(first_class, second_class, third_class) = classifyOutputRules(*ti);
512
513     /** FIRST class comes */
514     for (int i = 0; i < first_class.size(); i++) {
515         auto &rule = first_class[i];
516         auto out_data = ti->outData[rule.from];
517
518         std::string name = ti->name + ":out_concat_" + std::to_string(i);
519         auto concat = std::make_shared<ConcatLayer>(LayerParams{ name, "Concat", cur->precision });
520         concat->_axis = rule.axis;
521         concat->insData.resize(num);
522         concat->outData.emplace_back(out_data);
523         out_data->getCreatorLayer() = concat;
524
525         for (int j = 0; j < num; j++) {
526             auto body_idx = rule.stride == 1 ? j : num - 1 - j;
527             auto &chunk = body_list[body_idx].outputs[rule.to];
528             chunk->getInputTo()[concat->name] = concat;
529             concat->insData[j] = chunk;
530         }
531     }
532
533     /** SECOND class come on */
534     for (const auto &rule : second_class) {
535         auto out_data = ti->outData[rule.from];
536
537         for (int j = 0; j < num; j++) {
538             auto &chunk = body_list[j].outputs[rule.to];
539             CombineData(chunk, out_data);
540         }
541     }
542
543     /** THIRD class end up */
544     for (const auto &rule : third_class) {
545         // first iteration
546         auto &from_data = ti->outData[rule.from];
547         auto &to_data = body_list[num-1].outputs[rule.to];
548
549         auto parent = to_data->getCreatorLayer().lock();
550         std::replace(parent->outData.begin(), parent->outData.end(), to_data, from_data);
551         from_data->getCreatorLayer() = parent;
552
553         CombineData(from_data, to_data);
554     }
555     return true;
556 }
557
558 /************************************************************/
559 /****  Builder helpers   ************************************/
560 /************************************************************/
561
562 static CNNLayerPtr _concat(std::string name, Precision prc, SizeVector dims, int num) {
563     auto res = std::make_shared<ConcatLayer>(LayerParams{name, "Concat", prc});
564     res->_axis = 1;
565
566     res->insData.resize(num);
567     res->outData.resize(1);
568
569     auto out_data = DataPtr(new Data(name,
570             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
571     out_data->getCreatorLayer() = res;
572
573     res->outData[0] = out_data;
574     return res;
575 }
576
577 static CNNLayerPtr _split(std::string name, Precision prc, SizeVector dims, int num) {
578     auto res = std::make_shared<SplitLayer>(LayerParams{name, "Split", prc});
579     res->_axis = 1;
580     res->params["axis"] = std::to_string(res->_axis);
581
582     res->insData.resize(1);
583     res->outData.resize(num);
584
585     for (int i = 0; i < num; i++) {
586         auto out_data = DataPtr(new Data(name + "_part_" + std::to_string(i),
587                 TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
588         out_data->getCreatorLayer() = res;
589
590         res->outData[i] = out_data;
591     }
592     return res;
593 }
594
595 static CNNLayerPtr _fc(std::string name, Precision prc, SizeVector dims, Blob::Ptr &W, Blob::Ptr &B) {
596     auto res = std::make_shared<FullyConnectedLayer>(LayerParams{name, "FullyConnected", prc});
597
598     res->_weights = W;
599     res->_biases = B;
600     res->_out_num = dims[1];
601     res->blobs["weights"] = W;
602     res->blobs["biases"] = B;
603     res->params["out-size"] = std::to_string(dims[1]);
604
605     res->insData.resize(1);
606     res->outData.resize(1);
607
608     auto out_data = DataPtr(new Data(name,
609             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
610     out_data->getCreatorLayer() = res;
611
612     res->outData[0] = out_data;
613     return res;
614 }
615
616 static CNNLayerPtr _act(std::string name, Precision prc, SizeVector dims, std::string type) {
617     auto res = std::make_shared<CNNLayer>(LayerParams{name, type, prc});
618
619     res->params["type"] = type;
620
621     res->insData.resize(1);
622     res->outData.resize(1);
623
624     auto out_data = DataPtr(new Data(name,
625             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
626     out_data->getCreatorLayer() = res;
627
628     res->outData[0] = out_data;
629     return res;
630 }
631
632 static CNNLayerPtr _pwr(std::string name, Precision prc, SizeVector dims, float scale, float shift) {
633     auto res = std::make_shared<PowerLayer>(LayerParams{name, "Power", prc});
634
635     res->power = 1.0;
636     res->scale = scale;
637     res->offset = shift;
638     res->params["power"] = std::to_string(res->power);
639     res->params["scale"] = std::to_string(res->scale);
640     res->params["shift"] = std::to_string(res->offset);
641
642     res->insData.resize(1);
643     res->outData.resize(1);
644
645     auto out_data = DataPtr(new Data(name,
646             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
647     out_data->getCreatorLayer() = res;
648
649     res->outData[0] = out_data;
650     return res;
651 }
652
653
654 static CNNLayerPtr _eltw(std::string name, Precision prc, SizeVector dims, std::string type) {
655     auto res = std::make_shared<EltwiseLayer>(LayerParams{name, "Eltwise", prc});
656
657     res->params["operation"] = type;
658     res->_operation = type == "sum" ? EltwiseLayer::Sum : EltwiseLayer::Prod;
659
660     res->insData.resize(2);
661     res->outData.resize(1);
662
663     auto out_data = DataPtr(new Data(name,
664             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
665     out_data->getCreatorLayer() = res;
666
667     res->outData[0] = out_data;
668     return res;
669 }
670
671 static std::shared_ptr<ReshapeLayer> _resh(std::string name, Precision prc, SizeVector dims) {
672     auto res = std::make_shared<ReshapeLayer>(LayerParams{name, "Reshape", prc});
673
674     res->insData.resize(1);
675     res->outData.resize(1);
676
677     auto out_data = DataPtr(new Data(name,
678             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
679     out_data->getCreatorLayer() = res;
680
681     res->outData[0] = out_data;
682     return res;
683 }
684
685 static std::shared_ptr<RNNCellBase> _cell(std::string name, Precision prc, SizeVector data_dims, SizeVector state_dims, RNNSequenceLayer::CellType type) {
686     std::shared_ptr<RNNCellBase> res;
687     size_t NS = 1;
688     switch (type) {
689         case RNNSequenceLayer::LSTM:
690             res = std::make_shared<LSTMCell>(LayerParams{name, "LSTMCell", prc}); NS = 2;
691             break;
692         case RNNSequenceLayer::GRU:
693         case RNNSequenceLayer::GRU_LBR:
694             res = std::make_shared<GRUCell>(LayerParams{name, "GRUCell", prc});
695             break;
696         case RNNSequenceLayer::RNN:
697             res = std::make_shared<RNNCell>(LayerParams{name, "RNNCell", prc});
698             break;
699     }
700
701     res->cellType = type;
702     res->insData.resize(1 + NS);
703     res->outData.resize(NS);
704
705     auto out_data = DataPtr(new Data(name + ":out_data",
706             TensorDesc { prc, data_dims, TensorDesc::getLayoutByDims(data_dims) }));
707     out_data->getCreatorLayer() = res;
708     res->outData[0] = out_data;
709
710     for (size_t i = 0; i < NS; i++) {
711         auto out_state = DataPtr(new Data(name + ":out_state_" + std::to_string(i),
712                 TensorDesc { prc, state_dims, TensorDesc::getLayoutByDims(state_dims) }));
713         out_state->getCreatorLayer() = res;
714         res->outData[i] = out_state;
715     }
716
717     return res;
718 }
719
720 static std::shared_ptr<TensorIterator> _ti(std::string name, Precision prc, size_t NS) {
721     auto res = std::make_shared<TensorIterator>(LayerParams{name, "TensorIterator", prc});
722
723     res->insData.resize(1 + NS);
724     res->outData.resize(1 + NS);
725
726     return res;
727 }
728
729 static void _link(CNNLayerPtr src, CNNLayerPtr dst, size_t src_port = 0, size_t dst_port = 0) {
730     auto data = src->outData[src_port];
731     data->getInputTo()[dst->name] = dst;
732     dst->insData[dst_port] = data;
733 }
734
735 static void _link(DataPtr &data, CNNLayerPtr dst, size_t dst_port = 0) {
736     data->getInputTo()[dst->name] = dst;
737     dst->insData[dst_port] = data;
738 }
739
740 /** Link nodes with clipping data if required (clip_val != 0.0) */
741 static void _link_with_clip(CNNLayerPtr src, CNNLayerPtr dst, const float clip_val,
742         size_t src_port = 0, size_t dst_port = 0) {
743     if (clip_val == 0.0f) {
744         _link(src, dst, src_port, dst_port);
745     } else {
746         auto clip_name = dst->name + "_clip";
747         auto clip_prc = dst->precision;
748         auto clip_shape = src->outData[src_port]->getTensorDesc().getDims();
749         auto clip = _act(clip_name, clip_prc, clip_shape, "clamp");
750         clip->params["min"] = std::to_string(-clip_val);
751         clip->params["max"] = std::to_string(clip_val);
752
753         _link(src, clip, src_port, 0);
754         _link(clip, dst, 0, dst_port);
755     }
756 }
757
758
759 static Blob::Ptr make_partial_copy(Blob::Ptr src, size_t off, size_t size) {
760     auto res = make_plain_blob(src->getTensorDesc().getPrecision(), {size});
761     res->allocate();
762
763     size_t elem_size = src->getTensorDesc().getPrecision().size();
764     auto src_ptr = src->buffer().as<uint8_t*>();
765     auto dst_ptr = res->buffer().as<uint8_t*>();
766
767     ie_memcpy(dst_ptr, res->byteSize(), src_ptr + off * elem_size,  size * elem_size);
768
769     return res;
770 }
771
772 static Blob::Ptr wrap_as_tensor(Blob::Ptr src, SizeVector dims) {
773     auto res = make_blob_with_precision(
774             TensorDesc { src->getTensorDesc().getPrecision(), dims, TensorDesc::getLayoutByDims(dims) },
775             src->buffer());
776     IE_ASSERT(src->size() == res->size());
777     return res;
778 }
779
780 static Blob::Ptr make_region_copy(Blob::Ptr src, SizeVector region, SizeVector offset) {
781     IE_ASSERT(region.size() == offset.size());
782     IE_ASSERT(region.size() == src->getTensorDesc().getDims().size());
783
784     auto res = make_plain_blob(src->getTensorDesc().getPrecision(), region);
785     res->allocate();
786
787     size_t elem_size = src->getTensorDesc().getPrecision().size();
788     auto src_ptr = src->buffer().as<uint8_t*>();
789     auto dst_ptr = res->buffer().as<uint8_t*>();
790
791     auto &dd = src->getTensorDesc().getDims();
792     SizeVector src_dims {1, 1, 1};
793     std::copy(dd.begin(), dd.end(), src_dims.end() - dd.size());
794
795     SizeVector dims {1, 1, 1};
796     std::copy(region.begin(), region.end(), dims.end() - region.size());
797
798     SizeVector off {0, 0, 0};
799     std::copy(offset.begin(), offset.end(), off.end() - offset.size());
800
801     const auto D1 = dims[0];
802     const auto D2 = dims[1];
803     const auto D3 = dims[2];
804     const auto off1 = off[0];
805     const auto off2 = off[1];
806     const auto off3 = off[2];
807     const auto str1 = src_dims[1]*src_dims[2];
808     const auto str2 = src_dims[2];
809
810     for (size_t d1 = 0; d1 < D1; d1++)
811     for (size_t d2 = 0; d2 < D2; d2++) {
812         auto off_src = (off1 + d1)*str1 + (off2 + d2)*str2 + off3;
813         auto off_dst = d1*D2*D3 + d2*D3;
814         ie_memcpy(dst_ptr + off_dst * elem_size, res->byteSize(), src_ptr + off_src * elem_size,  D3 * elem_size);
815     }
816
817     return res;
818 }
819
820
821 static bool unrollRNNCellBody(CNNLayerPtr cur) {
822     if (cur->type != "RNNCell")
823         return true;
824
825     auto cell = std::dynamic_pointer_cast<RNNCellBase>(cur);
826     IE_ASSERT(cell) << "Cannot cast object with type ***Cell to WeightableLayer object";
827
828     auto name = cell->name;
829
830     auto in_data = cell->insData[0].lock();
831     auto in_h_state = cell->insData[1].lock();
832     auto out_h_state = cell->outData[0];
833
834     auto d_dims = in_data->getTensorDesc().getDims();
835     auto s_dims = in_h_state->getTensorDesc().getDims();
836
837     size_t N = d_dims[0];
838     size_t D = d_dims[1];
839     size_t S = s_dims[1];
840
841     auto prc = cell->precision;
842
843     /** Release links on TI */
844     for (auto &ins : cell->insData)
845         ins.lock()->getInputTo().erase(cell->name);
846     for (auto &outs : cell->outData)
847         outs->getCreatorLayer().reset();
848
849     // operations
850     auto concat = _concat(name + ":concat", prc, {N, D+S}, 2);
851     auto fc = _fc(name + ":fc", prc, {N, S}, cell->_weights, cell->_biases);
852     auto act = _act(name + ":act", prc, {N, S}, cell->activations[0]);
853
854     // Connection
855     _link(in_data, concat, 0);
856     _link(in_h_state, concat, 1);
857     _link(concat, fc);
858     _link_with_clip(fc, act, cell->clip);
859
860     // Output
861     act->outData[0] = out_h_state;
862     out_h_state->getCreatorLayer() = act;
863
864     return true;
865 }
866
867 static bool unrollLSTMCellBody(CNNLayerPtr cur) {
868     if (cur->type != "LSTMCell")
869         return true;
870
871     auto cell = std::dynamic_pointer_cast<RNNCellBase>(cur);
872     IE_ASSERT(cell) << "Cannot cast object with type ***Cell to WeightableLayer object";
873
874     auto name = cell->name;
875
876     auto in_data = cell->insData[0].lock();
877     auto in_h_state = cell->insData[1].lock();
878     auto in_c_state = cell->insData[2].lock();
879     auto out_h_state = cell->outData[0];
880     auto out_c_state = cell->outData[1];
881
882     auto d_dims = in_data->getTensorDesc().getDims();
883     auto s_dims = in_h_state->getTensorDesc().getDims();
884
885     size_t N = d_dims[0];
886     size_t D = d_dims[1];
887     size_t S = s_dims[1];
888     size_t G = 4;
889
890     auto prc = cell->precision;
891
892     /** Release links on TI */
893     for (auto &ins : cell->insData)
894         ins.lock()->getInputTo().erase(cell->name);
895     for (auto &outs : cell->outData)
896         outs->getCreatorLayer().reset();
897
898     // operations
899     auto concat = _concat(name + ":concat", prc, {N, D+S}, 2);
900     auto split = _split(name + ":split", prc, {N, S}, G);
901     auto fc = _fc(name + ":fc", prc, {N, S*G}, cell->_weights, cell->_biases);
902
903     const std::string _f = cell->activations[0], _g = cell->activations[1], _h = cell->activations[2];
904
905     auto act_f = _act(name + ":act_f", prc, {N, S}, _f);
906     auto act_i = _act(name + ":act_i", prc, {N, S}, _f);
907     auto act_c = _act(name + ":act_c", prc, {N, S}, _g);
908     auto act_o = _act(name + ":act_o", prc, {N, S}, _f);
909     auto act_x = _act(name + ":act_x", prc, {N, S}, _h);
910
911     auto mul_ic = _eltw(name + ":mul_ic", prc, {N, S}, "mul");
912     auto mul_f  = _eltw(name + ":mul_f" , prc, {N, S}, "mul");
913     auto sum    = _eltw(name + ":sum"   , prc, {N, S}, "sum");
914     auto mul    = _eltw(name + ":mul"   , prc, {N, S}, "mul");
915
916     // Connection
917     _link(in_data, concat, 0);
918     _link(in_h_state, concat, 1);
919     _link(concat, fc);
920
921     _link_with_clip(fc, split, cell->clip);
922
923     _link(split, act_f, 0, 0);
924     _link(split, act_i, 1, 0);
925     _link(split, act_c, 2, 0);
926     _link(split, act_o, 3, 0);
927
928     _link(act_i, mul_ic, 0, 0);
929     _link(act_c, mul_ic, 0, 1);
930
931     _link(act_f, mul_f, 0, 0);
932     _link(in_c_state, mul_f, 1);
933
934     _link(mul_f,  sum, 0, 0);
935     _link(mul_ic, sum, 0, 1);
936
937     _link(sum, act_x);
938
939     _link(act_x, mul, 0, 0);
940     _link(act_o, mul, 0, 1);
941
942     // Output
943     mul->outData[0] = out_h_state;
944     out_h_state->getCreatorLayer() = mul;
945
946     CombineData(out_c_state, sum->outData[0]);
947     sum->outData[0] = out_c_state;
948     out_c_state->getCreatorLayer() = sum;
949
950     return true;
951 }
952
953 static bool unrollGRUCellBody(CNNLayerPtr cur, bool linear_before_reset = false) {
954     if (cur->type != "GRUCell")
955         return true;
956
957     auto cell = std::dynamic_pointer_cast<GRUCell>(cur);
958     IE_ASSERT(cell) << "Cannot cast object with type ***Cell to WeightableLayer object";
959
960     auto name = cell->name;
961
962     auto in_data = cell->insData[0].lock();
963     auto in_h_state = cell->insData[1].lock();
964     auto out_h_state = cell->outData[0];
965
966     auto d_dims = in_data->getTensorDesc().getDims();
967     auto s_dims = in_h_state->getTensorDesc().getDims();
968
969     size_t N = d_dims[0];
970     size_t D = d_dims[1];
971     size_t S = s_dims[1];
972
973     // Split weights UR and O gates. Original gates are URO
974     size_t bG = linear_before_reset ? 4 : 3;
975     auto orig_W = wrap_as_tensor(cell->_weights, {3, S, D+S});
976     auto orig_B = wrap_as_tensor(cell->_biases, {bG, S});
977
978     auto ur_W = make_region_copy(orig_W, {2, S, D+S}, {0, 0, 0});
979     auto o_W  = make_region_copy(orig_W, {1, S, D+S}, {2, 0, 0});
980     auto ur_B = make_region_copy(orig_B, {2, S}, {0, 0});
981     auto o_B  = make_region_copy(orig_B, {1, S}, {2, 0});
982
983     auto prc = cell->precision;
984
985     /** Release links on TI */
986     for (auto &ins : cell->insData)
987         ins.lock()->getInputTo().erase(cell->name);
988     for (auto &outs : cell->outData)
989         outs->getCreatorLayer().reset();
990
991     // operations
992     auto concat = _concat(name + ":concat", prc, {N, D+S}, 2);
993     auto split = _split(name + ":split", prc, {N, S}, 2);
994     auto fc_ur = _fc(name + ":fc_ur", prc, {N, S*2}, ur_W, ur_B);
995
996     const std::string _f = cell->activations[0], _g = cell->activations[1];
997
998     auto act_ur = _act(name + ":act_ur", prc, {N, 2*S}, _f);
999     auto act_o = _act(name + ":act_o", prc, {N, S}, _g);
1000
1001     auto mul_u = _eltw(name + ":mul_u", prc, {N, S}, "mul");
1002     auto mul_r = _eltw(name + ":mul_r", prc, {N, S}, "mul");
1003
1004     auto pwr_m1 = _pwr(name + ":pwr", prc, {N, S}, -1.0, 1.0);
1005
1006     auto mul = _eltw(name + ":mul"   , prc, {N, S}, "mul");
1007     auto sum = _eltw(name + ":sum"   , prc, {N, S}, "sum");
1008
1009     /**
1010      * - zt = _f(Wz*[Xt + Ht-1] + Bz)
1011      * - rt = _f(Wr*[Xt + Ht-1] + Br)
1012      * - ht = _g(Wh*[Xt + (rt (.) Ht-1)] + Bh)    # default, when linear_before_reset = 0
1013      * - ht = _g(Whw*Xt + Bhw + (rt (.) (Whr*Ht-1 + Bhr))) # when linear_before_reset != 0
1014      * - Ht = (1 - zt) (.) ht + zt (.) Ht-1
1015      */
1016     _link(in_data, concat, 0);
1017     _link(in_h_state, concat, 1);
1018     _link(concat, fc_ur);
1019     _link_with_clip(fc_ur, act_ur, cell->clip);
1020     _link(act_ur, split);  // split[0] - zt,  split[1] - rt
1021
1022     if (linear_before_reset) {
1023         auto lbr_B = wrap_as_tensor(orig_B, {4, S});
1024
1025         auto whw_W = make_region_copy(o_W, {1, S, D}, {0, 0, 0});
1026         auto whr_W = make_region_copy(o_W, {1, S, S}, {0, 0, D});
1027         auto whw_B = make_region_copy(lbr_B, {1, S}, {2, 0});
1028         auto whr_B = make_region_copy(lbr_B, {1, S}, {3, 0});
1029
1030         auto fc_whr = _fc(name + ":fc_whr", prc, {N, S}, whr_W, whr_B);
1031         auto fc_whw = _fc(name + ":fc_whw", prc, {N, S}, whw_W, whw_B);
1032         auto sum_h  = _eltw(name + ":sum_h", prc, {N, S}, "sum");
1033
1034         _link(in_h_state, fc_whr);                  //                            Whr*Ht-1 + Bhr
1035         _link(fc_whr, mul_r, 0);                    //
1036         _link(split, mul_r, 1, 1);                  //                    rt (.) (Whr*Ht-1 + Bhr)
1037         _link(in_data, fc_whw);                     //    Whw*Xt + Bhw
1038         _link(fc_whw, sum_h, 0, 0);                 //
1039         _link(mul_r, sum_h, 0, 1);                  //    Whw*Xt + Bhw + (rt (.) (Whr*Ht-1 + Bhr))
1040         _link_with_clip(sum_h, act_o, cell->clip);  // _g(Whw*Xt + Bhw + (rt (.) (Whr*Ht-1 + Bhr)))
1041     } else {
1042         auto fc_wh = _fc(name + ":fc_o", prc, {N, S}, o_W, o_B);
1043         auto concat_h = _concat(name + ":concat_h", prc, {N, D+S}, 2);
1044
1045         _link(split, mul_r, 1, 0);                  //
1046         _link(in_h_state, mul_r, 1);                //              rt (.) Ht-1
1047         _link(in_data, concat_h, 0);                //
1048         _link(mul_r, concat_h, 0, 1);               //       [Xt + (rt (.) Ht-1)]
1049         _link(concat_h, fc_wh);                     //    Wh*[Xt + (rt (.) Ht-1)] + Bh
1050         _link_with_clip(fc_wh, act_o, cell->clip);  // _g(Wh*[Xt + (rt (.) Ht-1)] + Bh)
1051     }
1052
1053     _link(split, pwr_m1, 0, 0);   //  1 - zt
1054     _link(act_o, mul, 0, 0);      //
1055     _link(pwr_m1, mul, 0, 1);     // (1 - zt) (.) ht
1056     _link(split, mul_u, 0, 0);    //
1057     _link(in_h_state, mul_u, 1);  //                   zt (.) Ht-1
1058     _link(mul, sum, 0, 0);        //
1059     _link(mul_u, sum, 0, 1);      // (1 - zt) (.) ht + zt (.) Ht-1
1060
1061     // Output
1062     sum->outData[0] = out_h_state;
1063     out_h_state->getCreatorLayer() = sum;
1064
1065     return true;
1066 }
1067
1068 static bool unrollCell(CNNLayerPtr cur) {
1069     auto cell = std::dynamic_pointer_cast<RNNCellBase>(cur);
1070     switch (cell->cellType) {
1071         case RNNCellBase::LSTM:    return unrollLSTMCellBody(cur);
1072         case RNNCellBase::GRU:     return unrollGRUCellBody(cur);
1073         case RNNCellBase::GRU_LBR: return unrollGRUCellBody(cur, true);
1074         case RNNCellBase::RNN:     return unrollRNNCellBody(cur);
1075     }
1076     return false;
1077 }
1078
1079 static bool unrollSeq(CNNLayerPtr cur) {
1080     if (!one_of(cur->type, "LSTMSequence", "GRUSequence", "RNNSequence"))
1081     return true;
1082
1083     auto seq = std::dynamic_pointer_cast<RNNSequenceLayer>(cur);
1084     IE_ASSERT(seq) << "Cannot cast object with type ***Sequence to RNNSequenceLayer object";
1085
1086     auto name = seq->name;
1087
1088     auto in_data = seq->insData[0].lock();
1089     auto in_h_state = seq->insData[1].lock();
1090     auto out_data = seq->outData[0];
1091
1092     auto in_d_dims = in_data->getTensorDesc().getDims();
1093     auto state_dims = in_h_state->getTensorDesc().getDims();
1094     auto out_d_dims = out_data->getTensorDesc().getDims();
1095
1096     const int axis = seq->axis;
1097     const auto direct = seq->direction;
1098     const auto prc = seq->precision;
1099
1100     /** Release links on Seq */
1101     for (auto &ins : seq->insData)
1102     ins.lock()->getInputTo().erase(seq->name);
1103     for (auto &outs : seq->outData)
1104     outs->getCreatorLayer().reset();
1105
1106     /** Body subgraph*/
1107     auto in_d_body_dims = in_d_dims;
1108     in_d_body_dims[axis] = 1;
1109
1110     auto in_d_body_squeeze_dims = in_d_dims;
1111     in_d_body_squeeze_dims.erase(in_d_body_squeeze_dims.begin() + axis);
1112
1113     auto out_d_body_dims = out_d_dims;
1114     out_d_body_dims[axis] = 1;
1115
1116     auto out_d_body_squeeze_dims = out_d_dims;
1117     out_d_body_squeeze_dims.erase(out_d_body_squeeze_dims.begin() + axis);
1118
1119     auto body_in_data = DataPtr(new Data(name + ":data_in",
1120             TensorDesc { prc, in_d_body_dims, TensorDesc::getLayoutByDims(in_d_body_dims) }));
1121
1122     auto resh1 = _resh(name + ":resh1", prc, in_d_body_squeeze_dims);
1123     auto cell  = _cell(name + ":cell", prc, out_d_body_squeeze_dims, state_dims, seq->cellType);
1124     auto resh2 = _resh(name + ":resh2", prc, out_d_body_dims);
1125
1126     _link(body_in_data, resh1);
1127     _link(resh1, cell);
1128     _link(cell, resh2);
1129
1130     cell->_weights = seq->_weights;
1131     cell->_biases = seq->_biases;
1132     cell->blobs["weights"] = cell->_weights;
1133     cell->blobs["biases"] = cell->_biases;
1134     cell->hidden_size = seq->hidden_size;
1135     cell->clip = seq->clip;
1136     cell->activations = seq->activations;
1137     cell->activation_alpha = seq->activation_alpha;
1138     cell->activation_beta = seq->activation_beta;
1139
1140     const size_t NS = cell->outData.size();  // num of state
1141
1142     /** TI layer */
1143     auto ti = _ti(name + ":ti", prc, NS);
1144     _link(in_data, ti, 0);
1145
1146     ti->outData[0] = out_data;
1147     out_data->getCreatorLayer() = ti;
1148
1149     ti->body.inputs.push_back(body_in_data);
1150     ti->body.outputs.push_back(resh2->outData[0]);
1151
1152     int start = direct == RNNSequenceLayer::FWD ? 0 : -1;
1153     int end = direct == RNNSequenceLayer::FWD ? -1 : 0;
1154     int step = direct == RNNSequenceLayer::FWD ? 1 : -1;
1155     ti->input_port_map.push_back({0, 0, axis, step, start, end, 1});
1156     ti->output_port_map.push_back({0, 0, axis, step, start, end, 1});
1157
1158     for (size_t i = 0; i < NS; i++) {
1159         auto in_state = seq->insData[1 + i].lock();
1160         _link(in_state, ti, 1 + i);
1161
1162         auto out_state = seq->outData[1 + i];
1163         ti->outData[1 + i] = out_state;
1164         out_state->getCreatorLayer() = ti;
1165
1166         auto body_in_state = DataPtr(new Data(name + ":state_in_" + std::to_string(i),
1167                 TensorDesc { prc, state_dims, TensorDesc::getLayoutByDims(state_dims) }));
1168
1169         _link(body_in_state, cell, 1 + i);
1170
1171         ti->body.inputs.push_back(body_in_state);
1172         ti->body.outputs.push_back(cell->outData[i]);
1173
1174         const int ii = 1 + static_cast<int>(i);
1175         ti->input_port_map.push_back({ii, ii, -1, 0, 0, 0, 0});
1176         ti->output_port_map.push_back({ii, ii, -1, 0, 0, 0, 0});
1177         ti->back_edges.push_back({ii, ii, -1, 0, 0, 0, 0});
1178     }
1179
1180     return true;
1181 }
1182
1183 /************************************************************/
1184 /****  Converter API  ***************************************/
1185 /************************************************************/
1186
1187 template <typename N>
1188 std::vector<CNNLayerPtr> TopolSort(const N &net);
1189
1190 template <>
1191 std::vector<CNNLayerPtr> TopolSort(const ICNNNetwork &net) {
1192     return details::CNNNetSortTopologically(net);
1193 }
1194
1195 template <>
1196 std::vector<CNNLayerPtr> TopolSort(const TensorIterator::Body &net) {
1197     return TIBodySortTopologically(net);
1198 }
1199
1200
1201 template <typename N, typename T>
1202 bool ApplyForAll(N &net, T action) {
1203     auto all_layers = TopolSort(net);
1204     bool sts = true;
1205
1206     for (auto &layer : all_layers)
1207         sts &= action(layer, net);
1208
1209     return sts;
1210 }
1211
1212
1213
1214 template <typename N, typename T, typename P>
1215 bool ApplyForAll_if(N &net, T action, P pred) {
1216     auto all_layers = TopolSort(net);
1217     bool sts = true;
1218
1219     for (auto &layer : all_layers)
1220         if (pred(layer))
1221             sts &= action(layer);
1222
1223     return sts;
1224 }
1225
1226 bool CombineRNNSeq(ICNNNetwork &net) {
1227     return ApplyForAll(net, convertToRNNSeq<ICNNNetwork>);
1228 }
1229 bool CombineRNNSeq(TensorIterator::Body &net) {
1230     return ApplyForAll(net, convertToRNNSeq<TensorIterator::Body>);
1231 }
1232
1233 bool UnrollTI(ICNNNetwork &net) {
1234     return ApplyForAll(net, unrollTI);
1235 }
1236
1237
1238 template <typename NET>
1239 bool UnrollRNN_if_impl(NET &net, const std::function<bool(const RNNCellBase&)> pred) {
1240     // Filter layers by RNN specific type
1241     auto _seq_pred = [&] (CNNLayerPtr layer) {
1242         auto rnn = std::dynamic_pointer_cast<RNNSequenceLayer>(layer);
1243         if (!rnn) return false;
1244         return pred(*rnn.get());
1245     };
1246     auto _cell_pred = [&] (CNNLayerPtr layer) {
1247         auto rnn = std::dynamic_pointer_cast<RNNCellBase>(layer);
1248         if (!rnn || !one_of(rnn->type, "LSTMCell", "GRUCell", "RNNCell")) return false;
1249         return pred(*rnn.get());
1250     };
1251
1252     bool res = true;
1253     res &= ApplyForAll_if(net, unrollSeq, _seq_pred);
1254     res &= ApplyForAll_if(net, unrollCell, _cell_pred);
1255     return res;
1256 }
1257
1258 bool UnrollRNN_if(ICNNNetwork &net, const std::function<bool(const RNNCellBase&)> pred) {
1259     return UnrollRNN_if_impl(net, pred);
1260 }
1261
1262 bool UnrollRNN_if(TensorIterator::Body &net, const std::function<bool(const RNNCellBase&)> pred) {
1263     return UnrollRNN_if_impl(net, pred);
1264 }
1265
1266
1267 }  // namespace NetPass
1268 }  // namespace InferenceEngine
1269