tensroflow support maxpoolgrad
authorgal0is <zhipeng.job@hotmail.com>
Sat, 15 Jun 2019 15:51:13 +0000 (23:51 +0800)
committerVonChenPlus <VonChenPlus@gmail.com>
Wed, 3 Jul 2019 01:53:17 +0000 (09:53 +0800)
modules/dnn/src/layers/max_unpooling_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index b9c1f2d..2978509 100644 (file)
@@ -43,12 +43,18 @@ public:
                          std::vector<MatShape> &outputs,
                          std::vector<MatShape> &internals) const CV_OVERRIDE
     {
-        CV_Assert(inputs.size() == 2);
+        CV_Assert(inputs.size() == 2 || inputs.size() == 3);
         CV_Assert(total(inputs[0]) == total(inputs[1]));
 
-        MatShape outShape = inputs[0];
-        outShape[2] = (outShape[2] - 1) * poolStride.height + poolKernel.height - 2 * poolPad.height;
-        outShape[3] = (outShape[3] - 1) * poolStride.width + poolKernel.width - 2 * poolPad.width;
+        MatShape outShape;
+        if (inputs.size() == 2)
+        {
+            outShape = inputs[0];
+            outShape[2] = (outShape[2] - 1) * poolStride.height + poolKernel.height - 2 * poolPad.height;
+            outShape[3] = (outShape[3] - 1) * poolStride.width + poolKernel.width - 2 * poolPad.width;
+        }
+        else
+            outShape = inputs[2];
 
         outputs.clear();
         outputs.push_back(outShape);
@@ -71,7 +77,7 @@ public:
         inputs_arr.getMatVector(inputs);
         outputs_arr.getMatVector(outputs);
 
-        CV_Assert(inputs.size() == 2);
+        CV_Assert(inputs.size() == 2 || inputs.size() == 3);
         Mat& input = inputs[0];
         Mat& indices = inputs[1];
 
index 3a5bd34..c38b250 100644 (file)
@@ -1370,6 +1370,24 @@ void TFImporter::populateNet(Net dstNet)
 
             connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
         }
+        else if (type == "MaxPoolGrad")
+        {
+            CV_Assert(layer.input_size() == 3);
+
+            layerParams.set("pool_k_h", 0);
+            layerParams.set("pool_k_w", 0);
+            layerParams.set("pool_stride_h", 0);
+            layerParams.set("pool_stride_w", 0);
+            layerParams.set("pool_pad_h", 0);
+            layerParams.set("pool_pad_w", 0);
+
+            int id = dstNet.addLayer(name, "MaxUnpool", layerParams);
+            layer_id[name] = id;
+
+            connect(layer_id, dstNet, parsePin(layer.input(2)), id, 0);
+            connect(layer_id, dstNet, parsePin(layer.input(1) + ":1"), id, 1);
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 2);
+        }
         else if (type == "Placeholder")
         {
             if (!hasLayerAttr(layer, "dtype") ||
index dd5d871..0687c0b 100644 (file)
@@ -218,6 +218,13 @@ TEST_P(Test_TensorFlow_layers, pooling)
     runTensorFlowNet("reduce_mean");  // an average pooling over all spatial dimensions.
 }
 
+TEST_P(Test_TensorFlow_layers, max_pool_grad)
+{
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
+    runTensorFlowNet("max_pool_grad");
+}
+
 // TODO: fix tests and replace to pooling
 TEST_P(Test_TensorFlow_layers, ave_pool_same)
 {