From 4ed9de8680864a6a77b0861d1a74ba4faaeaa041 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Mon, 14 Jan 2019 17:55:13 -0800 Subject: [PATCH] Remove code duplication (#15880) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15880 The layer_norm reference was implemented twice. Removing one of them. Reviewed By: dzhulgakov Differential Revision: D13611232 fbshipit-source-id: cee96c78d3255c3a4e34300693bf9260cf096615 --- caffe2/python/operator_test/layer_norm_op_test.py | 103 ++++++++++------------ 1 file changed, 46 insertions(+), 57 deletions(-) diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index 6be3fe0..ed618ed 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -11,9 +11,52 @@ import caffe2.python.serialized_test.serialized_test_util as serial import numpy as np import os import unittest +from functools import partial import torch +def _layer_norm_ref(axis, epsilon, X): + left = int(np.prod(X.shape[:axis])) + reshaped = np.reshape(X, [left, -1]) + mean = np.mean(reshaped, axis=1).reshape([left, 1]) + stdev = np.sqrt( + np.mean(np.square(reshaped), axis=1).reshape([left, 1]) - + np.power(mean, 2) + epsilon + ) + norm = (reshaped - mean) / (stdev) + norm = np.reshape(norm, X.shape) + mean = np.reshape(mean, X.shape[:axis] + (1,)) + stdev = np.reshape(stdev, X.shape[:axis] + (1,)) + return [norm, mean, stdev] + + +def _layer_norm_grad_ref(axis, gout_full, norm, mean_full, stdev_full, X_full): + left = int(np.prod(X_full.shape[:axis])) + right = int(np.prod(X_full.shape[axis:])) + X = np.reshape(X_full, [left, right]) + stdev = np.reshape(stdev_full, [left, 1]) + mean = np.reshape(mean_full, [left, 1]) + gout = np.reshape(gout_full, [left, right]) + dstdev_end = (-1.0) / np.power(stdev, 2.0) \ + * np.sum((X - mean) * gout, axis=1).reshape([left, 1]) + dmean_end = np.sum(-1.0 / stdev * gout, axis=1).reshape([left, 1]) + dx_end = 1.0 / stdev * gout + + # stdev block + dmean_stdev = -1.0 * mean / stdev * dstdev_end + dx_stdev = X / (right * stdev) * dstdev_end + + # mean block + dmean = dmean_end + dmean_stdev + dxmean = (1.0 / right) * dmean + + # final outputs + dx = dx_end + dx_stdev + dxmean + dx = dx.reshape(X_full.shape) + + return [dx] + + class TestLayerNormOp(serial.SerializedTestCase): @serial.given(X=hu.tensors(n=1), **hu.gcs) def test_layer_norm_grad_op(self, X, gc, dc): @@ -30,54 +73,14 @@ class TestLayerNormOp(serial.SerializedTestCase): epsilon=epsilon, ) - def layer_norm_ref(X): - left = int(np.prod(X.shape[:axis])) - reshaped = np.reshape(X, [left, -1]) - mean = np.mean(reshaped, axis=1).reshape([left, 1]) - stdev = np.sqrt( - np.mean(np.square(reshaped), axis=1).reshape([left, 1]) - - np.power(mean, 2) + epsilon - ) - norm = (reshaped - mean) / (stdev) - norm = np.reshape(norm, X.shape) - mean = np.reshape(mean, X.shape[:axis] + (1,)) - stdev = np.reshape(stdev, X.shape[:axis] + (1,)) - return [norm, mean, stdev] - - norm, mean, stdev = layer_norm_ref(X) + norm, mean, stdev = _layer_norm_ref(axis, epsilon, X) gout = norm - def layer_norm_grad_ref(gout_full, norm, mean_full, stdev_full, X_full): - left = int(np.prod(X_full.shape[:axis])) - right = int(np.prod(X_full.shape[axis:])) - X = np.reshape(X_full, [left, right]) - stdev = np.reshape(stdev_full, [left, 1]) - mean = np.reshape(mean_full, [left, 1]) - gout = np.reshape(gout_full, [left, right]) - dstdev_end = (-1.0) / np.power(stdev, 2.0) \ - * np.sum((X - mean) * gout, axis=1).reshape([left, 1]) - dmean_end = np.sum(-1.0 / stdev * gout, axis=1).reshape([left, 1]) - dx_end = 1.0 / stdev * gout - - # stdev block - dmean_stdev = -1.0 * mean / stdev * dstdev_end - dx_stdev = X / (right * stdev) * dstdev_end - - # mean block - dmean = dmean_end + dmean_stdev - dxmean = (1.0 / right) * dmean - - # final outputs - dx = dx_end + dx_stdev + dxmean - dx = dx.reshape(X_full.shape) - - return [dx] - self.assertReferenceChecks( device_option=gc, op=op, inputs=[gout, norm, mean, stdev, X], - reference=layer_norm_grad_ref + reference=partial(_layer_norm_grad_ref, axis) ) self.assertDeviceChecks( device_options=dc, @@ -101,25 +104,11 @@ class TestLayerNormOp(serial.SerializedTestCase): epsilon=epsilon, ) - def layer_norm_ref(X): - left = int(np.prod(X.shape[:axis])) - reshaped = np.reshape(X, [left, -1]) - mean = np.mean(reshaped, axis=1).reshape([left, 1]) - stdev = np.sqrt( - np.mean(np.power(reshaped, 2), axis=1).reshape([left, 1]) - - np.power(mean, 2) + epsilon - ) - norm = (reshaped - mean) / (stdev) - norm = np.reshape(norm, X.shape) - mean = np.reshape(mean, X.shape[:axis] + (1,)) - stdev = np.reshape(stdev, X.shape[:axis] + (1,)) - return [norm, mean, stdev] - self.assertReferenceChecks( device_option=gc, op=op, inputs=[X], - reference=layer_norm_ref + reference=partial(_layer_norm_ref, axis, epsilon) ) self.assertDeviceChecks( device_options=dc, -- 2.7.4