Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / net_pass.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "net_pass.h"
6 #include "blob_factory.hpp"
7 #include "ie_memcpy.h"
8 #include "details/ie_cnn_network_tools.h"
9 #include "graph_tools.hpp"
10
11 #include <string>
12 #include <utility>
13 #include <algorithm>
14 #include <memory>
15 #include <tuple>
16 #include <set>
17 #include <unordered_map>
18 #include <unordered_set>
19
20 namespace InferenceEngine {
21 namespace NetPass {
22
23 template <typename T, typename P>
24 inline bool one_of(T val, P item) { return val == item; }
25 template <typename T, typename P, typename... Args>
26 inline bool one_of(T val, P item, Args... item_others) {
27     return val == item || one_of(val, item_others...);
28 }
29
30 /************************************************************/
31 /****  TI Utils  ********************************************/
32 /************************************************************/
33
34 static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr> &heads) {
35     CNNLayerSet inputLayers;
36     std::unordered_set<CNNLayer*> allLayers;
37
38     // Define all start layers
39     for (const auto & data : heads) {
40         auto &secondLayers = data->getInputTo();
41
42         details::UnorderedDFS(allLayers, secondLayers.begin()->second, [&](CNNLayerPtr layer) {
43             if (layer->insData.empty()) {
44                 inputLayers.insert(layer);
45             }
46         }, false);
47     }
48
49     std::vector<DataPtr> res = heads;
50     // Add fake input data to point on not achievable
51     // layers from head (like const placeholders)
52     for (auto &starter : inputLayers) {
53         DataPtr holder(new Data(starter->name + ":input_holder", starter->precision));
54         holder->inputTo[starter->name] = starter;
55         res.push_back(holder);
56     }
57
58     return res;
59 }
60
61 static std::vector<CNNLayerPtr> SortTopologically(const TensorIterator::Body &body) {
62     std::vector<CNNLayerPtr> all_layers;
63
64     auto all_input_layers = getAllInputs(body.inputs);
65     CNNNetForestDFS(all_input_layers, [&](CNNLayerPtr  current){
66         all_layers.push_back(current);
67     }, false);
68     std::reverse(all_layers.begin(), all_layers.end());
69     return all_layers;
70 }
71
72 static TensorIterator::Body CopyTIBody(ICNNNetwork &net, const TensorIterator::Body &body, std::string suffix = "") {
73     struct NoneStruct {};
74     auto cp = [&](CNNLayerPtr lp) {
75         return injectData<NoneStruct>(lp);
76     };
77
78     const auto all_orig = SortTopologically(body);
79     auto num = all_orig.size();
80
81     std::unordered_map<CNNLayer*, CNNLayerPtr> old2new_l;
82     for (int i = 0; i < num; i++) {
83         auto &orig = all_orig[i];
84         old2new_l[orig.get()] = cp(orig);
85     }
86
87     std::unordered_map<Data*, DataPtr> old2new_d;
88     for (auto &in : body.inputs) {
89         auto new_data = std::make_shared<Data>(*in.get());
90         for (auto &to : new_data->getInputTo())
91             to.second = old2new_l[to.second.get()];
92
93         old2new_d[in.get()] = new_data;
94     }
95
96     for (const auto &old : all_orig) {
97         auto &new_one = old2new_l[old.get()];
98         // remap output data
99         for (int i = 0; i != old->outData.size(); i++) {
100             auto old_data = old->outData[i];
101             auto new_data = new_one->outData[i];
102             new_data->getCreatorLayer() = CNNLayerWeakPtr(new_one);
103             old2new_d[old_data.get()] = new_data;
104
105             for (auto &to : new_data->getInputTo())
106                 to.second = old2new_l[to.second.get()];
107         }
108         // remap input data
109         for (int i = 0; i != old->insData.size(); i++) {
110             auto old_data = old->insData[i].lock();
111             auto new_data = old2new_d.at(old_data.get());
112             new_one->insData[i] = new_data;
113         }
114     }
115
116     // Add suffix
117     if (!suffix.empty()) {
118         for (auto &kvp : old2new_l) {
119             auto layer = kvp.second;
120             auto old_name = layer->name;
121             layer->name += suffix;
122             for (auto &ins : layer->insData) {
123                 ins.lock()->inputTo.erase(old_name);
124                 ins.lock()->inputTo[layer->name] = layer;
125             }
126
127             // And also hold newly created layer in parent network.
128             // TI body may contain isolated constant placeholder layers
129             // which are not achievable from body inputs.
130             net.addLayer(layer);
131         }
132         for (auto &kvp : old2new_d) kvp.second->name += suffix;
133     }
134
135     TensorIterator::Body res;
136     for (auto &in : body.inputs)
137         res.inputs.emplace_back(old2new_d[in.get()]);
138
139     for (auto &out : body.outputs)
140         res.outputs.emplace_back(old2new_d[out.get()]);
141
142     return res;
143 }
144
145 /************************************************************/
146 /****  TI rule helpers  *************************************/
147 /************************************************************/
148
149 inline bool is_full_ranged(const TensorIterator::PortMap& rule, const DataPtr &data) {
150     if (!data)
151         THROW_IE_EXCEPTION << "Internal error. data == nullptr";
152
153     if (rule.axis == -1 || !one_of(rule.stride, 1, -1))
154         return false;
155
156     auto &shape = data->getDims();
157     int size = shape[rule.axis];
158
159     int begin = rule.start >= 0 ? rule.start : size + rule.start + 1;
160     int end = rule.end >= 0 ? rule.end : size + rule.end + 1;
161
162     return (rule.stride == 1)
163         ? begin == 0 && end == size
164         : begin == size && end == 0;
165 }
166
167 inline int get_num_iteration(const std::shared_ptr<TensorIterator> &ti) {
168     int iter_num = 1;  // 1 means no iteration
169
170     for (auto & rule : ti->input_port_map) {
171         if (rule.axis == -1) continue;
172
173         auto data = ti->insData[rule.from].lock();
174         IE_ASSERT(data);
175
176         auto shape = data->getDims();
177         size_t size = shape[rule.axis];
178         size_t step = std::abs(rule.stride);
179         size_t cur_iter_size = size / step;
180
181         if (iter_num == 1) {
182             iter_num = cur_iter_size;
183         } else {
184             if (iter_num != cur_iter_size)
185                 return -1;  // TI is inconsistent
186         }
187     }
188
189     for (auto & rule : ti->output_port_map) {
190         if (rule.axis == -1) continue;
191
192         auto data = ti->outData[rule.from];
193         auto shape = data->getDims();
194
195         size_t size = shape[rule.axis];
196         size_t step = std::abs(rule.stride);
197         size_t cur_iter_size = size / step;
198
199         if (iter_num == 1) {
200             iter_num = cur_iter_size;
201         } else {
202             if (iter_num != cur_iter_size)
203                 return -1;  // TI is inconsistent
204         }
205     }
206     return iter_num;
207 }
208
209 using RuleSet = std::vector<TensorIterator::PortMap>;
210
211 std::tuple<RuleSet, RuleSet, RuleSet> ClassifyInRules(const std::shared_ptr<TensorIterator> &ti) {
212     /*
213      * first_class  - which has iteration component
214      * second_class - which has no iteration and there are no backedge connection to the same port
215      * third_class  - which has no iteration and has corresponding backedge
216      */
217     RuleSet first_class_rules, second_class_rules, third_class_rules;
218
219     std::set<int> ports_with_backedge;
220     for (const auto &back_edge : ti->back_edges) ports_with_backedge.insert(back_edge.to);
221
222     for (const auto &rule : ti->input_port_map) {
223         if (rule.axis != -1)
224             first_class_rules.push_back(rule);
225
226         else if (!ports_with_backedge.count(rule.to))
227             second_class_rules.push_back(rule);
228
229         else
230             third_class_rules.push_back(rule);
231     }
232     return std::tuple<RuleSet, RuleSet, RuleSet> {first_class_rules, second_class_rules, third_class_rules};
233 }
234
235 std::tuple<RuleSet, RuleSet, RuleSet> ClassifyOutRules(const std::shared_ptr<TensorIterator> &ti) {
236     /*
237      * first_class  - which has iteration component
238      * second_class - which has no iteration and there are no backedge connection to the same port
239      * third_class  - which has no iteration and has corresponding backedge
240      */
241     RuleSet first_class_rules, second_class_rules, third_class_rules;
242
243     std::set<int> ports_with_backedge;
244     for (const auto &back_edge : ti->back_edges) ports_with_backedge.insert(back_edge.from);
245
246     for (const auto &rule : ti->output_port_map) {
247         if (rule.axis != -1)
248             first_class_rules.push_back(rule);
249
250         else if (!ports_with_backedge.count(rule.to))
251             second_class_rules.push_back(rule);
252
253         else
254             third_class_rules.push_back(rule);
255     }
256     return std::tuple<RuleSet, RuleSet, RuleSet> {first_class_rules, second_class_rules, third_class_rules};
257 }
258
259 /**
260  * Merge slave connections into master
261  * @param master
262  * @param slave
263  */
264 void CombineData(DataPtr &master, DataPtr &slave) {
265     for (auto &kvp : slave->inputTo) {
266         auto &slave_layer = kvp.second;
267         for (auto &slv_ins_wptr : slave_layer->insData) {
268             auto slv_ins = slv_ins_wptr.lock();
269             // Replace slave ptr with master
270             if (slv_ins == slave) slv_ins_wptr = master;
271         }
272         master->inputTo[slave_layer->name] = slave_layer;
273     }
274 }
275
276 /************************************************************/
277 /****  Converter Passes  ************************************/
278 /************************************************************/
279
280 static RNNSequenceLayer::CellType cell_type_from_name(std::string &layer_type) {
281     RNNSequenceLayer::CellType res;
282     if (layer_type == "LSTMCell")
283         res = RNNSequenceLayer::LSTM;
284     else if (layer_type == "GRUCell")
285         res = RNNSequenceLayer::GRU;
286     else if (layer_type == "RNNCell")
287         res = RNNSequenceLayer::GRU;
288     else
289         THROW_IE_EXCEPTION << "Unknown Cell type (" << layer_type << "). Expected LSTMCell|GRUCell|RNNCell";
290     return res;
291 }
292
293 static std::string cell_name(RNNSequenceLayer::CellType type) {
294     std::string res;
295     if (type == RNNSequenceLayer::LSTM)
296         res = "LSTM";
297     else if (type == RNNSequenceLayer::GRU)
298         res = "GRU";
299     else if (type == RNNSequenceLayer::GRU)
300         res = "GRU";
301     else
302         THROW_IE_EXCEPTION << "Unknown Cell type (enum index: " << type << "). Expected LSTM|GRU|RNN";
303     return res;
304 }
305
306
307 bool convertToRNNSeq(CNNLayerPtr cur, ICNNNetwork &net) {
308     if (cur->type != "TensorIterator") return true;
309
310     auto ti = std::dynamic_pointer_cast<TensorIterator>(cur);
311     IE_ASSERT(ti) << "Cannot cast object with type TensorIterator to TensorIterator object";
312
313     auto all_body_layers = SortTopologically(ti->body);
314
315     // Check if body is:  squeeze -> lstm_cell -> unsqueeze
316     if (all_body_layers.size() != 3
317         || all_body_layers[0]->type != "Reshape"
318         || !one_of(all_body_layers[1]->type, "GRUCell", "RNNCell", "LSTMCell")
319         || all_body_layers[2]->type != "Reshape")
320         return false;
321
322     auto rsp1 = std::dynamic_pointer_cast<ReshapeLayer>(all_body_layers[0]);
323     auto cell = std::dynamic_pointer_cast<RNNCellBase>(all_body_layers[1]);
324     auto rsp2 = std::dynamic_pointer_cast<ReshapeLayer>(all_body_layers[2]);
325
326     auto cell_type = cell_type_from_name(all_body_layers[1]->type);
327
328     int NS = cell_type == RNNSequenceLayer::LSTM ? 2 : 1;  // number of states
329
330     IE_ASSERT(cell->insData.size() == NS + 1);  // {data, state1, [state2]}
331     IE_ASSERT(cell->outData.size() == NS);  // {state1, [state2]}
332
333     if (cell->insData[0].lock()->creatorLayer.lock() != rsp1 ||
334         cell->outData[0]->inputTo.begin()->second != rsp2)
335         return false;
336
337     // Check port mapping
338     auto _indx_in = [&] (const std::vector<DataPtr> &scope,  const DataPtr &data) {
339         int indx = std::find(scope.begin(), scope.end(), data) - scope.begin();
340         return indx == scope.size() ? -1 : indx;
341     };
342
343     int in_dt_idx = _indx_in(ti->body.inputs, rsp1->insData[0].lock());
344     int in_hs_idx = _indx_in(ti->body.inputs, cell->insData[1].lock());
345     int in_cs_idx = NS == 2 ? _indx_in(ti->body.inputs, cell->insData[2].lock()) : -1;
346
347     int out_dt_idx = _indx_in(ti->body.outputs, rsp2->outData[0]);
348     int out_hs_idx = _indx_in(ti->body.outputs, cell->outData[0]);
349     int out_cs_idx = NS == 2 ? _indx_in(ti->body.outputs, cell->outData[1]) : -1;
350
351     // indexes should be [0,1,2] : sum == 3 or [0,1,-1] : sum == 0
352     int sum = (NS - 1) * 3;
353     if (in_hs_idx + in_cs_idx + in_dt_idx != sum || out_hs_idx + out_cs_idx + out_dt_idx != sum)
354         return false;
355
356     std::map<int, TensorIterator::PortMap> i2map, o2map, be2map;
357     for (auto &m : ti->input_port_map) i2map[m.to] = m;
358     for (auto &m : ti->output_port_map) o2map[m.to] = m;
359     for (auto &m : ti->back_edges) be2map[m.to] = m;
360
361     if (!one_of(i2map.size(), NS + 1, 1) ||
362         !one_of(o2map.size(), NS + 1, 1) ||
363         !one_of(be2map.size(), 2))
364         return false;
365
366     auto in_iter_rule = i2map[in_dt_idx];
367     auto in_iter_data = ti->insData[in_iter_rule.from].lock();
368
369     auto out_iter_rule = o2map[out_dt_idx];
370     auto out_iter_data = ti->outData[out_iter_rule.from];
371
372     // TI iterates only for full range of tensor
373     if (!is_full_ranged(in_iter_rule, in_iter_data) ||
374         !is_full_ranged(out_iter_rule, out_iter_data))
375         return false;
376
377     // supported only same axis and strides for in/out data tensors
378     if (in_iter_rule.axis != out_iter_rule.axis ||
379         in_iter_rule.stride != out_iter_rule.stride)
380         return false;
381
382     // supported only firs and second dim for LSTM-Sequence
383     if (!one_of(in_iter_rule.axis, 0, 1))
384         return false;
385
386     bool no_init_state = i2map.size() == 1;
387     bool no_last_state = o2map.size() == 1;
388
389     if (!no_init_state && ( i2map[in_hs_idx].axis != -1 || (NS == 2 && i2map[in_cs_idx].axis != -1) ))
390         return false;
391     if (!no_last_state && ( o2map[out_hs_idx].axis != -1 || (NS == 2 && o2map[out_cs_idx].axis != -1) ))
392         return false;
393
394     std::vector<int> i_order {i2map[in_dt_idx].from };
395     if (!no_init_state)
396         i_order.push_back(i2map[in_hs_idx].from);
397     if (!no_init_state && NS == 2)
398         i_order.push_back(i2map[in_cs_idx].from);
399
400     std::vector<int> o_order {o2map[out_dt_idx].from};
401     if (!no_last_state)
402         o_order.push_back(o2map[out_hs_idx].from);
403     if (!no_last_state && NS == 2)
404         o_order.push_back(o2map[out_cs_idx].from);
405
406     // need swap an i/o ports if it is not in natural order
407     std::string name = cell->name + "_sequence";
408     auto rnn  = std::make_shared<RNNSequenceLayer>(LayerParams{ name, cell_name(cell_type) + "Sequence", cell->precision});
409     rnn->cellType = cell_type;
410     rnn->axis = in_iter_rule.axis;
411     rnn->direction = in_iter_rule.stride == 1
412             ? RNNSequenceLayer::FWD
413             : RNNSequenceLayer::BWD;
414
415     // copy base RNN cell fields
416     rnn->_weights = cell->_weights;
417     rnn->_biases = cell->_biases;
418     rnn->blobs = cell->blobs;
419     rnn->activations = cell->activations;
420     rnn->activation_alpha = cell->activation_alpha;
421     rnn->activation_beta = cell->activation_beta;
422     rnn->hidden_size = cell->hidden_size;
423     rnn->clip = cell->clip;
424
425     for (int i : i_order) {
426         auto in_data = ti->insData[i].lock();
427         in_data->inputTo.erase(ti->name);
428         in_data->inputTo[rnn->name] = rnn;
429         rnn->insData.push_back(in_data);
430     }
431     for (int i : o_order) {
432         rnn->outData.push_back(ti->outData[i]);
433         rnn->outData.back()->creatorLayer = rnn;
434     }
435
436     return true;
437 }
438
439 bool unrollTI(CNNLayerPtr cur, ICNNNetwork &net) {
440     if (cur->type != "TensorIterator")
441         return true;
442
443     auto ti = std::dynamic_pointer_cast<TensorIterator>(cur);
444     IE_ASSERT(ti) << "Cannot cast object with type TensorIterator to TensorIterator object";
445
446     int num = get_num_iteration(ti);  // -1 means inconsistent TI
447     if (num == -1) return false;  // TODO: better to throw exception
448
449     const auto &body = ti->body;
450
451     std::vector<TensorIterator::Body> body_list(num);
452     for (int i = 0; i < num; i++) {
453         // copy with additional suffix to each object name
454         body_list[i] = CopyTIBody(net, body, ":" + std::to_string(i));
455     }
456
457     RuleSet first_class, second_class, third_class;
458     std::tie(first_class, second_class, third_class) = ClassifyInRules(ti);
459
460     /** Clean links on TI */
461     for (auto &ins : ti->insData)
462         ins.lock()->inputTo.erase(ti->name);
463     for (auto &outs : ti->outData)
464         outs->creatorLayer.reset();
465
466     /** FIRST class comes */
467     for (int i = 0; i < first_class.size(); i++) {
468         auto &rule = first_class[i];
469         auto in_data = ti->insData[rule.from].lock();
470
471         std::string name = ti->name + ":in_split_" + std::to_string(i);
472         auto split = std::make_shared<SplitLayer>(LayerParams{ name, "Split", cur->precision });
473         split->_axis = rule.axis;
474         split->outData.resize(num);
475         split->insData.emplace_back(in_data);
476         in_data->inputTo[split->name] = split;
477
478         for (int j = 0; j < num; j++) {
479             auto body_idx = rule.stride == 1 ? j : num - 1 - j;
480             auto &chunk = body_list[body_idx].inputs[rule.to];
481             chunk->creatorLayer = split;
482             split->outData[j] = chunk;
483         }
484     }
485
486     /** SECOND class come on */
487     for (const auto &rule : second_class) {
488         auto in_data = ti->insData[rule.from].lock();
489
490         for (int j = 0; j < num; j++) {
491             auto &chunk = body_list[j].inputs[rule.to];
492             CombineData(in_data, chunk);
493         }
494     }
495
496     /** BACK EDGES that's your time */
497     for (const auto &rule : ti->back_edges) {
498         for (int i = 1; i < num; i++) {
499             auto &from_data = body_list[i-1].outputs[rule.from];
500             auto &to_data = body_list[i].inputs[rule.to];
501             CombineData(from_data, to_data);
502         }
503     }
504
505     /** THIRD class end up */
506     for (const auto &rule : third_class) {
507         // first iteration
508         auto from_data = ti->insData[rule.from].lock();
509         auto &to_data = body_list[0].inputs[rule.to];;
510         CombineData(from_data, to_data);
511     }
512
513     /** And the same actions for outputs connections */
514     std::tie(first_class, second_class, third_class) = ClassifyOutRules(ti);
515
516     /** FIRST class comes */
517     for (int i = 0; i < first_class.size(); i++) {
518         auto &rule = first_class[i];
519         auto out_data = ti->outData[rule.from];
520
521         std::string name = ti->name + ":out_concat_" + std::to_string(i);
522         auto concat = std::make_shared<ConcatLayer>(LayerParams{ name, "Concat", cur->precision });
523         concat->_axis = rule.axis;
524         concat->insData.resize(num);
525         concat->outData.emplace_back(out_data);
526         out_data->creatorLayer = concat;
527
528         for (int j = 0; j < num; j++) {
529             auto body_idx = rule.stride == 1 ? j : num - 1 - j;
530             auto &chunk = body_list[body_idx].outputs[rule.to];
531             chunk->inputTo[concat->name] = concat;
532             concat->insData[j] = chunk;
533         }
534     }
535
536     /** SECOND class come on */
537     for (const auto &rule : second_class) {
538         auto out_data = ti->outData[rule.from];
539
540         for (int j = 0; j < num; j++) {
541             auto &chunk = body_list[j].outputs[rule.to];
542             CombineData(chunk, out_data);
543         }
544     }
545
546     /** THIRD class end up */
547     for (const auto &rule : third_class) {
548         // first iteration
549         auto &from_data = ti->outData[rule.from];
550         auto &to_data = body_list[num-1].outputs[rule.to];
551
552         auto parent = to_data->creatorLayer.lock();
553         std::replace(parent->outData.begin(), parent->outData.end(), to_data, from_data);
554         from_data->creatorLayer = parent;
555
556         CombineData(from_data, to_data);
557     }
558     return true;
559 }
560
561 /************************************************************/
562 /****  Builder helpers   ************************************/
563 /************************************************************/
564
565 static CNNLayerPtr _concat(std::string name, Precision prc, SizeVector dims, int num) {
566     auto res = std::make_shared<ConcatLayer>(LayerParams{name, "Concat", prc});
567     res->_axis = 1;
568
569     res->insData.resize(num);
570     res->outData.resize(1);
571
572     auto out_data = DataPtr(new Data(name,
573             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
574     out_data->creatorLayer = res;
575
576     res->outData[0] = out_data;
577     return res;
578 }
579
580 static CNNLayerPtr _split(std::string name, Precision prc, SizeVector dims, int num) {
581     auto res = std::make_shared<SplitLayer>(LayerParams{name, "Split", prc});
582     res->_axis = 1;
583     res->params["axis"] = res->_axis;
584
585     res->insData.resize(1);
586     res->outData.resize(num);
587
588     for (int i = 0; i < num; i++) {
589         auto out_data = DataPtr(new Data(name + "_part_" + std::to_string(i),
590                 TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
591         out_data->creatorLayer = res;
592
593         res->outData[i] = out_data;
594     }
595     return res;
596 }
597
598 static CNNLayerPtr _fc(std::string name, Precision prc, SizeVector dims, Blob::Ptr &W, Blob::Ptr &B) {
599     auto res = std::make_shared<FullyConnectedLayer>(LayerParams{name, "FullyConnected", prc});
600
601     res->_weights = W;
602     res->_biases = B;
603     res->_out_num = dims[1];
604     res->blobs["weights"] = W;
605     res->blobs["biases"] = B;
606     res->params["out-size"] = std::to_string(dims[1]);
607
608     res->insData.resize(1);
609     res->outData.resize(1);
610
611     auto out_data = DataPtr(new Data(name,
612             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
613     out_data->creatorLayer = res;
614
615     res->outData[0] = out_data;
616     return res;
617 }
618
619 static CNNLayerPtr _act(std::string name, Precision prc, SizeVector dims, std::string type) {
620     auto res = std::make_shared<CNNLayer>(LayerParams{name, "Activation", prc});
621
622     res->params["type"] = type;
623
624     res->insData.resize(1);
625     res->outData.resize(1);
626
627     auto out_data = DataPtr(new Data(name,
628             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
629     out_data->creatorLayer = res;
630
631     res->outData[0] = out_data;
632     return res;
633 }
634
635 static CNNLayerPtr _pwr(std::string name, Precision prc, SizeVector dims, float scale, float shift) {
636     auto res = std::make_shared<PowerLayer>(LayerParams{name, "Power", prc});
637
638     res->power = 1.0;
639     res->scale = scale;
640     res->offset = shift;
641     res->params["power"] = res->power;
642     res->params["scale"] = res->scale;
643     res->params["shift"] = res->offset;
644
645     res->insData.resize(1);
646     res->outData.resize(1);
647
648     auto out_data = DataPtr(new Data(name,
649             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
650     out_data->creatorLayer = res;
651
652     res->outData[0] = out_data;
653     return res;
654 }
655
656
657 static CNNLayerPtr _eltw(std::string name, Precision prc, SizeVector dims, std::string type) {
658     auto res = std::make_shared<EltwiseLayer>(LayerParams{name, "Eltwise", prc});
659
660     res->params["operation"] = type;
661     res->_operation = type == "sum" ? EltwiseLayer::Sum : EltwiseLayer::Prod;
662
663     res->insData.resize(2);
664     res->outData.resize(1);
665
666     auto out_data = DataPtr(new Data(name,
667             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
668     out_data->creatorLayer = res;
669
670     res->outData[0] = out_data;
671     return res;
672 }
673
674 static std::shared_ptr<ReshapeLayer> _resh(std::string name, Precision prc, SizeVector dims) {
675     auto res = std::make_shared<ReshapeLayer>(LayerParams{name, "Reshape", prc});
676
677     res->insData.resize(1);
678     res->outData.resize(1);
679
680     auto out_data = DataPtr(new Data(name,
681             TensorDesc { prc, dims, TensorDesc::getLayoutByDims(dims) }));
682     out_data->creatorLayer = res;
683
684     res->outData[0] = out_data;
685     return res;
686 }
687
688 static std::shared_ptr<RNNCellBase> _cell(std::string name, Precision prc, SizeVector data_dims, SizeVector state_dims, RNNSequenceLayer::CellType type) {
689     std::shared_ptr<RNNCellBase> res;
690     size_t NS = 1;
691     switch (type) {
692         case RNNSequenceLayer::LSTM:
693             res = std::make_shared<LSTMCell>(LayerParams{name, "LSTMCell", prc}); NS = 2;
694             break;
695         case RNNSequenceLayer::GRU:
696         case RNNSequenceLayer::GRU_LBR:
697             res = std::make_shared<GRUCell>(LayerParams{name, "GRUCell", prc});
698             break;
699         case RNNSequenceLayer::RNN:
700             res = std::make_shared<RNNCell>(LayerParams{name, "RNNCell", prc});
701             break;
702     }
703
704     res->cellType = type;
705     res->insData.resize(1 + NS);
706     res->outData.resize(NS);
707
708     auto out_data = DataPtr(new Data(name + ":out_data",
709             TensorDesc { prc, data_dims, TensorDesc::getLayoutByDims(data_dims) }));
710     out_data->creatorLayer = res;
711     res->outData[0] = out_data;
712
713     for (size_t i = 0; i < NS; i++) {
714         auto out_state = DataPtr(new Data(name + ":out_state_" + std::to_string(i),
715                 TensorDesc { prc, state_dims, TensorDesc::getLayoutByDims(state_dims) }));
716         out_state->creatorLayer = res;
717         res->outData[i] = out_state;
718     }
719
720     return res;
721 }
722
723 static std::shared_ptr<TensorIterator> _ti(std::string name, Precision prc, size_t NS) {
724     auto res = std::make_shared<TensorIterator>(LayerParams{name, "TensorIterator", prc});
725
726     res->insData.resize(1 + NS);
727     res->outData.resize(1 + NS);
728
729     return res;
730 }
731
732 static void _link(CNNLayerPtr src, CNNLayerPtr dst, size_t src_port = 0, size_t dst_port = 0) {
733     auto data = src->outData[src_port];
734     data->inputTo[dst->name] = dst;
735     dst->insData[dst_port] = data;
736 }
737
738 static void _link(DataPtr &data, CNNLayerPtr dst, size_t dst_port = 0) {
739     data->inputTo[dst->name] = dst;
740     dst->insData[dst_port] = data;
741 }
742
743 /** Link nodes with clipping data if required (clip_val != 0.0) */
744 static void _link_with_clip(CNNLayerPtr src, CNNLayerPtr dst, const float clip_val,
745         size_t src_port = 0, size_t dst_port = 0) {
746     if (clip_val == 0.0f) {
747         _link(src, dst, src_port, dst_port);
748     } else {
749         auto clip_name = dst->name + "_clip";
750         auto clip_prc = dst->precision;
751         auto clip_shape = src->outData[src_port]->getTensorDesc().getDims();
752         auto clip = _act(clip_name, clip_prc, clip_shape, "clamp");
753         clip->params["min"] = std::to_string(-clip_val);
754         clip->params["max"] = std::to_string(clip_val);
755
756         _link(src, clip, src_port, 0);
757         _link(clip, dst, 0, dst_port);
758     }
759 }
760
761
762 static Blob::Ptr make_partial_copy(Blob::Ptr src, size_t off, size_t size) {
763     auto res = make_plain_blob(src->precision(), {size});
764     res->allocate();
765
766     size_t elem_size = src->precision().size();
767     auto src_ptr = src->buffer().as<uint8_t*>();
768     auto dst_ptr = res->buffer().as<uint8_t*>();
769
770     ie_memcpy(dst_ptr, res->byteSize(), src_ptr + off * elem_size,  size * elem_size);
771
772     return res;
773 }
774
775 static Blob::Ptr wrap_as_tensor(Blob::Ptr src, SizeVector dims) {
776     auto res = make_blob_with_precision(
777             TensorDesc { src->precision(), dims, plain_layout(dims) },
778             src->buffer());
779     IE_ASSERT(src->size() == res->size());
780     return res;
781 }
782
783 static Blob::Ptr make_region_copy(Blob::Ptr src, SizeVector region, SizeVector offset) {
784     IE_ASSERT(region.size() == offset.size());
785     IE_ASSERT(region.size() == src->dims().size());
786
787     auto res = make_plain_blob(src->precision(), region);
788     res->allocate();
789
790     size_t elem_size = src->precision().size();
791     auto src_ptr = src->buffer().as<uint8_t*>();
792     auto dst_ptr = res->buffer().as<uint8_t*>();
793
794     auto &dd = src->getTensorDesc().getDims();
795     SizeVector src_dims {1, 1, 1};
796     std::copy(dd.begin(), dd.end(), src_dims.end() - dd.size());
797
798     SizeVector dims {1, 1, 1};
799     std::copy(region.begin(), region.end(), dims.end() - region.size());
800
801     SizeVector off {0, 0, 0};
802     std::copy(offset.begin(), offset.end(), off.end() - offset.size());
803
804     const auto D1 = dims[0];
805     const auto D2 = dims[1];
806     const auto D3 = dims[2];
807     const auto off1 = off[0];
808     const auto off2 = off[1];
809     const auto off3 = off[2];
810     const auto str1 = src_dims[1]*src_dims[2];
811     const auto str2 = src_dims[2];
812
813     for (size_t d1 = 0; d1 < D1; d1++)
814     for (size_t d2 = 0; d2 < D2; d2++) {
815         auto off_src = (off1 + d1)*str1 + (off2 + d2)*str2 + off3;
816         auto off_dst = d1*D2*D3 + d2*D3;
817         ie_memcpy(dst_ptr + off_dst * elem_size, res->byteSize(), src_ptr + off_src * elem_size,  D3 * elem_size);
818     }
819
820     return res;
821 }
822
823
824 static bool unrollRNNCellBody(CNNLayerPtr cur) {
825     if (cur->type != "RNNCell")
826         return true;
827
828     auto cell = std::dynamic_pointer_cast<RNNCellBase>(cur);
829     IE_ASSERT(cell) << "Cannot cast object with type ***Cell to WeightableLayer object";
830
831     auto name = cell->name;
832
833     auto in_data = cell->insData[0].lock();
834     auto in_h_state = cell->insData[1].lock();
835     auto out_h_state = cell->outData[0];
836
837     auto d_dims = in_data->getTensorDesc().getDims();
838     auto s_dims = in_h_state->getTensorDesc().getDims();
839
840     size_t N = d_dims[0];
841     size_t D = d_dims[1];
842     size_t S = s_dims[1];
843
844     auto prc = cell->precision;
845
846     /** Release links on TI */
847     for (auto &ins : cell->insData)
848         ins.lock()->inputTo.erase(cell->name);
849     for (auto &outs : cell->outData)
850         outs->creatorLayer.reset();
851
852     // operations
853     auto concat = _concat(name + ":concat", prc, {N, D+S}, 2);
854     auto fc = _fc(name + ":fc", prc, {N, S}, cell->_weights, cell->_biases);
855     auto act = _act(name + ":act", prc, {N, S}, cell->activations[0]);
856
857     // Connection
858     _link(in_data, concat, 0);
859     _link(in_h_state, concat, 1);
860     _link(concat, fc);
861     _link_with_clip(fc, act, cell->clip);
862
863     // Output
864     act->outData[0] = out_h_state;
865     out_h_state->creatorLayer = act;
866
867     return true;
868 }
869
870 static bool unrollLSTMCellBody(CNNLayerPtr cur) {
871     if (cur->type != "LSTMCell")
872         return true;
873
874     auto cell = std::dynamic_pointer_cast<RNNCellBase>(cur);
875     IE_ASSERT(cell) << "Cannot cast object with type ***Cell to WeightableLayer object";
876
877     auto name = cell->name;
878
879     auto in_data = cell->insData[0].lock();
880     auto in_h_state = cell->insData[1].lock();
881     auto in_c_state = cell->insData[2].lock();
882     auto out_h_state = cell->outData[0];
883     auto out_c_state = cell->outData[1];
884
885     auto d_dims = in_data->getTensorDesc().getDims();
886     auto s_dims = in_h_state->getTensorDesc().getDims();
887
888     size_t N = d_dims[0];
889     size_t D = d_dims[1];
890     size_t S = s_dims[1];
891     size_t G = 4;
892
893     auto prc = cell->precision;
894
895     /** Release links on TI */
896     for (auto &ins : cell->insData)
897         ins.lock()->inputTo.erase(cell->name);
898     for (auto &outs : cell->outData)
899         outs->creatorLayer.reset();
900
901     // operations
902     auto concat = _concat(name + ":concat", prc, {N, D+S}, 2);
903     auto split = _split(name + ":split", prc, {N, S}, G);
904     auto fc = _fc(name + ":fc", prc, {N, S*G}, cell->_weights, cell->_biases);
905
906     const std::string _f = cell->activations[0], _g = cell->activations[1], _h = cell->activations[2];
907
908     auto act_f = _act(name + ":act_f", prc, {N, S}, _f);
909     auto act_i = _act(name + ":act_i", prc, {N, S}, _f);
910     auto act_c = _act(name + ":act_c", prc, {N, S}, _g);
911     auto act_o = _act(name + ":act_o", prc, {N, S}, _f);
912     auto act_x = _act(name + ":act_x", prc, {N, S}, _h);
913
914     auto mul_ic = _eltw(name + ":mul_ic", prc, {N, S}, "mul");
915     auto mul_f  = _eltw(name + ":mul_f" , prc, {N, S}, "mul");
916     auto sum    = _eltw(name + ":sum"   , prc, {N, S}, "sum");
917     auto mul    = _eltw(name + ":mul"   , prc, {N, S}, "mul");
918
919     // Connection
920     _link(in_data, concat, 0);
921     _link(in_h_state, concat, 1);
922     _link(concat, fc);
923
924     _link_with_clip(fc, split, cell->clip);
925
926     _link(split, act_f, 0, 0);
927     _link(split, act_i, 1, 0);
928     _link(split, act_c, 2, 0);
929     _link(split, act_o, 3, 0);
930
931     _link(act_i, mul_ic, 0, 0);
932     _link(act_c, mul_ic, 0, 1);
933
934     _link(act_f, mul_f, 0, 0);
935     _link(in_c_state, mul_f, 1);
936
937     _link(mul_f,  sum, 0, 0);
938     _link(mul_ic, sum, 0, 1);
939
940     _link(sum, act_x);
941
942     _link(act_x, mul, 0, 0);
943     _link(act_o, mul, 0, 1);
944
945     // Output
946     mul->outData[0] = out_h_state;
947     out_h_state->creatorLayer = mul;
948
949     CombineData(out_c_state, sum->outData[0]);
950     sum->outData[0] = out_c_state;
951     out_c_state->creatorLayer = sum;
952
953     return true;
954 }
955
956 static bool unrollGRUCellBody(CNNLayerPtr cur, bool linear_before_reset = false) {
957     if (cur->type != "GRUCell")
958         return true;
959
960     auto cell = std::dynamic_pointer_cast<GRUCell>(cur);
961     IE_ASSERT(cell) << "Cannot cast object with type ***Cell to WeightableLayer object";
962
963     auto name = cell->name;
964
965     auto in_data = cell->insData[0].lock();
966     auto in_h_state = cell->insData[1].lock();
967     auto out_h_state = cell->outData[0];
968
969     auto d_dims = in_data->getTensorDesc().getDims();
970     auto s_dims = in_h_state->getTensorDesc().getDims();
971
972     size_t N = d_dims[0];
973     size_t D = d_dims[1];
974     size_t S = s_dims[1];
975
976     // Split weights UR and O gates. Original gates are URO
977     size_t bG = linear_before_reset ? 4 : 3;
978     auto orig_W = wrap_as_tensor(cell->_weights, {3, S, D+S});
979     auto orig_B = wrap_as_tensor(cell->_biases, {bG, S});
980
981     auto ur_W = make_region_copy(orig_W, {2, S, D+S}, {0, 0, 0});
982     auto o_W  = make_region_copy(orig_W, {1, S, D+S}, {2, 0, 0});
983     auto ur_B = make_region_copy(orig_B, {2, S}, {0, 0});
984     auto o_B  = make_region_copy(orig_B, {1, S}, {2, 0});
985
986     auto prc = cell->precision;
987
988     /** Release links on TI */
989     for (auto &ins : cell->insData)
990         ins.lock()->inputTo.erase(cell->name);
991     for (auto &outs : cell->outData)
992         outs->creatorLayer.reset();
993
994     // operations
995     auto concat = _concat(name + ":concat", prc, {N, D+S}, 2);
996     auto split = _split(name + ":split", prc, {N, S}, 2);
997     auto fc_ur = _fc(name + ":fc_ur", prc, {N, S*2}, ur_W, ur_B);
998
999     const std::string _f = cell->activations[0], _g = cell->activations[1];
1000
1001     auto act_ur = _act(name + ":act_ur", prc, {N, 2*S}, _f);
1002     auto act_o = _act(name + ":act_o", prc, {N, S}, _g);
1003
1004     auto mul_u = _eltw(name + ":mul_u", prc, {N, S}, "mul");
1005     auto mul_r = _eltw(name + ":mul_r", prc, {N, S}, "mul");
1006
1007     auto pwr_m1 = _pwr(name + ":pwr", prc, {N, S}, -1.0, 1.0);
1008
1009     auto mul = _eltw(name + ":mul"   , prc, {N, S}, "mul");
1010     auto sum = _eltw(name + ":sum"   , prc, {N, S}, "sum");
1011
1012     /**
1013      * - zt = _f(Wz*[Xt + Ht-1] + Bz)
1014      * - rt = _f(Wr*[Xt + Ht-1] + Br)
1015      * - ht = _g(Wh*[Xt + (rt (.) Ht-1)] + Bh)    # default, when linear_before_reset = 0
1016      * - ht = _g(Whw*Xt + Bhw + (rt (.) (Whr*Ht-1 + Bhr))) # when linear_before_reset != 0
1017      * - Ht = (1 - zt) (.) ht + zt (.) Ht-1
1018      */
1019     _link(in_data, concat, 0);
1020     _link(in_h_state, concat, 1);
1021     _link(concat, fc_ur);
1022     _link_with_clip(fc_ur, act_ur, cell->clip);
1023     _link(act_ur, split);  // split[0] - zt,  split[1] - rt
1024
1025     if (linear_before_reset) {
1026         auto lbr_B = wrap_as_tensor(orig_B, {4, S});
1027
1028         auto whw_W = make_region_copy(o_W, {1, S, D}, {0, 0, 0});
1029         auto whr_W = make_region_copy(o_W, {1, S, S}, {0, 0, D});
1030         auto whw_B = make_region_copy(lbr_B, {1, S}, {2, 0});
1031         auto whr_B = make_region_copy(lbr_B, {1, S}, {3, 0});
1032
1033         auto fc_whr = _fc(name + ":fc_whr", prc, {N, S}, whr_W, whr_B);
1034         auto fc_whw = _fc(name + ":fc_whw", prc, {N, S}, whw_W, whw_B);
1035         auto sum_h  = _eltw(name + ":sum_h", prc, {N, S}, "sum");
1036
1037         _link(in_h_state, fc_whr);                  //                            Whr*Ht-1 + Bhr
1038         _link(fc_whr, mul_r, 0);                    //
1039         _link(split, mul_r, 1, 1);                  //                    rt (.) (Whr*Ht-1 + Bhr)
1040         _link(in_data, fc_whw);                     //    Whw*Xt + Bhw
1041         _link(fc_whw, sum_h, 0, 0);                 //
1042         _link(mul_r, sum_h, 0, 1);                  //    Whw*Xt + Bhw + (rt (.) (Whr*Ht-1 + Bhr))
1043         _link_with_clip(sum_h, act_o, cell->clip);  // _g(Whw*Xt + Bhw + (rt (.) (Whr*Ht-1 + Bhr)))
1044     } else {
1045         auto fc_wh = _fc(name + ":fc_o", prc, {N, S}, o_W, o_B);
1046         auto concat_h = _concat(name + ":concat_h", prc, {N, D+S}, 2);
1047
1048         _link(split, mul_r, 1, 0);                  //
1049         _link(in_h_state, mul_r, 1);                //              rt (.) Ht-1
1050         _link(in_data, concat_h, 0);                //
1051         _link(mul_r, concat_h, 0, 1);               //       [Xt + (rt (.) Ht-1)]
1052         _link(concat_h, fc_wh);                     //    Wh*[Xt + (rt (.) Ht-1)] + Bh
1053         _link_with_clip(fc_wh, act_o, cell->clip);  // _g(Wh*[Xt + (rt (.) Ht-1)] + Bh)
1054     }
1055
1056     _link(split, pwr_m1, 0, 0);   //  1 - zt
1057     _link(act_o, mul, 0, 0);      //
1058     _link(pwr_m1, mul, 0, 1);     // (1 - zt) (.) ht
1059     _link(split, mul_u, 0, 0);    //
1060     _link(in_h_state, mul_u, 1);  //                   zt (.) Ht-1
1061     _link(mul, sum, 0, 0);        //
1062     _link(mul_u, sum, 0, 1);      // (1 - zt) (.) ht + zt (.) Ht-1
1063
1064     // Output
1065     sum->outData[0] = out_h_state;
1066     out_h_state->creatorLayer = sum;
1067
1068     return true;
1069 }
1070
1071 static bool unrollCell(CNNLayerPtr cur, ICNNNetwork &net) {
1072     auto cell = std::dynamic_pointer_cast<RNNCellBase>(cur);
1073     switch (cell->cellType) {
1074         case RNNCellBase::LSTM:    return unrollLSTMCellBody(cur);
1075         case RNNCellBase::GRU:     return unrollGRUCellBody(cur);
1076         case RNNCellBase::GRU_LBR: return unrollGRUCellBody(cur, true);
1077         case RNNCellBase::RNN:     return unrollRNNCellBody(cur);
1078     }
1079     return false;
1080 }
1081
1082 static bool unrollSeq(CNNLayerPtr cur, ICNNNetwork &net) {
1083     if (!one_of(cur->type, "LSTMSequence", "GRUSequence", "RNNSequence"))
1084     return true;
1085
1086     auto seq = std::dynamic_pointer_cast<RNNSequenceLayer>(cur);
1087     IE_ASSERT(seq) << "Cannot cast object with type ***Sequence to RNNSequenceLayer object";
1088
1089     auto name = seq->name;
1090
1091     auto in_data = seq->insData[0].lock();
1092     auto in_h_state = seq->insData[1].lock();
1093     auto out_data = seq->outData[0];
1094
1095     auto in_d_dims = in_data->getTensorDesc().getDims();
1096     auto state_dims = in_h_state->getTensorDesc().getDims();
1097     auto out_d_dims = out_data->getTensorDesc().getDims();
1098
1099     const int axis = seq->axis;
1100     const auto direct = seq->direction;
1101     const auto prc = seq->precision;
1102
1103     /** Release links on Seq */
1104     for (auto &ins : seq->insData)
1105     ins.lock()->inputTo.erase(seq->name);
1106     for (auto &outs : seq->outData)
1107     outs->creatorLayer.reset();
1108
1109     /** Body subgraph*/
1110     auto in_d_body_dims = in_d_dims;
1111     in_d_body_dims[axis] = 1;
1112
1113     auto in_d_body_squeeze_dims = in_d_dims;
1114     in_d_body_squeeze_dims.erase(in_d_body_squeeze_dims.begin() + axis);
1115
1116     auto out_d_body_dims = out_d_dims;
1117     out_d_body_dims[axis] = 1;
1118
1119     auto out_d_body_squeeze_dims = out_d_dims;
1120     out_d_body_squeeze_dims.erase(out_d_body_squeeze_dims.begin() + axis);
1121
1122     auto body_in_data = DataPtr(new Data(name + ":data_in",
1123             TensorDesc { prc, in_d_body_dims, TensorDesc::getLayoutByDims(in_d_body_dims) }));
1124
1125     auto resh1 = _resh(name + ":resh1", prc, in_d_body_squeeze_dims);
1126     auto cell  = _cell(name + ":cell", prc, out_d_body_squeeze_dims, state_dims, seq->cellType);
1127     auto resh2 = _resh(name + ":resh2", prc, out_d_body_dims);
1128
1129     _link(body_in_data, resh1);
1130     _link(resh1, cell);
1131     _link(cell, resh2);
1132
1133     cell->_weights = seq->_weights;
1134     cell->_biases = seq->_biases;
1135     cell->hidden_size = seq->hidden_size;
1136     cell->clip = seq->clip;
1137     cell->activations = seq->activations;
1138     cell->activation_alpha = seq->activation_alpha;
1139     cell->activation_beta = seq->activation_beta;
1140
1141     const size_t NS = cell->outData.size();  // num of state
1142
1143     /** TI layer */
1144     auto ti = _ti(name + ":ti", prc, NS);
1145     _link(in_data, ti, 0);
1146
1147     ti->outData[0] = out_data;
1148     out_data->creatorLayer = ti;
1149
1150     ti->body.inputs.push_back(body_in_data);
1151     ti->body.outputs.push_back(resh2->outData[0]);
1152
1153     int start = direct == RNNSequenceLayer::FWD ? 0 : -1;
1154     int end = direct == RNNSequenceLayer::FWD ? -1 : 0;
1155     int step = direct == RNNSequenceLayer::FWD ? 1 : -1;
1156     ti->input_port_map.push_back({0, 0, axis, step, start, end, 1});
1157     ti->output_port_map.push_back({0, 0, axis, step, start, end, 1});
1158
1159     for (size_t i = 0; i < NS; i++) {
1160         auto in_state = seq->insData[1 + i].lock();
1161         _link(in_state, ti, 1 + i);
1162
1163         auto out_state = seq->outData[1 + i];
1164         ti->outData[1 + i] = out_state;
1165         out_state->creatorLayer = ti;
1166
1167         auto body_in_state = DataPtr(new Data(name + ":state_in_" + std::to_string(i),
1168                 TensorDesc { prc, state_dims, TensorDesc::getLayoutByDims(state_dims) }));
1169
1170         _link(body_in_state, cell, 1 + i);
1171
1172         ti->body.inputs.push_back(body_in_state);
1173         ti->body.outputs.push_back(cell->outData[i]);
1174
1175         const int ii = 1 + static_cast<int>(i);
1176         ti->input_port_map.push_back({ii, ii, -1, 0, 0, 0, 0});
1177         ti->output_port_map.push_back({ii, ii, -1, 0, 0, 0, 0});
1178         ti->back_edges.push_back({ii, ii, -1, 0, 0, 0, 0});
1179     }
1180
1181     unrollTI(ti, net);
1182
1183     return true;
1184 }
1185
1186 /************************************************************/
1187 /****  Converter API  ***************************************/
1188 /************************************************************/
1189
1190 template <typename T>
1191 bool ApplyForAll(ICNNNetwork &net, T action) {
1192     auto all_layers = details::CNNNetSortTopologically(net);
1193     bool sts = true;
1194
1195     for (auto &layer : all_layers)
1196         sts &= action(layer, net);
1197
1198     return sts;
1199 }
1200
1201 template <typename T, typename P>
1202 bool ApplyForAll_if(ICNNNetwork &net, T action, P pred) {
1203     auto all_layers = details::CNNNetSortTopologically(net);
1204     bool sts = true;
1205
1206     for (auto &layer : all_layers)
1207         if (pred(layer))
1208             sts &= action(layer, net);
1209
1210     return sts;
1211 }
1212
1213 bool CombineRNNSeq(ICNNNetwork &net) {
1214     return ApplyForAll(net, convertToRNNSeq);
1215 }
1216
1217 bool UnrollTI(ICNNNetwork &net) {
1218     return ApplyForAll(net, unrollTI);
1219 }
1220
1221 bool UnrollRNN_if(ICNNNetwork &net, const std::function<bool(const RNNCellBase&)> pred) {
1222     // Filter layers by RNN specific type
1223     auto _seq_pred = [&] (CNNLayerPtr layer) {
1224         auto rnn = std::dynamic_pointer_cast<RNNSequenceLayer>(layer);
1225         if (!rnn) return false;
1226         return pred(*rnn.get());
1227     };
1228     auto _cell_pred = [&] (CNNLayerPtr layer) {
1229         auto rnn = std::dynamic_pointer_cast<RNNCellBase>(layer);
1230         if (!rnn || !one_of(rnn->type, "LSTMCell", "GRUCell", "RNNCell")) return false;
1231         return pred(*rnn.get());
1232     };
1233
1234     bool res = true;
1235     res &= ApplyForAll_if(net, unrollSeq, _seq_pred);
1236     res &= ApplyForAll_if(net, unrollCell, _cell_pred);
1237     return res;
1238 }
1239
1240 }  // namespace NetPass
1241 }  // namespace InferenceEngine
1242