Remove Switch and Merge nodes from TensorFlow networks
[platform/upstream/opencv.git] / modules / dnn / src / tensorflow / tf_graph_simplifier.cpp
1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
7
8 #include "../precomp.hpp"
9
10 #ifdef HAVE_PROTOBUF
11
12 #include "tf_graph_simplifier.hpp"
13 #include <queue>
14
15 namespace cv { namespace dnn {
16 CV__DNN_EXPERIMENTAL_NS_BEGIN
17
18 using ::google::protobuf::RepeatedField;
19 using ::google::protobuf::MapPair;
20
21 class Subgraph  // Interface to match and replace TensorFlow subgraphs.
22 {
23 public:
24     virtual ~Subgraph() {}
25
26     // Add a node to be matched in the origin graph. Specify ids of nodes that
27     // are expected to be inputs. Returns id of a newly added node.
28     // TODO: Replace inputs to std::vector<int> in C++11
29     int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1,
30                        int input_2 = -1, int input_3 = -1)
31     {
32         int nodeInputs[] = {input_0, input_1, input_2, input_3};
33         int numInputs = 0;
34         for (int i = 0; i < 4; ++i)
35         {
36             numInputs += (int)(nodeInputs[i] != -1);
37         }
38         return addNodeToMatch(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
39     }
40
41     int addNodeToMatch(const std::string& op, const std::vector<int>& inputs_)
42     {
43         for (int i = 0; i < inputs_.size(); ++i)
44         {
45             CV_Assert(inputs_[i] < (int)nodes.size());
46         }
47         nodes.push_back(op);
48         inputs.push_back(inputs_);
49         return nodes.size() - 1;
50     }
51
52     // Specify resulting node. All the matched nodes in subgraph excluding
53     // input nodes will be fused into this single node.
54     // TODO: Replace inputs to std::vector<int> in C++11
55     void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1,
56                       int input_2 = -1, int input_3 = -1, int input_4 = -1,
57                       int input_5 = -1)
58     {
59         int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5};
60         int numInputs = 0;
61         for (int i = 0; i < 6; ++i)
62         {
63             CV_Assert(nodeInputs[i] < (int)nodes.size());
64             numInputs += (int)(nodeInputs[i] != -1);
65         }
66         setFusedNode(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
67     }
68
69     void setFusedNode(const std::string& op, const std::vector<int>& inputs_)
70     {
71         fusedNodeInputs = inputs_;
72         fusedNodeOp = op;
73         nodesToFuse.clear();
74         for (int i = 0; i < nodes.size(); ++i)
75         {
76             if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end() &&
77                 nodes[i] != "Const")
78                 nodesToFuse.push_back(i);
79         }
80     }
81
82     static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net,
83                                                    const tensorflow::NodeDef& node,
84                                                    int inpId)
85     {
86         CV_Assert(inpId < node.input_size());
87         std::string name = node.input(inpId);
88         // If operation produces several tensors, they are specified by index
89         // after ':' character. In example, "input:0".
90         name = name.substr(0, name.rfind(':'));
91         const int numNodes = net.node_size();
92         for (int i = 0; i < numNodes; ++i)
93         {
94             if (net.node(i).name() == name)
95                 return net.node(i);
96         }
97         CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
98     }
99
100     // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
101     // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
102     virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds)
103     {
104         matchedNodesIds.clear();
105         matchedNodesIds.reserve(nodesToFuse.size());
106
107         int numNodes = net.node_size();
108         for (int i = 0; i < nodesToFuse.size(); ++i)
109         {
110             while (nodeId < numNodes && net.node(nodeId).op() == "Const")
111             {
112                 nodeId += 1;
113             }
114             if (nodeId > numNodes - 1)
115                 return false;
116
117             const tensorflow::NodeDef& node = net.node(nodeId);
118
119             if (node.op() != nodes[nodesToFuse[i]])
120                 return false;
121
122             std::vector<int>& inputNodes = inputs[nodesToFuse[i]];
123             if (inputNodes.size() != node.input_size())
124                 return false;
125             for (int j = 0; j < inputNodes.size(); ++j)
126             {
127                 if (nodes[inputNodes[j]].empty())  // Unknown input node type.
128                     continue;
129                 const tensorflow::NodeDef& inpNode = getInputNode(net, node, j);
130                 if (inpNode.op() != nodes[inputNodes[j]])
131                     return false;
132             }
133
134             matchedNodesIds.push_back(nodeId);
135             nodeId += 1;
136         }
137         return true;
138     }
139
140     // Fuse matched subgraph.
141     void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds)
142     {
143         // Extract names of input nodes.
144         std::vector<std::string> inputsNames(fusedNodeInputs.size());
145         for (int i = 0; i < fusedNodeInputs.size(); ++i)
146         {
147             std::string inpName;
148             // Find input node name looking at inputs of fused nodes.
149             for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
150             {
151                 const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]);
152                 std::vector<int>& inpIndices = inputs[nodesToFuse[j]];
153
154                 CV_Assert(node.input_size() == inpIndices.size());
155                 for (int k = 0; k < inpIndices.size(); ++k)
156                 {
157                     if (inpIndices[k] == fusedNodeInputs[i])
158                     {
159                         inpName = node.input(k);
160                         break;
161                     }
162                 }
163             }
164             CV_Assert(!inpName.empty());
165             inputsNames[i] = inpName;
166         }
167
168         // Remove matched nodes except the last one. Indices in ascending order are expected.
169         tensorflow::NodeDef* node = net.mutable_node(matchedNodesIds.back());
170         for (int i = matchedNodesIds.size() - 2; i >= 0; --i)
171             net.mutable_node()->DeleteSubrange(matchedNodesIds[i], 1);
172
173         // Modify the last node to be a fused one.
174         node->set_op(fusedNodeOp);
175         node->clear_input();
176         for (int i = 0; i < inputsNames.size(); ++i)
177         {
178             node->add_input(inputsNames[i]);
179         }
180
181         std::vector<tensorflow::NodeDef*> inputNodes(inputsNames.size());
182         for (int i = 0; i < inputsNames.size(); ++i)
183         {
184             inputNodes[i] = (tensorflow::NodeDef*)&getInputNode(net, *node, i);
185         }
186         finalize(net, node, inputNodes);
187     }
188
189     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*,
190                           std::vector<tensorflow::NodeDef*>&) {}
191
192 private:
193     std::vector<std::string> nodes;         // Nodes to be matched in the origin graph.
194     std::vector<std::vector<int> > inputs;  // Connections of an every node to it's inputs.
195
196     std::string fusedNodeOp;           // Operation name of resulting fused node.
197     std::vector<int> nodesToFuse;      // Set of nodes to be fused.
198     std::vector<int> fusedNodeInputs;  // Inputs of fused node.
199 };
200
201 class BatchNormSubgraph : public Subgraph
202 {
203 public:
204     BatchNormSubgraph()
205     {
206         int input = addNodeToMatch("");
207         int epsilon = addNodeToMatch("Const");
208         int moving_variance = addNodeToMatch("Const");
209         int moving_mean = addNodeToMatch("Const");
210         int beta = addNodeToMatch("Const");
211         int gamma = addNodeToMatch("Const");
212         int add = addNodeToMatch("Add", moving_variance, epsilon);
213         int rsqrt = addNodeToMatch("Rsqrt", add);
214         int mul = addNodeToMatch("Mul", rsqrt, gamma);
215         int mul_1 = addNodeToMatch("Mul", input, mul);
216         int mul_2 = addNodeToMatch("Mul", moving_mean, mul);
217         int sub = addNodeToMatch("Sub", beta, mul_2);
218         addNodeToMatch("Add", mul_1, sub);
219
220         setFusedNode("FusedBatchNorm", input, gamma, beta, moving_mean, moving_variance, epsilon);
221     }
222
223     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
224                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
225     {
226         Mat epsMat = getTensorContent(inputNodes.back()->attr().at("value").tensor());
227         CV_CheckEQ(epsMat.total(), (size_t)1, ""); CV_CheckTypeEQ(epsMat.type(), CV_32FC1, "");
228
229         fusedNode->mutable_input()->RemoveLast();
230         fusedNode->clear_attr();
231         tensorflow::AttrValue epsilon;
232         epsilon.set_f(epsMat.at<float>(0));
233         fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
234     }
235 };
236
237 class BatchNormNoGammaSubgraph : public Subgraph
238 {
239 public:
240     BatchNormNoGammaSubgraph()
241     {
242         int input = addNodeToMatch("");
243         int epsilon = addNodeToMatch("Const");
244         int moving_variance = addNodeToMatch("Const");
245         int moving_mean = addNodeToMatch("Const");
246         int beta = addNodeToMatch("Const");
247         int add = addNodeToMatch("Add", moving_variance, epsilon);
248         int rsqrt = addNodeToMatch("Rsqrt", add);
249         int mul = addNodeToMatch("Mul", input, rsqrt);
250         int mul_1 = addNodeToMatch("Mul", moving_mean, rsqrt);
251         int sub = addNodeToMatch("Sub", beta, mul_1);
252         addNodeToMatch("Add", mul, sub);
253
254         // There is a fake reference to beta that will be replaced to a new gamma tensor.
255         setFusedNode("FusedBatchNorm", input, beta, beta, moving_mean, moving_variance, epsilon);
256     }
257
258     virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
259                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
260     {
261         Mat epsMat = getTensorContent(inputNodes.back()->attr().at("value").tensor());
262         CV_CheckEQ(epsMat.total(), (size_t)1, ""); CV_CheckTypeEQ(epsMat.type(), CV_32FC1, "");
263
264         fusedNode->mutable_input()->RemoveLast();
265         fusedNode->clear_attr();
266         tensorflow::AttrValue epsilon;
267         epsilon.set_f(epsMat.at<float>(0));
268         fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
269
270         tensorflow::NodeDef* gamma = net.add_node();
271         gamma->set_op("Const");
272         gamma->set_name(fusedNode->name() + "/gamma");
273         // Just put a single value to recognize this node as Const.
274         gamma->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", epsilon));
275         fusedNode->set_input(1, gamma->name());
276     }
277 };
278
279 // tf.contrib.layers.flatten
280 class FlattenSubgraph : public Subgraph
281 {
282 public:
283     FlattenSubgraph()
284     {
285         int input = addNodeToMatch("");
286         int shape = addNodeToMatch("Const");
287         int stack = addNodeToMatch("Const");
288         int stack_1 = addNodeToMatch("Const");
289         int stack_2 = addNodeToMatch("Const");
290         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
291         int shape_pack = addNodeToMatch("Const");
292         int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
293         addNodeToMatch("Reshape", input, pack);
294
295         setFusedNode("Flatten", input);
296     }
297 };
298
299 // tf.contrib.layers.flatten in case of unknown batch size
300 class FlattenShapeSubgraph : public Subgraph
301 {
302 public:
303     FlattenShapeSubgraph()
304     {
305         int input = addNodeToMatch("");
306         int shape = addNodeToMatch("Shape", input);
307         int stack = addNodeToMatch("Const");
308         int stack_1 = addNodeToMatch("Const");
309         int stack_2 = addNodeToMatch("Const");
310         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
311         int shape_pack = addNodeToMatch("Const");
312         int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
313         addNodeToMatch("Reshape", input, pack);
314
315         setFusedNode("Flatten", input);
316     }
317 };
318
319 // K.layers.Softmax
320 class SoftMaxKerasSubgraph : public Subgraph
321 {
322 public:
323     SoftMaxKerasSubgraph()
324     {
325         int input = addNodeToMatch("");
326         int maxReductionIndices = addNodeToMatch("Const");
327         int smMax = addNodeToMatch("Max", input, maxReductionIndices);
328         int smSub = addNodeToMatch("Sub", input, smMax);
329         int smExp = addNodeToMatch("Exp", smSub);
330         int sumReductionIndices = addNodeToMatch("Const");
331         int smSum = addNodeToMatch("Sum", smExp, sumReductionIndices);
332         addNodeToMatch("RealDiv", smExp, smSum);
333
334         setFusedNode("Softmax", input);
335     }
336 };
337
338 class ReLU6KerasSubgraph : public Subgraph
339 {
340 public:
341     ReLU6KerasSubgraph()
342     {
343         int input = addNodeToMatch("");
344         int relu = addNodeToMatch("Relu", input);
345         int maxValue = addNodeToMatch("Const");
346         int clipValue = addNodeToMatch("Const");
347         int minimum = addNodeToMatch("Minimum", relu, maxValue);
348         addNodeToMatch("Maximum", minimum, clipValue);
349
350         setFusedNode("Relu6", input);
351     }
352
353     virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE
354     {
355         if (!Subgraph::match(net, nodeId, matchedNodesIds))
356             return false;
357         Mat maxValue = getTensorContent(net.node(nodeId + 1).attr().at("value").tensor());
358         return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
359     }
360 };
361
362 // Keras' reshape stores output shape in separate Const nodes by one value.
363 // Need to merge them into a single Const node.
364 class ReshapeKerasSubgraph : public Subgraph
365 {
366 public:
367     ReshapeKerasSubgraph(int _numOutDims) : numOutDims(_numOutDims)
368     {
369         int input = addNodeToMatch("");
370         int shape = addNodeToMatch("Shape", input);
371         int stack = addNodeToMatch("Const");
372         int stack_1 = addNodeToMatch("Const");
373         int stack_2 = addNodeToMatch("Const");
374         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
375
376         std::vector<int> ids(1 + numOutDims);
377         ids[0] = strided_slice;
378         for (int i = 0; i < numOutDims; ++i)
379             ids[1 + i] = addNodeToMatch("Const");
380         int pack = addNodeToMatch("Pack", ids);
381         addNodeToMatch("Reshape", input, pack);
382
383         ids[0] = input;
384         setFusedNode("Reshape", ids);
385     }
386
387     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
388                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
389     {
390         std::vector<int> shape(numOutDims + 1);  // batch size in Keras is implicit.
391         shape[0] = -1;
392         for (int i = 0; i < numOutDims; ++i)
393         {
394             shape[1 + i] = inputNodes[1 + i]->attr().at("value").tensor().int_val(0);
395         }
396         tensorflow::TensorProto* shapeTensor = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
397         fusedNode->mutable_input()->DeleteSubrange(2, numOutDims - 1);
398
399         shapeTensor->clear_int_val();
400         for (int i = 0; i < shape.size(); ++i)
401         {
402             shapeTensor->add_int_val(shape[i]);
403         }
404     }
405
406 private:
407     int numOutDims;
408 };
409
410 class L2NormalizeSubgraph : public Subgraph
411 {
412 public:
413     L2NormalizeSubgraph()
414     {
415         int input = addNodeToMatch("");
416         int square = addNodeToMatch("Square", input);
417         int reductionIndices = addNodeToMatch("Const");
418         int sum = addNodeToMatch("Sum", square, reductionIndices);
419         int y = addNodeToMatch("Const");
420         int maximum = addNodeToMatch("Maximum", sum, y);
421         int rsqrt = addNodeToMatch("Rsqrt", maximum);
422         addNodeToMatch("Mul", input, rsqrt);
423         setFusedNode("L2Normalize", input, reductionIndices);
424     }
425 };
426
427 class DeconvolutionValidKerasSubgraph : public Subgraph
428 {
429 public:
430     DeconvolutionValidKerasSubgraph()
431     {
432         int input = addNodeToMatch("");
433         int shape = addNodeToMatch("Shape", input);
434         int kernel = addNodeToMatch("Const");
435
436         int stack = addNodeToMatch("Const");
437         int stack_1 = addNodeToMatch("Const");
438         int stack_2 = addNodeToMatch("Const");
439         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
440
441         stack = addNodeToMatch("Const");
442         stack_1 = addNodeToMatch("Const");
443         stack_2 = addNodeToMatch("Const");
444         int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
445
446         stack = addNodeToMatch("Const");
447         stack_1 = addNodeToMatch("Const");
448         stack_2 = addNodeToMatch("Const");
449         int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
450
451         int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
452         int add = addNodeToMatch("Add", mul, addNodeToMatch("Const"));
453
454         int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
455         int add_1 = addNodeToMatch("Add", mul_1, addNodeToMatch("Const"));
456         int pack = addNodeToMatch("Pack", strided_slice, add, add_1, addNodeToMatch("Const"));
457         addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
458         // Put any unused Const op to the first input.
459         setFusedNode("Conv2DBackpropInput", stack, kernel, input);
460     }
461
462     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
463                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
464     {
465         // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
466         // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
467         // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
468         // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
469         std::string padMode = fusedNode->attr().at("padding").s();
470         CV_Assert(padMode == "VALID");
471
472         const tensorflow::TensorShapeProto& kernelShape =
473             inputNodes[1]->mutable_attr()->at("value").tensor().tensor_shape();
474
475         CV_Assert(kernelShape.dim_size() == 4);
476         const int kernelHeight = kernelShape.dim(0).size();
477         const int kernelWidth = kernelShape.dim(1).size();
478
479         tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
480         outShape->clear_int_val();
481         outShape->add_int_val(-1);
482         outShape->add_int_val(kernelHeight);
483         outShape->add_int_val(kernelWidth);
484         outShape->add_int_val(-1);
485     }
486 };
487
488 class DeconvolutionSameKerasSubgraph : public Subgraph
489 {
490 public:
491     DeconvolutionSameKerasSubgraph()
492     {
493         int input = addNodeToMatch("");
494         int shape = addNodeToMatch("Shape", input);
495         int kernel = addNodeToMatch("Const");
496
497         int stack = addNodeToMatch("Const");
498         int stack_1 = addNodeToMatch("Const");
499         int stack_2 = addNodeToMatch("Const");
500         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
501
502         stack = addNodeToMatch("Const");
503         stack_1 = addNodeToMatch("Const");
504         stack_2 = addNodeToMatch("Const");
505         int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
506
507         stack = addNodeToMatch("Const");
508         stack_1 = addNodeToMatch("Const");
509         stack_2 = addNodeToMatch("Const");
510         int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
511
512         int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
513
514         int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
515         int pack = addNodeToMatch("Pack", strided_slice, mul, mul_1, addNodeToMatch("Const"));
516         addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
517         // Put any unused Const op to the first input.
518         setFusedNode("Conv2DBackpropInput", stack, kernel, input);
519     }
520
521     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
522                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
523     {
524         // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
525         // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
526         // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
527         // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
528         std::string padMode = fusedNode->attr().at("padding").s();
529         CV_Assert(padMode == "SAME");
530
531         const tensorflow::AttrValue_ListValue& strides = fusedNode->attr().at("strides").list();
532         CV_Assert(strides.i_size() == 4);
533
534         const int strideY = strides.i(1);
535         const int strideX = strides.i(2);
536
537         tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
538         outShape->clear_int_val();
539         outShape->add_int_val(-1);
540         outShape->add_int_val(strideY);
541         outShape->add_int_val(strideX);
542         outShape->add_int_val(-1);
543     }
544 };
545
546 // In case of resizing by factor.
547 class ResizeBilinearSubgraph : public Subgraph
548 {
549 public:
550     ResizeBilinearSubgraph()
551     {
552         int input = addNodeToMatch("");
553
554         int shape = addNodeToMatch("Shape", input);
555         int stack = addNodeToMatch("Const");
556         int stack_1 = addNodeToMatch("Const");
557         int stack_2 = addNodeToMatch("Const");
558         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
559         int factorY = addNodeToMatch("Const");
560         int mul = addNodeToMatch("Mul", strided_slice, factorY);
561
562         shape = addNodeToMatch("Shape", input);
563         stack = addNodeToMatch("Const");
564         stack_1 = addNodeToMatch("Const");
565         stack_2 = addNodeToMatch("Const");
566         strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
567         int factorX = addNodeToMatch("Const");
568         int mul_1 = addNodeToMatch("Mul", strided_slice, factorX);
569
570         int pack = addNodeToMatch("Pack", mul, mul_1);
571
572         addNodeToMatch("ResizeBilinear", input, pack);
573         setFusedNode("ResizeBilinear", input, factorY, factorX);
574     }
575 };
576
577 // In case of resizing by factor.
578 class UpsamplingKerasSubgraph : public Subgraph
579 {
580 public:
581     UpsamplingKerasSubgraph()
582     {
583         int input = addNodeToMatch("");
584         int shape = addNodeToMatch("Shape", input);
585         int stack = addNodeToMatch("Const");
586         int stack_1 = addNodeToMatch("Const");
587         int stack_2 = addNodeToMatch("Const");
588         int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
589         int factors = addNodeToMatch("Const");
590         int mul = addNodeToMatch("Mul", strided_slice, factors);
591         addNodeToMatch("ResizeNearestNeighbor", input, mul);
592         setFusedNode("ResizeNearestNeighbor", input, factors);
593     }
594
595     virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
596                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
597     {
598         Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
599         CV_CheckEQ(factorsMat.total(), (size_t)2, ""); CV_CheckTypeEQ(factorsMat.type(), CV_32SC1, "");
600
601         // Height scale factor
602         tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
603         factorY->clear_int_val();
604         factorY->clear_tensor_content();
605         factorY->add_int_val(factorsMat.at<int>(0, 0));
606
607         // Width scale factor.
608         tensorflow::NodeDef* factorXNode = net.add_node();
609         factorXNode->set_op("Const");
610         factorXNode->set_name(fusedNode->name() + "/factor_y");
611
612         tensorflow::AttrValue factorX;
613         factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32);
614         factorX.mutable_tensor()->add_int_val(factorsMat.at<int>(0, 1));
615         factorXNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", factorX));
616
617         fusedNode->add_input(factorXNode->name());
618     }
619 };
620
621 class ReshapeAsShapeSubgraph : public Subgraph
622 {
623 public:
624     ReshapeAsShapeSubgraph()
625     {
626         int input = addNodeToMatch("");
627         int shapeSrc = addNodeToMatch("");
628         int shape = addNodeToMatch("Shape", shapeSrc);
629         addNodeToMatch("Reshape", input, shape);
630         setFusedNode("Reshape", input, shapeSrc);
631     }
632 };
633
634 class SoftMaxSlimSubgraph : public Subgraph
635 {
636 public:
637     SoftMaxSlimSubgraph()
638     {
639         int input = addNodeToMatch("");
640         int shape = addNodeToMatch("Const");
641         int shapeOp = addNodeToMatch("Shape", input);
642         int reshape = addNodeToMatch("Reshape", input, shape);
643         int softmax = addNodeToMatch("Softmax", reshape);
644         addNodeToMatch("Reshape", softmax, shapeOp);
645         setFusedNode("Softmax", input);
646     }
647 };
648
649 void simplifySubgraphs(tensorflow::GraphDef& net)
650 {
651     std::vector<Ptr<Subgraph> > subgraphs;
652     subgraphs.push_back(Ptr<Subgraph>(new BatchNormSubgraph()));
653     subgraphs.push_back(Ptr<Subgraph>(new BatchNormNoGammaSubgraph()));
654     subgraphs.push_back(Ptr<Subgraph>(new FlattenSubgraph()));
655     subgraphs.push_back(Ptr<Subgraph>(new FlattenShapeSubgraph()));
656     subgraphs.push_back(Ptr<Subgraph>(new SoftMaxKerasSubgraph()));
657     subgraphs.push_back(Ptr<Subgraph>(new ReLU6KerasSubgraph()));
658     subgraphs.push_back(Ptr<Subgraph>(new ReshapeKerasSubgraph(3)));
659     subgraphs.push_back(Ptr<Subgraph>(new L2NormalizeSubgraph()));
660     subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
661     subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
662     subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
663     subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
664     subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
665     subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
666
667     int numNodes = net.node_size();
668     std::vector<int> matchedNodesIds;
669     for (int i = 0; i < numNodes; ++i)
670     {
671         for (int j = 0; j < subgraphs.size(); ++j)
672         {
673             if (subgraphs[j]->match(net, i, matchedNodesIds))
674             {
675                 subgraphs[j]->replace(net, matchedNodesIds);
676                 numNodes -= matchedNodesIds.size() - 1;  // #matchedNodes removed and one added.
677                 break;
678             }
679         }
680     }
681 }
682
683 void RemoveIdentityOps(tensorflow::GraphDef& net)
684 {
685     typedef std::map<String, String>  IdentityOpsMap;
686     IdentityOpsMap identity_ops;
687
688     std::vector<int> identity_ops_idx;
689
690     int layersCount = net.node_size();
691     for (int li = 0; li < layersCount; li++)
692     {
693         const tensorflow::NodeDef &layer = net.node(li);
694         String type = layer.op();
695
696         if (type == "Identity" || type == "Dropout") {
697             identity_ops_idx.push_back(li);
698             identity_ops[layer.name()] = layer.input(0);
699         }
700     }
701
702     for (int li = 0; li < layersCount; li++)
703     {
704         tensorflow::NodeDef* layer = net.mutable_node(li);
705         for (int input_id = 0; input_id < layer->input_size(); input_id++) {
706             String input_op_name = layer->input(input_id);
707             IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
708
709             if (it != identity_ops.end()) {
710                 layer->set_input(input_id, it->second);
711             }
712         }
713     }
714
715     std::sort(identity_ops_idx.begin(), identity_ops_idx.end());
716
717     int removed_nodes = 0;
718     for(size_t i = 0; i < identity_ops_idx.size(); i++) {
719         int start_id = identity_ops_idx[i] - removed_nodes;
720         net.mutable_node()->DeleteSubrange(start_id, 1);
721         removed_nodes++;
722     }
723 }
724
725 Mat getTensorContent(const tensorflow::TensorProto &tensor)
726 {
727     const std::string& content = tensor.tensor_content();
728     switch (tensor.dtype())
729     {
730         case tensorflow::DT_FLOAT:
731         {
732             if (!content.empty())
733                 return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
734             else
735             {
736                 const RepeatedField<float>& field = tensor.float_val();
737                 CV_Assert(!field.empty());
738                 return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
739             }
740         }
741         case tensorflow::DT_DOUBLE:
742         {
743             if (!content.empty())
744                 return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
745             else
746             {
747                 const RepeatedField<double>& field = tensor.double_val();
748                 CV_Assert(!field.empty());
749                 return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
750             }
751         }
752         case tensorflow::DT_INT32:
753         {
754             if (!content.empty())
755                 return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
756             else
757             {
758                 const RepeatedField<int32_t>& field = tensor.int_val();
759                 CV_Assert(!field.empty());
760                 return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
761             }
762         }
763         case tensorflow::DT_HALF:
764         {
765             Mat halfs;
766             if (!content.empty())
767             {
768                 static const int kHalfSize = 2;
769                 halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
770             }
771             else
772             {
773                 const RepeatedField<int32_t>& field = tensor.half_val();
774                 CV_Assert(!field.empty());
775                 Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
776                 ints.convertTo(halfs, CV_16UC1);
777             }
778             // Reinterpret as a signed shorts just for a convertFp16 call.
779             Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
780             Mat floats(halfs.size(), CV_32FC1);
781             convertFp16(halfsSigned, floats);
782             return floats;
783         }
784         case tensorflow::DT_QUINT8:
785         {
786             CV_Assert(!content.empty());
787             return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
788         }
789         default:
790             CV_Error(Error::StsError, "Tensor's data type is not supported");
791             break;
792     }
793     return Mat();
794 }
795
796 void releaseTensor(tensorflow::TensorProto* tensor)
797 {
798     if (!tensor->mutable_tensor_content()->empty())
799     {
800         delete tensor->release_tensor_content();
801     }
802 }
803
804 static void permute(google::protobuf::RepeatedPtrField<tensorflow::NodeDef>* data,
805                     const std::vector<int>& indices)
806 {
807     const int num = data->size();
808     CV_Assert(num == indices.size());
809
810     std::vector<int> elemIdToPos(num);
811     std::vector<int> posToElemId(num);
812     for (int i = 0; i < num; ++i)
813     {
814         elemIdToPos[i] = i;
815         posToElemId[i] = i;
816     }
817     for (int i = 0; i < num; ++i)
818     {
819         int elemId = indices[i];
820         int pos = elemIdToPos[elemId];
821         if (pos != i)
822         {
823             data->SwapElements(i, pos);
824             const int swappedElemId = posToElemId[i];
825             elemIdToPos[elemId] = i;
826             elemIdToPos[swappedElemId] = pos;
827
828             posToElemId[i] = elemId;
829             posToElemId[pos] = swappedElemId;
830         }
831     }
832 }
833
834 // Is based on tensorflow::graph_transforms::SortByExecutionOrder
835 void sortByExecutionOrder(tensorflow::GraphDef& net)
836 {
837     // Maps node's name to index at net.node() list.
838     std::map<std::string, int> nodesMap;
839     std::map<std::string, int>::iterator nodesMapIt;
840     for (int i = 0; i < net.node_size(); ++i)
841     {
842         const tensorflow::NodeDef& node = net.node(i);
843         nodesMap.insert(std::make_pair(node.name(), i));
844     }
845
846     // Indices of nodes which use specific node as input.
847     std::vector<std::vector<int> > edges(nodesMap.size());
848     std::vector<int> numRefsToAdd(nodesMap.size(), 0);
849     std::vector<int> nodesToAdd;
850     for (int i = 0; i < net.node_size(); ++i)
851     {
852         const tensorflow::NodeDef& node = net.node(i);
853         for (int j = 0; j < node.input_size(); ++j)
854         {
855             std::string inpName = node.input(j);
856             inpName = inpName.substr(0, inpName.rfind(':'));
857             inpName = inpName.substr(inpName.find('^') + 1);
858
859             nodesMapIt = nodesMap.find(inpName);
860             CV_Assert(nodesMapIt != nodesMap.end());
861             edges[nodesMapIt->second].push_back(i);
862         }
863         if (node.input_size() == 0)
864             nodesToAdd.push_back(i);
865         else
866         {
867             if (node.op() == "Merge" || node.op() == "RefMerge")
868             {
869                 int numControlEdges = 0;
870                 for (int j = 0; j < node.input_size(); ++j)
871                     numControlEdges += node.input(j)[0] == '^';
872                 numRefsToAdd[i] = numControlEdges + 1;
873             }
874             else
875                 numRefsToAdd[i] = node.input_size();
876         }
877     }
878
879     std::vector<int> permIds;
880     permIds.reserve(net.node_size());
881     while (!nodesToAdd.empty())
882     {
883         int nodeToAdd = nodesToAdd.back();
884         nodesToAdd.pop_back();
885
886         permIds.push_back(nodeToAdd);
887
888         for (int i = 0; i < edges[nodeToAdd].size(); ++i)
889         {
890             int consumerId = edges[nodeToAdd][i];
891             if (numRefsToAdd[consumerId] > 0)
892             {
893                 if (numRefsToAdd[consumerId] == 1)
894                     nodesToAdd.push_back(consumerId);
895                 else
896                     CV_Assert(numRefsToAdd[consumerId] >= 0);
897                 numRefsToAdd[consumerId] -= 1;
898             }
899         }
900     }
901     CV_Assert(permIds.size() == net.node_size());
902     permute(net.mutable_node(), permIds);
903 }
904
905 // Remove training switches (Switch and Merge nodes and corresponding subgraphs).
906 void removePhaseSwitches(tensorflow::GraphDef& net)
907 {
908     std::vector<int> nodesToRemove;
909     std::map<std::string, int> nodesMap;
910     std::map<std::string, int>::iterator nodesMapIt;
911     std::queue<int> mergeOpSubgraphNodes;
912     for (int i = 0; i < net.node_size(); ++i)
913     {
914         const tensorflow::NodeDef& node = net.node(i);
915         nodesMap.insert(std::make_pair(node.name(), i));
916         if (node.op() == "Switch" || node.op() == "Merge")
917         {
918             CV_Assert(node.input_size() > 0);
919             // Replace consumers' inputs.
920             for (int j = 0; j < net.node_size(); ++j)
921             {
922                 tensorflow::NodeDef* consumer = net.mutable_node(j);
923                 for (int k = 0; k < consumer->input_size(); ++k)
924                 {
925                     std::string inpName = consumer->input(k);
926                     inpName = inpName.substr(0, inpName.rfind(':'));
927                     if (inpName == node.name())
928                     {
929                         consumer->set_input(k, node.input(0));
930                     }
931                 }
932             }
933             nodesToRemove.push_back(i);
934             if (node.op() == "Merge")
935                 mergeOpSubgraphNodes.push(i);
936         }
937     }
938
939     std::vector<int> numConsumers(net.node_size(), 0);
940     for (int i = 0; i < net.node_size(); ++i)
941     {
942         const tensorflow::NodeDef& node = net.node(i);
943         for (int j = 0; j < node.input_size(); ++j)
944         {
945             std::string inpName = node.input(j);
946             inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
947             nodesMapIt = nodesMap.find(inpName);
948             CV_Assert(nodesMapIt != nodesMap.end());
949             numConsumers[nodesMapIt->second] += 1;
950         }
951     }
952
953     // Remove subgraphs of unused nodes which are terminated by Merge nodes.
954     while (!mergeOpSubgraphNodes.empty())
955     {
956         const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front());
957         mergeOpSubgraphNodes.pop();
958         for (int i = 0; i < node.input_size(); ++i)
959         {
960             std::string inpName = node.input(i);
961             inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
962             nodesMapIt = nodesMap.find(inpName);
963             CV_Assert(nodesMapIt != nodesMap.end());
964
965             int inpNodeId = nodesMapIt->second;
966             if (numConsumers[inpNodeId] == 1)
967             {
968                 mergeOpSubgraphNodes.push(inpNodeId);
969                 nodesToRemove.push_back(inpNodeId);
970             }
971             else if (numConsumers[inpNodeId] > 0)
972                 numConsumers[inpNodeId] -= 1;
973         }
974     }
975     std::sort(nodesToRemove.begin(), nodesToRemove.end());
976     for (int i = nodesToRemove.size() - 1; i >= 0; --i)
977     {
978         if (nodesToRemove[i] < net.node_size())  // Ids might be repeated.
979             net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1);
980     }
981 }
982
983
984 CV__DNN_EXPERIMENTAL_NS_END
985 }}  // namespace dnn, namespace cv
986
987 #endif  // HAVE_PROTOBUF