Fuse top layers to batch normalization
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 9 Jun 2018 15:06:53 +0000 (18:06 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 9 Jun 2018 15:06:53 +0000 (18:06 +0300)
modules/dnn/src/layers/batch_norm_layer.cpp

index d42face..3b47232 100644 (file)
@@ -96,6 +96,46 @@ public:
         shift = bias_;
     }
 
+    virtual bool tryFuse(Ptr<Layer>& top) CV_OVERRIDE
+    {
+        Mat w, b;
+        top->getScaleShift(w, b);
+        if (w.empty() && b.empty())
+            return false;
+
+        const int numChannels = weights_.total();
+        const int numFusedWeights = w.total();
+        const int numFusedBias = b.total();
+
+        if ((numFusedWeights != numChannels && numFusedWeights != 1 && !w.empty()) ||
+            (numFusedBias != numChannels && numFusedBias != 1 && !b.empty()))
+            return false;
+
+        if (!w.empty())
+        {
+            w = w.reshape(1, 1);
+            if (numFusedWeights == 1)
+            {
+                multiply(weights_, w.at<float>(0), weights_);
+                multiply(bias_, w.at<float>(0), bias_);
+            }
+            else
+            {
+                multiply(weights_, w, weights_);
+                multiply(bias_, w, bias_);
+            }
+        }
+        if (!b.empty())
+        {
+            b = b.reshape(1, 1);
+            if (numFusedBias == 1)
+                add(bias_, b.at<float>(0), bias_);
+            else
+                add(bias_, b.reshape(1, 1), bias_);
+        }
+        return true;
+    }
+
     bool getMemoryShapes(const std::vector<MatShape> &inputs,
                          const int requiredOutputs,
                          std::vector<MatShape> &outputs,