PReLU layer for multidimensional input
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 23 Oct 2017 11:30:40 +0000 (14:30 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 23 Oct 2017 13:13:03 +0000 (16:13 +0300)
modules/dnn/src/layers/elementwise_layers.cpp
modules/dnn/src/layers/fully_connected_layer.cpp
modules/dnn/src/layers/prior_box_layer.cpp
modules/dnn/test/test_layers.cpp

index eb93363..1e3d2de 100644 (file)
@@ -80,20 +80,19 @@ public:
 
         void operator()(const Range &r) const
         {
-            int nstripes = nstripes_, nsamples, outCn;
-            size_t planeSize;
+            int nstripes = nstripes_, nsamples = 1, outCn = 1;
+            size_t planeSize = 1;
 
-            if( src_->dims == 4 )
+            if (src_->dims > 1)
             {
                 nsamples = src_->size[0];
                 outCn = src_->size[1];
-                planeSize = (size_t)src_->size[2]*src_->size[3];
             }
             else
-            {
-                nsamples = outCn = 1;
-                planeSize = (size_t)src_->total();
-            }
+                outCn = src_->size[0];
+
+            for (int i = 2; i < src_->dims; ++i)
+                planeSize *= src_->size[i];
 
             size_t stripeSize = (planeSize + nstripes - 1)/nstripes;
             size_t stripeStart = r.start*stripeSize;
index 6067b3f..6fa9ed6 100644 (file)
@@ -242,9 +242,8 @@ public:
                     }
                 }
 
-                // TODO: check whether this is correct in the case of ChannelsPReLU.
                 if(activ)
-                    activ->forwardSlice(dptr, dptr, nw, 0, 0, 1);
+                    activ->forwardSlice(dptr, dptr, 1, 1, delta, delta + nw);
 
                 ofs += nw;
             }
index 3ca0835..009789d 100644 (file)
@@ -177,7 +177,7 @@ public:
         : _boxWidth(0), _boxHeight(0)
     {
         setParamsFrom(params);
-        _minSize = getParameter<unsigned>(params, "min_size");
+        _minSize = getParameter<float>(params, "min_size");
         CV_Assert(_minSize > 0);
 
         _flip = getParameter<bool>(params, "flip");
index ac36d0e..9becaca 100644 (file)
@@ -282,6 +282,7 @@ TEST(Layer_Test_Eltwise, Accuracy)
 TEST(Layer_Test_PReLU, Accuracy)
 {
     testLayerUsingCaffeModels("layer_prelu", DNN_TARGET_CPU, true);
+    testLayerUsingCaffeModels("layer_prelu_fc", DNN_TARGET_CPU, true, false);
 }
 
 //template<typename XMat>