Faster-RCNN object detection models from TensorFlow
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 3 Apr 2018 15:28:05 +0000 (18:28 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 30 May 2018 14:12:36 +0000 (17:12 +0300)
modules/dnn/include/opencv2/dnn/all_layers.hpp
modules/dnn/src/init.cpp
modules/dnn/src/layers/crop_and_resize_layer.cpp [new file with mode: 0644]
modules/dnn/src/layers/detection_output_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp
samples/dnn/README.md
samples/dnn/tf_text_graph_faster_rcnn.py [new file with mode: 0644]

index f2124dd..ffb09a2 100644 (file)
@@ -581,6 +581,12 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
         static Ptr<ProposalLayer> create(const LayerParams& params);
     };
 
+    class CV_EXPORTS CropAndResizeLayer : public Layer
+    {
+    public:
+        static Ptr<Layer> create(const LayerParams& params);
+    };
+
 //! @}
 //! @}
 CV__DNN_EXPERIMENTAL_NS_END
index 28759da..2bff16c 100644 (file)
@@ -84,6 +84,7 @@ void initializeLayerFactory()
     CV_DNN_REGISTER_LAYER_CLASS(Reshape,        ReshapeLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Flatten,        FlattenLayer);
     CV_DNN_REGISTER_LAYER_CLASS(ResizeNearestNeighbor, ResizeNearestNeighborLayer);
+    CV_DNN_REGISTER_LAYER_CLASS(CropAndResize,  CropAndResizeLayer);
 
     CV_DNN_REGISTER_LAYER_CLASS(Convolution,    ConvolutionLayer);
     CV_DNN_REGISTER_LAYER_CLASS(Deconvolution,  DeconvolutionLayer);
diff --git a/modules/dnn/src/layers/crop_and_resize_layer.cpp b/modules/dnn/src/layers/crop_and_resize_layer.cpp
new file mode 100644 (file)
index 0000000..3f92a84
--- /dev/null
@@ -0,0 +1,108 @@
+#include "../precomp.hpp"
+#include "layers_common.hpp"
+
+namespace cv { namespace dnn {
+
+class CropAndResizeLayerImpl CV_FINAL : public CropAndResizeLayer
+{
+public:
+    CropAndResizeLayerImpl(const LayerParams& params)
+    {
+        CV_Assert(params.has("width"), params.has("height"));
+        outWidth = params.get<float>("width");
+        outHeight = params.get<float>("height");
+    }
+
+    bool getMemoryShapes(const std::vector<MatShape> &inputs,
+                         const int requiredOutputs,
+                         std::vector<MatShape> &outputs,
+                         std::vector<MatShape> &internals) const CV_OVERRIDE
+    {
+        CV_Assert(inputs.size() == 2, inputs[0].size() == 4);
+        if (inputs[0][0] != 1)
+            CV_Error(Error::StsNotImplemented, "");
+        outputs.resize(1, MatShape(4));
+        outputs[0][0] = inputs[1][2];  // Number of bounding boxes.
+        outputs[0][1] = inputs[0][1];  // Number of channels.
+        outputs[0][2] = outHeight;
+        outputs[0][3] = outWidth;
+        return false;
+    }
+
+    void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr);
+    }
+
+    void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals) CV_OVERRIDE
+    {
+        CV_TRACE_FUNCTION();
+        CV_TRACE_ARG_VALUE(name, "name", name.c_str());
+
+        Mat& inp = *inputs[0];
+        Mat& out = outputs[0];
+        Mat boxes = inputs[1]->reshape(1, inputs[1]->total() / 7);
+        const int numChannels = inp.size[1];
+        const int inpHeight = inp.size[2];
+        const int inpWidth = inp.size[3];
+        const int inpSpatialSize = inpHeight * inpWidth;
+        const int outSpatialSize = outHeight * outWidth;
+        CV_Assert(inp.isContinuous(), out.isContinuous());
+
+        for (int b = 0; b < boxes.rows; ++b)
+        {
+            float* outDataBox = out.ptr<float>(b);
+            float left = boxes.at<float>(b, 3);
+            float top = boxes.at<float>(b, 4);
+            float right = boxes.at<float>(b, 5);
+            float bottom = boxes.at<float>(b, 6);
+            float boxWidth = right - left;
+            float boxHeight = bottom - top;
+
+            float heightScale = boxHeight * static_cast<float>(inpHeight - 1) / (outHeight - 1);
+            float widthScale = boxWidth * static_cast<float>(inpWidth - 1) / (outWidth - 1);
+            for (int y = 0; y < outHeight; ++y)
+            {
+                float input_y = top * (inpHeight - 1) + y * heightScale;
+                int y0 = static_cast<int>(input_y);
+                const float* inpData_row0 = (float*)inp.data + y0 * inpWidth;
+                const float* inpData_row1 = (y0 + 1 < inpHeight) ? (inpData_row0 + inpWidth) : inpData_row0;
+                for (int x = 0; x < outWidth; ++x)
+                {
+                    float input_x = left * (inpWidth - 1) + x * widthScale;
+                    int x0 = static_cast<int>(input_x);
+                    int x1 = std::min(x0 + 1, inpWidth - 1);
+
+                    float* outData = outDataBox + y * outWidth + x;
+                    const float* inpData_row0_c = inpData_row0;
+                    const float* inpData_row1_c = inpData_row1;
+                    for (int c = 0; c < numChannels; ++c)
+                    {
+                        *outData = inpData_row0_c[x0] +
+                            (input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) +
+                            (input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] +
+                            (input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0]));
+
+                        inpData_row0_c += inpSpatialSize;
+                        inpData_row1_c += inpSpatialSize;
+                        outData += outSpatialSize;
+                    }
+                }
+            }
+        }
+    }
+
+private:
+    int outWidth, outHeight;
+};
+
+Ptr<Layer> CropAndResizeLayer::create(const LayerParams& params)
+{
+    return Ptr<CropAndResizeLayer>(new CropAndResizeLayerImpl(params));
+}
+
+}  // namespace dnn
+}  // namespace cv
index 44f7b32..ee1ad95 100644 (file)
@@ -208,8 +208,9 @@ public:
         CV_Assert(inputs[0][0] == inputs[1][0]);
 
         int numPriors = inputs[2][2] / 4;
-        CV_Assert((numPriors * _numLocClasses * 4) == inputs[0][1]);
-        CV_Assert(int(numPriors * _numClasses) == inputs[1][1]);
+        CV_Assert((numPriors * _numLocClasses * 4) == total(inputs[0], 1));
+        CV_Assert(int(numPriors * _numClasses) == total(inputs[1], 1));
+        CV_Assert(inputs[2][1] == 1 + (int)(!_varianceEncodedInTarget));
 
         // num() and channels() are 1.
         // Since the number of bboxes to be kept is unknown before nms, we manually
index bca150e..f19daf9 100644 (file)
@@ -1094,9 +1094,9 @@ void TFImporter::populateNet(Net dstNet)
             CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
                       sizes.type() == CV_32SC1);
 
-            if (begins.total() == 4)
+            if (begins.total() == 4 && data_layouts[name] == DATA_LAYOUT_NHWC)
             {
-                // Perhabs, we have an NHWC order. Swap it to NCHW.
+                // Swap NHWC parameters' order to NCHW.
                 std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
                 std::swap(*begins.ptr<int32_t>(0, 1), *begins.ptr<int32_t>(0, 2));
                 std::swap(*sizes.ptr<int32_t>(0, 2), *sizes.ptr<int32_t>(0, 3));
@@ -1176,6 +1176,9 @@ void TFImporter::populateNet(Net dstNet)
                        layers_to_ignore.insert(next_layers[0].first);
                    }
 
+                    if (hasLayerAttr(layer, "axis"))
+                        layerParams.set("axis", getLayerAttr(layer, "axis").i());
+
                     id = dstNet.addLayer(name, "Scale", layerParams);
                 }
                 layer_id[name] = id;
@@ -1547,6 +1550,10 @@ void TFImporter::populateNet(Net dstNet)
                 layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f());
             if (hasLayerAttr(layer, "loc_pred_transposed"))
                 layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b());
+            if (hasLayerAttr(layer, "clip"))
+                layerParams.set("clip", getLayerAttr(layer, "clip").b());
+            if (hasLayerAttr(layer, "variance_encoded_in_target"))
+                layerParams.set("variance_encoded_in_target", getLayerAttr(layer, "variance_encoded_in_target").b());
 
             int id = dstNet.addLayer(name, "DetectionOutput", layerParams);
             layer_id[name] = id;
@@ -1563,6 +1570,26 @@ void TFImporter::populateNet(Net dstNet)
             layer_id[name] = id;
             connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
         }
+        else if (type == "CropAndResize")
+        {
+            // op: "CropAndResize"
+            // input: "input"
+            // input: "boxes"
+            // input: "sizes"
+            CV_Assert(layer.input_size() == 3);
+
+            Mat cropSize = getTensorContent(getConstBlob(layer, value_id, 2));
+            CV_Assert(cropSize.type() == CV_32SC1, cropSize.total() == 2);
+
+            layerParams.set("height", cropSize.at<int>(0));
+            layerParams.set("width", cropSize.at<int>(1));
+
+            int id = dstNet.addLayer(name, "CropAndResize", layerParams);
+            layer_id[name] = id;
+
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+            connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
+        }
         else if (type == "Mean")
         {
             Mat indices = getTensorContent(getConstBlob(layer, value_id, 1));
index b090fd7..84205f7 100644 (file)
@@ -270,6 +270,22 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD)
     normAssertDetections(ref, out, "", 0.5);
 }
 
+TEST_P(Test_TensorFlow_nets, Inception_v2_Faster_RCNN)
+{
+    std::string proto = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt", false);
+    std::string model = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pb", false);
+
+    Net net = readNetFromTensorflow(model, proto);
+    Mat img = imread(findDataFile("dnn/dog416.png", false));
+    Mat blob = blobFromImage(img, 1.0f / 127.5, Size(800, 600), Scalar(127.5, 127.5, 127.5), true, false);
+
+    net.setInput(blob);
+    Mat out = net.forward();
+
+    Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/faster_rcnn_inception_v2_coco_2018_01_28.detection_out.npy"));
+    normAssertDetections(ref, out, "", 0.3);
+}
+
 TEST_P(Test_TensorFlow_nets, opencv_face_detector_uint8)
 {
     std::string proto = findDataFile("dnn/opencv_face_detector.pbtxt", false);
index c438bb0..9072ddb 100644 (file)
 | [SSDs from TensorFlow](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB |
 | [YOLO](https://pjreddie.com/darknet/yolo/) | `0.00392 (1/255)` | `416x416` | `0 0 0` | RGB |
 | [VGG16-SSD](https://github.com/weiliu89/caffe/tree/ssd) | `1.0` | `300x300` | `104 117 123` | BGR |
-| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR |
+| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
 | [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR |
+| [Faster-RCNN, ResNet backbone](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `1.0` | `300x300` | `103.939 116.779 123.68` | RGB |
+| [Faster-RCNN, InceptionV2 backbone](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB |
 
 #### Face detection
 [An origin model](https://github.com/opencv/opencv/tree/master/samples/dnn/face_detector)
diff --git a/samples/dnn/tf_text_graph_faster_rcnn.py b/samples/dnn/tf_text_graph_faster_rcnn.py
new file mode 100644 (file)
index 0000000..7ad5de2
--- /dev/null
@@ -0,0 +1,291 @@
+import argparse
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.core.framework.node_def_pb2 import NodeDef
+from tensorflow.tools.graph_transforms import TransformGraph
+from google.protobuf import text_format
+
+parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
+                                             'SSD model from TensorFlow Object Detection API. '
+                                             'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
+parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
+parser.add_argument('--output', required=True, help='Path to output text graph.')
+parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.')
+parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+',
+                    help='Hyper-parameter of grid_anchor_generator from a config file.')
+parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+',
+                    help='Hyper-parameter of grid_anchor_generator from a config file.')
+parser.add_argument('--features_stride', default=16, type=float, nargs='+',
+                    help='Hyper-parameter from a config file.')
+args = parser.parse_args()
+
+scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
+                'FirstStageBoxPredictor/BoxEncodingPredictor',
+                'FirstStageBoxPredictor/ClassPredictor',
+                'CropAndResize',
+                'MaxPool2D',
+                'SecondStageFeatureExtractor',
+                'SecondStageBoxPredictor',
+                'image_tensor')
+
+scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
+                  'FirstStageFeatureExtractor/Shape',
+                  'FirstStageFeatureExtractor/strided_slice',
+                  'FirstStageFeatureExtractor/GreaterEqual',
+                  'FirstStageFeatureExtractor/LogicalAnd')
+
+unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
+               'Index', 'Tperm', 'is_training', 'Tpaddings']
+
+# Read the graph.
+with tf.gfile.FastGFile(args.input, 'rb') as f:
+    graph_def = tf.GraphDef()
+    graph_def.ParseFromString(f.read())
+
+# Removes Identity nodes
+def removeIdentity():
+    identities = {}
+    for node in graph_def.node:
+        if node.op == 'Identity':
+            identities[node.name] = node.input[0]
+            graph_def.node.remove(node)
+
+    for node in graph_def.node:
+        for i in range(len(node.input)):
+            if node.input[i] in identities:
+                node.input[i] = identities[node.input[i]]
+
+removeIdentity()
+
+removedNodes = []
+
+for i in reversed(range(len(graph_def.node))):
+    op = graph_def.node[i].op
+    name = graph_def.node[i].name
+
+    if op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep):
+        if op != 'Const':
+            removedNodes.append(name)
+
+        del graph_def.node[i]
+    else:
+        for attr in unusedAttrs:
+            if attr in graph_def.node[i].attr:
+                del graph_def.node[i].attr[attr]
+
+# Remove references to removed nodes except Const nodes.
+for node in graph_def.node:
+    for i in reversed(range(len(node.input))):
+        if node.input[i] in removedNodes:
+            del node.input[i]
+
+
+# Connect input node to the first layer
+assert(graph_def.node[0].op == 'Placeholder')
+graph_def.node[1].input.insert(0, graph_def.node[0].name)
+
+# Temporarily remove top nodes.
+topNodes = []
+while True:
+    node = graph_def.node.pop()
+    topNodes.append(node)
+    if node.op == 'CropAndResize':
+        break
+
+def tensorMsg(values):
+    if all([isinstance(v, float) for v in values]):
+        dtype = 'DT_FLOAT'
+        field = 'float_val'
+    elif all([isinstance(v, int) for v in values]):
+        dtype = 'DT_INT32'
+        field = 'int_val'
+    else:
+        raise Exception('Wrong values types')
+
+    msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values)
+    for value in values:
+        msg += '%s: %s ' % (field, str(value))
+    return msg + '}'
+
+def addSlice(inp, out, begins, sizes):
+    beginsNode = NodeDef()
+    beginsNode.name = out + '/begins'
+    beginsNode.op = 'Const'
+    text_format.Merge(tensorMsg(begins), beginsNode.attr["value"])
+    graph_def.node.extend([beginsNode])
+
+    sizesNode = NodeDef()
+    sizesNode.name = out + '/sizes'
+    sizesNode.op = 'Const'
+    text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"])
+    graph_def.node.extend([sizesNode])
+
+    sliced = NodeDef()
+    sliced.name = out
+    sliced.op = 'Slice'
+    sliced.input.append(inp)
+    sliced.input.append(beginsNode.name)
+    sliced.input.append(sizesNode.name)
+    graph_def.node.extend([sliced])
+
+def addReshape(inp, out, shape):
+    shapeNode = NodeDef()
+    shapeNode.name = out + '/shape'
+    shapeNode.op = 'Const'
+    text_format.Merge(tensorMsg(shape), shapeNode.attr["value"])
+    graph_def.node.extend([shapeNode])
+
+    reshape = NodeDef()
+    reshape.name = out
+    reshape.op = 'Reshape'
+    reshape.input.append(inp)
+    reshape.input.append(shapeNode.name)
+    graph_def.node.extend([reshape])
+
+def addSoftMax(inp, out):
+    softmax = NodeDef()
+    softmax.name = out
+    softmax.op = 'Softmax'
+    text_format.Merge('i: -1', softmax.attr['axis'])
+    softmax.input.append(inp)
+    graph_def.node.extend([softmax])
+
+addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
+           'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2])
+
+addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
+           'FirstStageBoxPredictor/ClassPredictor/softmax')  # Compare with Reshape_4
+
+flatten = NodeDef()
+flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten'  # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
+flatten.op = 'Flatten'
+flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
+graph_def.node.extend([flatten])
+
+proposals = NodeDef()
+proposals.name = 'proposals'  # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
+proposals.op = 'PriorBox'
+proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
+proposals.input.append(graph_def.node[0].name)  # image_tensor
+
+text_format.Merge('b: false', proposals.attr["flip"])
+text_format.Merge('b: true', proposals.attr["clip"])
+text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"])
+text_format.Merge('f: 0.0', proposals.attr["offset"])
+text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"])
+
+widths = []
+heights = []
+for a in args.aspect_ratios:
+    for s in args.scales:
+        ar = np.sqrt(a)
+        heights.append((args.features_stride**2) * s / ar)
+        widths.append((args.features_stride**2) * s * ar)
+
+text_format.Merge(tensorMsg(widths), proposals.attr["width"])
+text_format.Merge(tensorMsg(heights), proposals.attr["height"])
+
+graph_def.node.extend([proposals])
+
+# Compare with Reshape_5
+detectionOut = NodeDef()
+detectionOut.name = 'detection_out'
+detectionOut.op = 'DetectionOutput'
+
+detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
+detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax')
+detectionOut.input.append('proposals')
+
+text_format.Merge('i: 2', detectionOut.attr['num_classes'])
+text_format.Merge('b: true', detectionOut.attr['share_location'])
+text_format.Merge('i: 0', detectionOut.attr['background_label_id'])
+text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold'])
+text_format.Merge('i: 6000', detectionOut.attr['top_k'])
+text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
+text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
+text_format.Merge('b: true', detectionOut.attr['clip'])
+text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
+
+graph_def.node.extend([detectionOut])
+
+# Save as text.
+for node in reversed(topNodes):
+    graph_def.node.extend([node])
+
+addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax')
+
+addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
+         'SecondStageBoxPredictor/Reshape_1/slice',
+         [0, 0, 1], [-1, -1, -1])
+
+addReshape('SecondStageBoxPredictor/Reshape_1/slice',
+          'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1])
+
+# Replace Flatten subgraph onto a single node.
+for i in reversed(range(len(graph_def.node))):
+    if graph_def.node[i].op == 'CropAndResize':
+        graph_def.node[i].input.insert(1, 'detection_out')
+
+    if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
+        shapeNode = NodeDef()
+        shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2'
+        shapeNode.op = 'Const'
+        text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"])
+        graph_def.node.extend([shapeNode])
+
+        graph_def.node[i].input.pop()
+        graph_def.node[i].input.append(shapeNode.name)
+
+    if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
+                                  'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
+                                  'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
+        del graph_def.node[i]
+
+for node in graph_def.node:
+    if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
+        node.op = 'Flatten'
+        node.input.pop()
+        break
+
+################################################################################
+### Postprocessing
+################################################################################
+addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4])
+
+variance = NodeDef()
+variance.name = 'proposals/variance'
+variance.op = 'Const'
+text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"])
+graph_def.node.extend([variance])
+
+varianceEncoder = NodeDef()
+varianceEncoder.name = 'variance_encoded'
+varianceEncoder.op = 'Mul'
+varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
+varianceEncoder.input.append(variance.name)
+text_format.Merge('i: 2', varianceEncoder.attr["axis"])
+graph_def.node.extend([varianceEncoder])
+
+addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1])
+
+detectionOut = NodeDef()
+detectionOut.name = 'detection_out_final'
+detectionOut.op = 'DetectionOutput'
+
+detectionOut.input.append('variance_encoded')
+detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
+detectionOut.input.append('detection_out/slice/reshape')
+
+text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes'])
+text_format.Merge('b: false', detectionOut.attr['share_location'])
+text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id'])
+text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold'])
+text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type'])
+text_format.Merge('i: 100', detectionOut.attr['keep_top_k'])
+text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed'])
+text_format.Merge('b: true', detectionOut.attr['clip'])
+text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target'])
+graph_def.node.extend([detectionOut])
+
+tf.train.write_graph(graph_def, "", args.output, as_text=True)