Remove code duplication (#15880)
authorSebastian Messmer <messmer@fb.com>
Tue, 15 Jan 2019 01:55:13 +0000 (17:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 01:59:37 +0000 (17:59 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15880

The layer_norm reference was implemented twice. Removing one of them.

Reviewed By: dzhulgakov

Differential Revision: D13611232

fbshipit-source-id: cee96c78d3255c3a4e34300693bf9260cf096615

caffe2/python/operator_test/layer_norm_op_test.py

index 6be3fe0..ed618ed 100644 (file)
@@ -11,9 +11,52 @@ import caffe2.python.serialized_test.serialized_test_util as serial
 import numpy as np
 import os
 import unittest
+from functools import partial
 import torch
 
 
+def _layer_norm_ref(axis, epsilon, X):
+    left = int(np.prod(X.shape[:axis]))
+    reshaped = np.reshape(X, [left, -1])
+    mean = np.mean(reshaped, axis=1).reshape([left, 1])
+    stdev = np.sqrt(
+        np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
+        np.power(mean, 2) + epsilon
+    )
+    norm = (reshaped - mean) / (stdev)
+    norm = np.reshape(norm, X.shape)
+    mean = np.reshape(mean, X.shape[:axis] + (1,))
+    stdev = np.reshape(stdev, X.shape[:axis] + (1,))
+    return [norm, mean, stdev]
+
+
+def _layer_norm_grad_ref(axis, gout_full, norm, mean_full, stdev_full, X_full):
+    left = int(np.prod(X_full.shape[:axis]))
+    right = int(np.prod(X_full.shape[axis:]))
+    X = np.reshape(X_full, [left, right])
+    stdev = np.reshape(stdev_full, [left, 1])
+    mean = np.reshape(mean_full, [left, 1])
+    gout = np.reshape(gout_full, [left, right])
+    dstdev_end = (-1.0) / np.power(stdev, 2.0) \
+            * np.sum((X - mean) * gout, axis=1).reshape([left, 1])
+    dmean_end = np.sum(-1.0 / stdev * gout, axis=1).reshape([left, 1])
+    dx_end = 1.0 / stdev * gout
+
+    # stdev block
+    dmean_stdev = -1.0 * mean / stdev * dstdev_end
+    dx_stdev = X / (right * stdev) * dstdev_end
+
+    # mean block
+    dmean = dmean_end + dmean_stdev
+    dxmean = (1.0 / right) * dmean
+
+    # final outputs
+    dx = dx_end + dx_stdev + dxmean
+    dx = dx.reshape(X_full.shape)
+
+    return [dx]
+
+
 class TestLayerNormOp(serial.SerializedTestCase):
     @serial.given(X=hu.tensors(n=1), **hu.gcs)
     def test_layer_norm_grad_op(self, X, gc, dc):
@@ -30,54 +73,14 @@ class TestLayerNormOp(serial.SerializedTestCase):
             epsilon=epsilon,
         )
 
-        def layer_norm_ref(X):
-            left = int(np.prod(X.shape[:axis]))
-            reshaped = np.reshape(X, [left, -1])
-            mean = np.mean(reshaped, axis=1).reshape([left, 1])
-            stdev = np.sqrt(
-                np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
-                np.power(mean, 2) + epsilon
-            )
-            norm = (reshaped - mean) / (stdev)
-            norm = np.reshape(norm, X.shape)
-            mean = np.reshape(mean, X.shape[:axis] + (1,))
-            stdev = np.reshape(stdev, X.shape[:axis] + (1,))
-            return [norm, mean, stdev]
-
-        norm, mean, stdev = layer_norm_ref(X)
+        norm, mean, stdev = _layer_norm_ref(axis, epsilon, X)
         gout = norm
 
-        def layer_norm_grad_ref(gout_full, norm, mean_full, stdev_full, X_full):
-            left = int(np.prod(X_full.shape[:axis]))
-            right = int(np.prod(X_full.shape[axis:]))
-            X = np.reshape(X_full, [left, right])
-            stdev = np.reshape(stdev_full, [left, 1])
-            mean = np.reshape(mean_full, [left, 1])
-            gout = np.reshape(gout_full, [left, right])
-            dstdev_end = (-1.0) / np.power(stdev, 2.0) \
-                    * np.sum((X - mean) * gout, axis=1).reshape([left, 1])
-            dmean_end = np.sum(-1.0 / stdev * gout, axis=1).reshape([left, 1])
-            dx_end = 1.0 / stdev * gout
-
-            # stdev block
-            dmean_stdev = -1.0 * mean / stdev * dstdev_end
-            dx_stdev = X / (right * stdev) * dstdev_end
-
-            # mean block
-            dmean = dmean_end + dmean_stdev
-            dxmean = (1.0 / right) * dmean
-
-            # final outputs
-            dx = dx_end + dx_stdev + dxmean
-            dx = dx.reshape(X_full.shape)
-
-            return [dx]
-
         self.assertReferenceChecks(
             device_option=gc,
             op=op,
             inputs=[gout, norm, mean, stdev, X],
-            reference=layer_norm_grad_ref
+            reference=partial(_layer_norm_grad_ref, axis)
         )
         self.assertDeviceChecks(
             device_options=dc,
@@ -101,25 +104,11 @@ class TestLayerNormOp(serial.SerializedTestCase):
             epsilon=epsilon,
         )
 
-        def layer_norm_ref(X):
-            left = int(np.prod(X.shape[:axis]))
-            reshaped = np.reshape(X, [left, -1])
-            mean = np.mean(reshaped, axis=1).reshape([left, 1])
-            stdev = np.sqrt(
-                np.mean(np.power(reshaped, 2), axis=1).reshape([left, 1]) -
-                np.power(mean, 2) + epsilon
-            )
-            norm = (reshaped - mean) / (stdev)
-            norm = np.reshape(norm, X.shape)
-            mean = np.reshape(mean, X.shape[:axis] + (1,))
-            stdev = np.reshape(stdev, X.shape[:axis] + (1,))
-            return [norm, mean, stdev]
-
         self.assertReferenceChecks(
             device_option=gc,
             op=op,
             inputs=[X],
-            reference=layer_norm_ref
+            reference=partial(_layer_norm_ref, axis, epsilon)
         )
         self.assertDeviceChecks(
             device_options=dc,