1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
5 // Copyright (C) 2018, Intel Corporation, all rights reserved.
6 // Third party copyrights are property of their respective owners.
8 #include "../precomp.hpp"
12 #include "tf_graph_simplifier.hpp"
15 namespace cv { namespace dnn {
16 CV__DNN_EXPERIMENTAL_NS_BEGIN
18 using ::google::protobuf::RepeatedField;
19 using ::google::protobuf::MapPair;
21 class Subgraph // Interface to match and replace TensorFlow subgraphs.
24 virtual ~Subgraph() {}
26 // Add a node to be matched in the origin graph. Specify ids of nodes that
27 // are expected to be inputs. Returns id of a newly added node.
28 // TODO: Replace inputs to std::vector<int> in C++11
29 int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1,
30 int input_2 = -1, int input_3 = -1)
32 int nodeInputs[] = {input_0, input_1, input_2, input_3};
34 for (int i = 0; i < 4; ++i)
36 numInputs += (int)(nodeInputs[i] != -1);
38 return addNodeToMatch(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
41 int addNodeToMatch(const std::string& op, const std::vector<int>& inputs_)
43 for (int i = 0; i < inputs_.size(); ++i)
45 CV_Assert(inputs_[i] < (int)nodes.size());
48 inputs.push_back(inputs_);
49 return nodes.size() - 1;
52 // Specify resulting node. All the matched nodes in subgraph excluding
53 // input nodes will be fused into this single node.
54 // TODO: Replace inputs to std::vector<int> in C++11
55 void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1,
56 int input_2 = -1, int input_3 = -1, int input_4 = -1,
59 int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5};
61 for (int i = 0; i < 6; ++i)
63 CV_Assert(nodeInputs[i] < (int)nodes.size());
64 numInputs += (int)(nodeInputs[i] != -1);
66 setFusedNode(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
69 void setFusedNode(const std::string& op, const std::vector<int>& inputs_)
71 fusedNodeInputs = inputs_;
74 for (int i = 0; i < nodes.size(); ++i)
76 if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end() &&
78 nodesToFuse.push_back(i);
82 static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net,
83 const tensorflow::NodeDef& node,
86 CV_Assert(inpId < node.input_size());
87 std::string name = node.input(inpId);
88 // If operation produces several tensors, they are specified by index
89 // after ':' character. In example, "input:0".
90 name = name.substr(0, name.rfind(':'));
91 const int numNodes = net.node_size();
92 for (int i = 0; i < numNodes; ++i)
94 if (net.node(i).name() == name)
97 CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
100 // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
101 // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
102 virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds)
104 matchedNodesIds.clear();
105 matchedNodesIds.reserve(nodesToFuse.size());
107 int numNodes = net.node_size();
108 for (int i = 0; i < nodesToFuse.size(); ++i)
110 while (nodeId < numNodes && net.node(nodeId).op() == "Const")
114 if (nodeId > numNodes - 1)
117 const tensorflow::NodeDef& node = net.node(nodeId);
119 if (node.op() != nodes[nodesToFuse[i]])
122 std::vector<int>& inputNodes = inputs[nodesToFuse[i]];
123 if (inputNodes.size() != node.input_size())
125 for (int j = 0; j < inputNodes.size(); ++j)
127 if (nodes[inputNodes[j]].empty()) // Unknown input node type.
129 const tensorflow::NodeDef& inpNode = getInputNode(net, node, j);
130 if (inpNode.op() != nodes[inputNodes[j]])
134 matchedNodesIds.push_back(nodeId);
140 // Fuse matched subgraph.
141 void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds)
143 // Extract names of input nodes.
144 std::vector<std::string> inputsNames(fusedNodeInputs.size());
145 for (int i = 0; i < fusedNodeInputs.size(); ++i)
148 // Find input node name looking at inputs of fused nodes.
149 for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
151 const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]);
152 std::vector<int>& inpIndices = inputs[nodesToFuse[j]];
154 CV_Assert(node.input_size() == inpIndices.size());
155 for (int k = 0; k < inpIndices.size(); ++k)
157 if (inpIndices[k] == fusedNodeInputs[i])
159 inpName = node.input(k);
164 CV_Assert(!inpName.empty());
165 inputsNames[i] = inpName;
168 // Remove matched nodes except the last one. Indices in ascending order are expected.
169 tensorflow::NodeDef* node = net.mutable_node(matchedNodesIds.back());
170 for (int i = matchedNodesIds.size() - 2; i >= 0; --i)
171 net.mutable_node()->DeleteSubrange(matchedNodesIds[i], 1);
173 // Modify the last node to be a fused one.
174 node->set_op(fusedNodeOp);
176 for (int i = 0; i < inputsNames.size(); ++i)
178 node->add_input(inputsNames[i]);
181 std::vector<tensorflow::NodeDef*> inputNodes(inputsNames.size());
182 for (int i = 0; i < inputsNames.size(); ++i)
184 inputNodes[i] = (tensorflow::NodeDef*)&getInputNode(net, *node, i);
186 finalize(net, node, inputNodes);
189 virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*,
190 std::vector<tensorflow::NodeDef*>&) {}
193 std::vector<std::string> nodes; // Nodes to be matched in the origin graph.
194 std::vector<std::vector<int> > inputs; // Connections of an every node to it's inputs.
196 std::string fusedNodeOp; // Operation name of resulting fused node.
197 std::vector<int> nodesToFuse; // Set of nodes to be fused.
198 std::vector<int> fusedNodeInputs; // Inputs of fused node.
201 class BatchNormSubgraph : public Subgraph
206 int input = addNodeToMatch("");
207 int epsilon = addNodeToMatch("Const");
208 int moving_variance = addNodeToMatch("Const");
209 int moving_mean = addNodeToMatch("Const");
210 int beta = addNodeToMatch("Const");
211 int gamma = addNodeToMatch("Const");
212 int add = addNodeToMatch("Add", moving_variance, epsilon);
213 int rsqrt = addNodeToMatch("Rsqrt", add);
214 int mul = addNodeToMatch("Mul", rsqrt, gamma);
215 int mul_1 = addNodeToMatch("Mul", input, mul);
216 int mul_2 = addNodeToMatch("Mul", moving_mean, mul);
217 int sub = addNodeToMatch("Sub", beta, mul_2);
218 addNodeToMatch("Add", mul_1, sub);
220 setFusedNode("FusedBatchNorm", input, gamma, beta, moving_mean, moving_variance, epsilon);
223 virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
224 std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
226 Mat epsMat = getTensorContent(inputNodes.back()->attr().at("value").tensor());
227 CV_CheckEQ(epsMat.total(), (size_t)1, ""); CV_CheckTypeEQ(epsMat.type(), CV_32FC1, "");
229 fusedNode->mutable_input()->RemoveLast();
230 fusedNode->clear_attr();
231 tensorflow::AttrValue epsilon;
232 epsilon.set_f(epsMat.at<float>(0));
233 fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
237 class BatchNormNoGammaSubgraph : public Subgraph
240 BatchNormNoGammaSubgraph()
242 int input = addNodeToMatch("");
243 int epsilon = addNodeToMatch("Const");
244 int moving_variance = addNodeToMatch("Const");
245 int moving_mean = addNodeToMatch("Const");
246 int beta = addNodeToMatch("Const");
247 int add = addNodeToMatch("Add", moving_variance, epsilon);
248 int rsqrt = addNodeToMatch("Rsqrt", add);
249 int mul = addNodeToMatch("Mul", input, rsqrt);
250 int mul_1 = addNodeToMatch("Mul", moving_mean, rsqrt);
251 int sub = addNodeToMatch("Sub", beta, mul_1);
252 addNodeToMatch("Add", mul, sub);
254 // There is a fake reference to beta that will be replaced to a new gamma tensor.
255 setFusedNode("FusedBatchNorm", input, beta, beta, moving_mean, moving_variance, epsilon);
258 virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
259 std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
261 Mat epsMat = getTensorContent(inputNodes.back()->attr().at("value").tensor());
262 CV_CheckEQ(epsMat.total(), (size_t)1, ""); CV_CheckTypeEQ(epsMat.type(), CV_32FC1, "");
264 fusedNode->mutable_input()->RemoveLast();
265 fusedNode->clear_attr();
266 tensorflow::AttrValue epsilon;
267 epsilon.set_f(epsMat.at<float>(0));
268 fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
270 tensorflow::NodeDef* gamma = net.add_node();
271 gamma->set_op("Const");
272 gamma->set_name(fusedNode->name() + "/gamma");
273 // Just put a single value to recognize this node as Const.
274 gamma->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", epsilon));
275 fusedNode->set_input(1, gamma->name());
279 // tf.contrib.layers.flatten
280 class FlattenSubgraph : public Subgraph
285 int input = addNodeToMatch("");
286 int shape = addNodeToMatch("Const");
287 int stack = addNodeToMatch("Const");
288 int stack_1 = addNodeToMatch("Const");
289 int stack_2 = addNodeToMatch("Const");
290 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
291 int shape_pack = addNodeToMatch("Const");
292 int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
293 addNodeToMatch("Reshape", input, pack);
295 setFusedNode("Flatten", input);
299 // tf.contrib.layers.flatten in case of unknown batch size
300 class FlattenShapeSubgraph : public Subgraph
303 FlattenShapeSubgraph()
305 int input = addNodeToMatch("");
306 int shape = addNodeToMatch("Shape", input);
307 int stack = addNodeToMatch("Const");
308 int stack_1 = addNodeToMatch("Const");
309 int stack_2 = addNodeToMatch("Const");
310 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
311 int shape_pack = addNodeToMatch("Const");
312 int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
313 addNodeToMatch("Reshape", input, pack);
315 setFusedNode("Flatten", input);
320 class SoftMaxKerasSubgraph : public Subgraph
323 SoftMaxKerasSubgraph()
325 int input = addNodeToMatch("");
326 int maxReductionIndices = addNodeToMatch("Const");
327 int smMax = addNodeToMatch("Max", input, maxReductionIndices);
328 int smSub = addNodeToMatch("Sub", input, smMax);
329 int smExp = addNodeToMatch("Exp", smSub);
330 int sumReductionIndices = addNodeToMatch("Const");
331 int smSum = addNodeToMatch("Sum", smExp, sumReductionIndices);
332 addNodeToMatch("RealDiv", smExp, smSum);
334 setFusedNode("Softmax", input);
338 class ReLU6KerasSubgraph : public Subgraph
343 int input = addNodeToMatch("");
344 int relu = addNodeToMatch("Relu", input);
345 int maxValue = addNodeToMatch("Const");
346 int clipValue = addNodeToMatch("Const");
347 int minimum = addNodeToMatch("Minimum", relu, maxValue);
348 addNodeToMatch("Maximum", minimum, clipValue);
350 setFusedNode("Relu6", input);
353 virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE
355 if (!Subgraph::match(net, nodeId, matchedNodesIds))
357 Mat maxValue = getTensorContent(net.node(nodeId + 1).attr().at("value").tensor());
358 return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
362 // Keras' reshape stores output shape in separate Const nodes by one value.
363 // Need to merge them into a single Const node.
364 class ReshapeKerasSubgraph : public Subgraph
367 ReshapeKerasSubgraph(int _numOutDims) : numOutDims(_numOutDims)
369 int input = addNodeToMatch("");
370 int shape = addNodeToMatch("Shape", input);
371 int stack = addNodeToMatch("Const");
372 int stack_1 = addNodeToMatch("Const");
373 int stack_2 = addNodeToMatch("Const");
374 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
376 std::vector<int> ids(1 + numOutDims);
377 ids[0] = strided_slice;
378 for (int i = 0; i < numOutDims; ++i)
379 ids[1 + i] = addNodeToMatch("Const");
380 int pack = addNodeToMatch("Pack", ids);
381 addNodeToMatch("Reshape", input, pack);
384 setFusedNode("Reshape", ids);
387 virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
388 std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
390 std::vector<int> shape(numOutDims + 1); // batch size in Keras is implicit.
392 for (int i = 0; i < numOutDims; ++i)
394 shape[1 + i] = inputNodes[1 + i]->attr().at("value").tensor().int_val(0);
396 tensorflow::TensorProto* shapeTensor = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
397 fusedNode->mutable_input()->DeleteSubrange(2, numOutDims - 1);
399 shapeTensor->clear_int_val();
400 for (int i = 0; i < shape.size(); ++i)
402 shapeTensor->add_int_val(shape[i]);
410 class L2NormalizeSubgraph : public Subgraph
413 L2NormalizeSubgraph()
415 int input = addNodeToMatch("");
416 int square = addNodeToMatch("Square", input);
417 int reductionIndices = addNodeToMatch("Const");
418 int sum = addNodeToMatch("Sum", square, reductionIndices);
419 int y = addNodeToMatch("Const");
420 int maximum = addNodeToMatch("Maximum", sum, y);
421 int rsqrt = addNodeToMatch("Rsqrt", maximum);
422 addNodeToMatch("Mul", input, rsqrt);
423 setFusedNode("L2Normalize", input, reductionIndices);
427 class DeconvolutionValidKerasSubgraph : public Subgraph
430 DeconvolutionValidKerasSubgraph()
432 int input = addNodeToMatch("");
433 int shape = addNodeToMatch("Shape", input);
434 int kernel = addNodeToMatch("Const");
436 int stack = addNodeToMatch("Const");
437 int stack_1 = addNodeToMatch("Const");
438 int stack_2 = addNodeToMatch("Const");
439 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
441 stack = addNodeToMatch("Const");
442 stack_1 = addNodeToMatch("Const");
443 stack_2 = addNodeToMatch("Const");
444 int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
446 stack = addNodeToMatch("Const");
447 stack_1 = addNodeToMatch("Const");
448 stack_2 = addNodeToMatch("Const");
449 int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
451 int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
452 int add = addNodeToMatch("Add", mul, addNodeToMatch("Const"));
454 int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
455 int add_1 = addNodeToMatch("Add", mul_1, addNodeToMatch("Const"));
456 int pack = addNodeToMatch("Pack", strided_slice, add, add_1, addNodeToMatch("Const"));
457 addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
458 // Put any unused Const op to the first input.
459 setFusedNode("Conv2DBackpropInput", stack, kernel, input);
462 virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
463 std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
465 // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
466 // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
467 // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
468 // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
469 std::string padMode = fusedNode->attr().at("padding").s();
470 CV_Assert(padMode == "VALID");
472 const tensorflow::TensorShapeProto& kernelShape =
473 inputNodes[1]->mutable_attr()->at("value").tensor().tensor_shape();
475 CV_Assert(kernelShape.dim_size() == 4);
476 const int kernelHeight = kernelShape.dim(0).size();
477 const int kernelWidth = kernelShape.dim(1).size();
479 tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
480 outShape->clear_int_val();
481 outShape->add_int_val(-1);
482 outShape->add_int_val(kernelHeight);
483 outShape->add_int_val(kernelWidth);
484 outShape->add_int_val(-1);
488 class DeconvolutionSameKerasSubgraph : public Subgraph
491 DeconvolutionSameKerasSubgraph()
493 int input = addNodeToMatch("");
494 int shape = addNodeToMatch("Shape", input);
495 int kernel = addNodeToMatch("Const");
497 int stack = addNodeToMatch("Const");
498 int stack_1 = addNodeToMatch("Const");
499 int stack_2 = addNodeToMatch("Const");
500 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
502 stack = addNodeToMatch("Const");
503 stack_1 = addNodeToMatch("Const");
504 stack_2 = addNodeToMatch("Const");
505 int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
507 stack = addNodeToMatch("Const");
508 stack_1 = addNodeToMatch("Const");
509 stack_2 = addNodeToMatch("Const");
510 int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
512 int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
514 int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
515 int pack = addNodeToMatch("Pack", strided_slice, mul, mul_1, addNodeToMatch("Const"));
516 addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
517 // Put any unused Const op to the first input.
518 setFusedNode("Conv2DBackpropInput", stack, kernel, input);
521 virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
522 std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
524 // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
525 // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
526 // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
527 // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
528 std::string padMode = fusedNode->attr().at("padding").s();
529 CV_Assert(padMode == "SAME");
531 const tensorflow::AttrValue_ListValue& strides = fusedNode->attr().at("strides").list();
532 CV_Assert(strides.i_size() == 4);
534 const int strideY = strides.i(1);
535 const int strideX = strides.i(2);
537 tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
538 outShape->clear_int_val();
539 outShape->add_int_val(-1);
540 outShape->add_int_val(strideY);
541 outShape->add_int_val(strideX);
542 outShape->add_int_val(-1);
546 // In case of resizing by factor.
547 class ResizeBilinearSubgraph : public Subgraph
550 ResizeBilinearSubgraph()
552 int input = addNodeToMatch("");
554 int shape = addNodeToMatch("Shape", input);
555 int stack = addNodeToMatch("Const");
556 int stack_1 = addNodeToMatch("Const");
557 int stack_2 = addNodeToMatch("Const");
558 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
559 int factorY = addNodeToMatch("Const");
560 int mul = addNodeToMatch("Mul", strided_slice, factorY);
562 shape = addNodeToMatch("Shape", input);
563 stack = addNodeToMatch("Const");
564 stack_1 = addNodeToMatch("Const");
565 stack_2 = addNodeToMatch("Const");
566 strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
567 int factorX = addNodeToMatch("Const");
568 int mul_1 = addNodeToMatch("Mul", strided_slice, factorX);
570 int pack = addNodeToMatch("Pack", mul, mul_1);
572 addNodeToMatch("ResizeBilinear", input, pack);
573 setFusedNode("ResizeBilinear", input, factorY, factorX);
577 // In case of resizing by factor.
578 class UpsamplingKerasSubgraph : public Subgraph
581 UpsamplingKerasSubgraph()
583 int input = addNodeToMatch("");
584 int shape = addNodeToMatch("Shape", input);
585 int stack = addNodeToMatch("Const");
586 int stack_1 = addNodeToMatch("Const");
587 int stack_2 = addNodeToMatch("Const");
588 int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
589 int factors = addNodeToMatch("Const");
590 int mul = addNodeToMatch("Mul", strided_slice, factors);
591 addNodeToMatch("ResizeNearestNeighbor", input, mul);
592 setFusedNode("ResizeNearestNeighbor", input, factors);
595 virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
596 std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
598 Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
599 CV_CheckEQ(factorsMat.total(), (size_t)2, ""); CV_CheckTypeEQ(factorsMat.type(), CV_32SC1, "");
601 // Height scale factor
602 tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
603 factorY->clear_int_val();
604 factorY->clear_tensor_content();
605 factorY->add_int_val(factorsMat.at<int>(0, 0));
607 // Width scale factor.
608 tensorflow::NodeDef* factorXNode = net.add_node();
609 factorXNode->set_op("Const");
610 factorXNode->set_name(fusedNode->name() + "/factor_y");
612 tensorflow::AttrValue factorX;
613 factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32);
614 factorX.mutable_tensor()->add_int_val(factorsMat.at<int>(0, 1));
615 factorXNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", factorX));
617 fusedNode->add_input(factorXNode->name());
621 class ReshapeAsShapeSubgraph : public Subgraph
624 ReshapeAsShapeSubgraph()
626 int input = addNodeToMatch("");
627 int shapeSrc = addNodeToMatch("");
628 int shape = addNodeToMatch("Shape", shapeSrc);
629 addNodeToMatch("Reshape", input, shape);
630 setFusedNode("Reshape", input, shapeSrc);
634 class SoftMaxSlimSubgraph : public Subgraph
637 SoftMaxSlimSubgraph()
639 int input = addNodeToMatch("");
640 int shape = addNodeToMatch("Const");
641 int shapeOp = addNodeToMatch("Shape", input);
642 int reshape = addNodeToMatch("Reshape", input, shape);
643 int softmax = addNodeToMatch("Softmax", reshape);
644 addNodeToMatch("Reshape", softmax, shapeOp);
645 setFusedNode("Softmax", input);
649 void simplifySubgraphs(tensorflow::GraphDef& net)
651 std::vector<Ptr<Subgraph> > subgraphs;
652 subgraphs.push_back(Ptr<Subgraph>(new BatchNormSubgraph()));
653 subgraphs.push_back(Ptr<Subgraph>(new BatchNormNoGammaSubgraph()));
654 subgraphs.push_back(Ptr<Subgraph>(new FlattenSubgraph()));
655 subgraphs.push_back(Ptr<Subgraph>(new FlattenShapeSubgraph()));
656 subgraphs.push_back(Ptr<Subgraph>(new SoftMaxKerasSubgraph()));
657 subgraphs.push_back(Ptr<Subgraph>(new ReLU6KerasSubgraph()));
658 subgraphs.push_back(Ptr<Subgraph>(new ReshapeKerasSubgraph(3)));
659 subgraphs.push_back(Ptr<Subgraph>(new L2NormalizeSubgraph()));
660 subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
661 subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
662 subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
663 subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
664 subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
665 subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
667 int numNodes = net.node_size();
668 std::vector<int> matchedNodesIds;
669 for (int i = 0; i < numNodes; ++i)
671 for (int j = 0; j < subgraphs.size(); ++j)
673 if (subgraphs[j]->match(net, i, matchedNodesIds))
675 subgraphs[j]->replace(net, matchedNodesIds);
676 numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
683 void RemoveIdentityOps(tensorflow::GraphDef& net)
685 typedef std::map<String, String> IdentityOpsMap;
686 IdentityOpsMap identity_ops;
688 std::vector<int> identity_ops_idx;
690 int layersCount = net.node_size();
691 for (int li = 0; li < layersCount; li++)
693 const tensorflow::NodeDef &layer = net.node(li);
694 String type = layer.op();
696 if (type == "Identity" || type == "Dropout") {
697 identity_ops_idx.push_back(li);
698 identity_ops[layer.name()] = layer.input(0);
702 for (int li = 0; li < layersCount; li++)
704 tensorflow::NodeDef* layer = net.mutable_node(li);
705 for (int input_id = 0; input_id < layer->input_size(); input_id++) {
706 String input_op_name = layer->input(input_id);
707 IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
709 if (it != identity_ops.end()) {
710 layer->set_input(input_id, it->second);
715 std::sort(identity_ops_idx.begin(), identity_ops_idx.end());
717 int removed_nodes = 0;
718 for(size_t i = 0; i < identity_ops_idx.size(); i++) {
719 int start_id = identity_ops_idx[i] - removed_nodes;
720 net.mutable_node()->DeleteSubrange(start_id, 1);
725 Mat getTensorContent(const tensorflow::TensorProto &tensor)
727 const std::string& content = tensor.tensor_content();
728 switch (tensor.dtype())
730 case tensorflow::DT_FLOAT:
732 if (!content.empty())
733 return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
736 const RepeatedField<float>& field = tensor.float_val();
737 CV_Assert(!field.empty());
738 return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
741 case tensorflow::DT_DOUBLE:
743 if (!content.empty())
744 return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
747 const RepeatedField<double>& field = tensor.double_val();
748 CV_Assert(!field.empty());
749 return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
752 case tensorflow::DT_INT32:
754 if (!content.empty())
755 return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
758 const RepeatedField<int32_t>& field = tensor.int_val();
759 CV_Assert(!field.empty());
760 return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
763 case tensorflow::DT_HALF:
766 if (!content.empty())
768 static const int kHalfSize = 2;
769 halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
773 const RepeatedField<int32_t>& field = tensor.half_val();
774 CV_Assert(!field.empty());
775 Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
776 ints.convertTo(halfs, CV_16UC1);
778 // Reinterpret as a signed shorts just for a convertFp16 call.
779 Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
780 Mat floats(halfs.size(), CV_32FC1);
781 convertFp16(halfsSigned, floats);
784 case tensorflow::DT_QUINT8:
786 CV_Assert(!content.empty());
787 return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
790 CV_Error(Error::StsError, "Tensor's data type is not supported");
796 void releaseTensor(tensorflow::TensorProto* tensor)
798 if (!tensor->mutable_tensor_content()->empty())
800 delete tensor->release_tensor_content();
804 static void permute(google::protobuf::RepeatedPtrField<tensorflow::NodeDef>* data,
805 const std::vector<int>& indices)
807 const int num = data->size();
808 CV_Assert(num == indices.size());
810 std::vector<int> elemIdToPos(num);
811 std::vector<int> posToElemId(num);
812 for (int i = 0; i < num; ++i)
817 for (int i = 0; i < num; ++i)
819 int elemId = indices[i];
820 int pos = elemIdToPos[elemId];
823 data->SwapElements(i, pos);
824 const int swappedElemId = posToElemId[i];
825 elemIdToPos[elemId] = i;
826 elemIdToPos[swappedElemId] = pos;
828 posToElemId[i] = elemId;
829 posToElemId[pos] = swappedElemId;
834 // Is based on tensorflow::graph_transforms::SortByExecutionOrder
835 void sortByExecutionOrder(tensorflow::GraphDef& net)
837 // Maps node's name to index at net.node() list.
838 std::map<std::string, int> nodesMap;
839 std::map<std::string, int>::iterator nodesMapIt;
840 for (int i = 0; i < net.node_size(); ++i)
842 const tensorflow::NodeDef& node = net.node(i);
843 nodesMap.insert(std::make_pair(node.name(), i));
846 // Indices of nodes which use specific node as input.
847 std::vector<std::vector<int> > edges(nodesMap.size());
848 std::vector<int> numRefsToAdd(nodesMap.size(), 0);
849 std::vector<int> nodesToAdd;
850 for (int i = 0; i < net.node_size(); ++i)
852 const tensorflow::NodeDef& node = net.node(i);
853 for (int j = 0; j < node.input_size(); ++j)
855 std::string inpName = node.input(j);
856 inpName = inpName.substr(0, inpName.rfind(':'));
857 inpName = inpName.substr(inpName.find('^') + 1);
859 nodesMapIt = nodesMap.find(inpName);
860 CV_Assert(nodesMapIt != nodesMap.end());
861 edges[nodesMapIt->second].push_back(i);
863 if (node.input_size() == 0)
864 nodesToAdd.push_back(i);
867 if (node.op() == "Merge" || node.op() == "RefMerge")
869 int numControlEdges = 0;
870 for (int j = 0; j < node.input_size(); ++j)
871 numControlEdges += node.input(j)[0] == '^';
872 numRefsToAdd[i] = numControlEdges + 1;
875 numRefsToAdd[i] = node.input_size();
879 std::vector<int> permIds;
880 permIds.reserve(net.node_size());
881 while (!nodesToAdd.empty())
883 int nodeToAdd = nodesToAdd.back();
884 nodesToAdd.pop_back();
886 permIds.push_back(nodeToAdd);
888 for (int i = 0; i < edges[nodeToAdd].size(); ++i)
890 int consumerId = edges[nodeToAdd][i];
891 if (numRefsToAdd[consumerId] > 0)
893 if (numRefsToAdd[consumerId] == 1)
894 nodesToAdd.push_back(consumerId);
896 CV_Assert(numRefsToAdd[consumerId] >= 0);
897 numRefsToAdd[consumerId] -= 1;
901 CV_Assert(permIds.size() == net.node_size());
902 permute(net.mutable_node(), permIds);
905 // Remove training switches (Switch and Merge nodes and corresponding subgraphs).
906 void removePhaseSwitches(tensorflow::GraphDef& net)
908 std::vector<int> nodesToRemove;
909 std::map<std::string, int> nodesMap;
910 std::map<std::string, int>::iterator nodesMapIt;
911 std::queue<int> mergeOpSubgraphNodes;
912 for (int i = 0; i < net.node_size(); ++i)
914 const tensorflow::NodeDef& node = net.node(i);
915 nodesMap.insert(std::make_pair(node.name(), i));
916 if (node.op() == "Switch" || node.op() == "Merge")
918 CV_Assert(node.input_size() > 0);
919 // Replace consumers' inputs.
920 for (int j = 0; j < net.node_size(); ++j)
922 tensorflow::NodeDef* consumer = net.mutable_node(j);
923 for (int k = 0; k < consumer->input_size(); ++k)
925 std::string inpName = consumer->input(k);
926 inpName = inpName.substr(0, inpName.rfind(':'));
927 if (inpName == node.name())
929 consumer->set_input(k, node.input(0));
933 nodesToRemove.push_back(i);
934 if (node.op() == "Merge")
935 mergeOpSubgraphNodes.push(i);
939 std::vector<int> numConsumers(net.node_size(), 0);
940 for (int i = 0; i < net.node_size(); ++i)
942 const tensorflow::NodeDef& node = net.node(i);
943 for (int j = 0; j < node.input_size(); ++j)
945 std::string inpName = node.input(j);
946 inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
947 nodesMapIt = nodesMap.find(inpName);
948 CV_Assert(nodesMapIt != nodesMap.end());
949 numConsumers[nodesMapIt->second] += 1;
953 // Remove subgraphs of unused nodes which are terminated by Merge nodes.
954 while (!mergeOpSubgraphNodes.empty())
956 const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front());
957 mergeOpSubgraphNodes.pop();
958 for (int i = 0; i < node.input_size(); ++i)
960 std::string inpName = node.input(i);
961 inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
962 nodesMapIt = nodesMap.find(inpName);
963 CV_Assert(nodesMapIt != nodesMap.end());
965 int inpNodeId = nodesMapIt->second;
966 if (numConsumers[inpNodeId] == 1)
968 mergeOpSubgraphNodes.push(inpNodeId);
969 nodesToRemove.push_back(inpNodeId);
971 else if (numConsumers[inpNodeId] > 0)
972 numConsumers[inpNodeId] -= 1;
975 std::sort(nodesToRemove.begin(), nodesToRemove.end());
976 for (int i = nodesToRemove.size() - 1; i >= 0; --i)
978 if (nodesToRemove[i] < net.node_size()) // Ids might be repeated.
979 net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1);
984 CV__DNN_EXPERIMENTAL_NS_END
985 }} // namespace dnn, namespace cv
987 #endif // HAVE_PROTOBUF