Testing for folded conv_bn_relu (#19298)
authorJerry Zhang <jerryzh@fb.com>
Wed, 17 Apr 2019 01:57:13 +0000 (18:57 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 17 Apr 2019 02:04:06 +0000 (19:04 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19298

Proper testing for conv_bn_relu folding

Differential Revision: D13998891

fbshipit-source-id: ceb58ccec19885cbbf38964ee0d0db070e098b4a

caffe2/python/brew.py
caffe2/python/helpers/normalization.py

index bb6d3f0..2722c21 100644 (file)
@@ -45,6 +45,7 @@ class HelperWrapper(object):
         'instance_norm': instance_norm,
         'spatial_bn': spatial_bn,
         'spatial_gn': spatial_gn,
+        'moments_with_running_stats': moments_with_running_stats,
         'relu': relu,
         'prelu': prelu,
         'tanh': tanh,
index a47ac58..6c0889c 100644 (file)
@@ -290,3 +290,49 @@ def layer_norm(
     )
 
     return biased, mean, stdev
+
+def moments_with_running_stats(model, blob_in, blob_out, dim_in,
+                                     RunningMeanInitializer=None, RunningVarianceInitializer=None,
+                                     order="NCHW", **kwargs):
+
+    if model.init_params:
+        rm_init = ("ConstantFill", {'value': 0.0})
+        riv_init = ("ConstantFill", {'value': 1.0})
+
+        RunningMeanInitializer = initializers.update_initializer(
+            RunningMeanInitializer, rm_init, ("ConstantFill", {})
+        )
+        RunningVarianceInitializer = initializers.update_initializer(
+            RunningVarianceInitializer, riv_init, ("ConstantFill", {})
+        )
+    else:
+        RunningMeanInitializer = initializers.ExternalInitializer()
+        RunningVarianceInitializer = initializers.ExternalInitializer()
+
+    running_mean = model.create_param(
+        param_name=blob_out + '_rm',
+        shape=[dim_in],
+        initializer=RunningMeanInitializer,
+        tags=ParameterTags.COMPUTED_PARAM
+    )
+
+    # this is just running variance
+    running_inv_var = model.create_param(
+        param_name=blob_out + '_riv',
+        shape=[dim_in],
+        initializer=RunningVarianceInitializer,
+        tags=ParameterTags.COMPUTED_PARAM
+    )
+
+    blob_outs = [blob_out + "_sm", blob_out + "_sv"]
+    if order == 'NCHW':
+        blob_outputs = model.net.Moments(
+            [blob_in], blob_outs,
+            axes=[0, 2, 3],
+            order=order, keepdims=False, **kwargs)
+    elif order == 'NHWC':
+        blob_outputs = model.net.Moments(
+            [blob_in], blob_outs,
+            axes=[0, 1, 2],
+            order=order, keepdims=False, **kwargs)
+    return blob_outputs