import numpy as np
import os
import unittest
+from functools import partial
import torch
+def _layer_norm_ref(axis, epsilon, X):
+ left = int(np.prod(X.shape[:axis]))
+ reshaped = np.reshape(X, [left, -1])
+ mean = np.mean(reshaped, axis=1).reshape([left, 1])
+ stdev = np.sqrt(
+ np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
+ np.power(mean, 2) + epsilon
+ )
+ norm = (reshaped - mean) / (stdev)
+ norm = np.reshape(norm, X.shape)
+ mean = np.reshape(mean, X.shape[:axis] + (1,))
+ stdev = np.reshape(stdev, X.shape[:axis] + (1,))
+ return [norm, mean, stdev]
+
+
+def _layer_norm_grad_ref(axis, gout_full, norm, mean_full, stdev_full, X_full):
+ left = int(np.prod(X_full.shape[:axis]))
+ right = int(np.prod(X_full.shape[axis:]))
+ X = np.reshape(X_full, [left, right])
+ stdev = np.reshape(stdev_full, [left, 1])
+ mean = np.reshape(mean_full, [left, 1])
+ gout = np.reshape(gout_full, [left, right])
+ dstdev_end = (-1.0) / np.power(stdev, 2.0) \
+ * np.sum((X - mean) * gout, axis=1).reshape([left, 1])
+ dmean_end = np.sum(-1.0 / stdev * gout, axis=1).reshape([left, 1])
+ dx_end = 1.0 / stdev * gout
+
+ # stdev block
+ dmean_stdev = -1.0 * mean / stdev * dstdev_end
+ dx_stdev = X / (right * stdev) * dstdev_end
+
+ # mean block
+ dmean = dmean_end + dmean_stdev
+ dxmean = (1.0 / right) * dmean
+
+ # final outputs
+ dx = dx_end + dx_stdev + dxmean
+ dx = dx.reshape(X_full.shape)
+
+ return [dx]
+
+
class TestLayerNormOp(serial.SerializedTestCase):
@serial.given(X=hu.tensors(n=1), **hu.gcs)
def test_layer_norm_grad_op(self, X, gc, dc):
epsilon=epsilon,
)
- def layer_norm_ref(X):
- left = int(np.prod(X.shape[:axis]))
- reshaped = np.reshape(X, [left, -1])
- mean = np.mean(reshaped, axis=1).reshape([left, 1])
- stdev = np.sqrt(
- np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
- np.power(mean, 2) + epsilon
- )
- norm = (reshaped - mean) / (stdev)
- norm = np.reshape(norm, X.shape)
- mean = np.reshape(mean, X.shape[:axis] + (1,))
- stdev = np.reshape(stdev, X.shape[:axis] + (1,))
- return [norm, mean, stdev]
-
- norm, mean, stdev = layer_norm_ref(X)
+ norm, mean, stdev = _layer_norm_ref(axis, epsilon, X)
gout = norm
- def layer_norm_grad_ref(gout_full, norm, mean_full, stdev_full, X_full):
- left = int(np.prod(X_full.shape[:axis]))
- right = int(np.prod(X_full.shape[axis:]))
- X = np.reshape(X_full, [left, right])
- stdev = np.reshape(stdev_full, [left, 1])
- mean = np.reshape(mean_full, [left, 1])
- gout = np.reshape(gout_full, [left, right])
- dstdev_end = (-1.0) / np.power(stdev, 2.0) \
- * np.sum((X - mean) * gout, axis=1).reshape([left, 1])
- dmean_end = np.sum(-1.0 / stdev * gout, axis=1).reshape([left, 1])
- dx_end = 1.0 / stdev * gout
-
- # stdev block
- dmean_stdev = -1.0 * mean / stdev * dstdev_end
- dx_stdev = X / (right * stdev) * dstdev_end
-
- # mean block
- dmean = dmean_end + dmean_stdev
- dxmean = (1.0 / right) * dmean
-
- # final outputs
- dx = dx_end + dx_stdev + dxmean
- dx = dx.reshape(X_full.shape)
-
- return [dx]
-
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[gout, norm, mean, stdev, X],
- reference=layer_norm_grad_ref
+ reference=partial(_layer_norm_grad_ref, axis)
)
self.assertDeviceChecks(
device_options=dc,
epsilon=epsilon,
)
- def layer_norm_ref(X):
- left = int(np.prod(X.shape[:axis]))
- reshaped = np.reshape(X, [left, -1])
- mean = np.mean(reshaped, axis=1).reshape([left, 1])
- stdev = np.sqrt(
- np.mean(np.power(reshaped, 2), axis=1).reshape([left, 1]) -
- np.power(mean, 2) + epsilon
- )
- norm = (reshaped - mean) / (stdev)
- norm = np.reshape(norm, X.shape)
- mean = np.reshape(mean, X.shape[:axis] + (1,))
- stdev = np.reshape(stdev, X.shape[:axis] + (1,))
- return [norm, mean, stdev]
-
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[X],
- reference=layer_norm_ref
+ reference=partial(_layer_norm_ref, axis, epsilon)
)
self.assertDeviceChecks(
device_options=dc,