Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape,
const Tensor& weight /* optional */, const Tensor& bias /* optional */,
double eps, bool cudnn_enabled) {
-
+
int64_t normalized_ndim = normalized_shape.size();
AT_CHECK(normalized_ndim >= 1,
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
'', (True, 'aten::_batch_norm_impl_index')),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
- ('layer_norm', (S, S, S, S), ([5],),),
- ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
- ('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'),
- ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
+ ('layer_norm', (S, S, S, S), ([5],), '',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+ ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+ ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
+ ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
+ non_differentiable(torch.rand(S))), 'with_weight_and_bias',
+ (True, ['prim::Loop', 'aten::_batch_norm_impl_index'])),
('group_norm', (S, S, S), (1, torch.rand(5),),),
('local_response_norm', (S, S, S), (2, ),),
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '', (True, 'aten::nll_loss_forward')),
return torch.adaptive_avg_pool3d(self, output_size), backward
+
def batch_norm(input : Tensor,
weight : Optional[Tensor],
bias : Optional[Tensor],
return output, backward
def layer_norm(input : Tensor,
- normalied_shape : List[int],
+ normalized_shape : List[int],
weight : Optional[Tensor],
bias : Optional[Tensor],
eps : float,
cudnn_enable : bool):
+ input_ndim = input.dim()
+ normalized_ndim = len(normalized_shape)
+ n = 1
+ for i in range(input_ndim - normalized_ndim):
+ n *= input.size(i)
+
+ input_reshape = input.contiguous().view(1, n, -1)
+
bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index(
- input, weight, bias, None, None, True,
+ input_reshape, None, None, None, None, True,
0.0, eps, cudnn_enable)
- has_weight = weight is not None
- has_bias = bias is not None
- bn_out = bn_out.view(input.sizes())
+ bn_out = bn_out.view(input.size())
if weight is not None and bias is not None:
- output = bias.addcmul(bn_out, weight)
+ output = bias.addcmul(bn_out, weight, value=1)
elif weight is not None:
output = bn_out.mul(weight)
elif bias is not None:
output = bn_out
def backward(grad_output):
- if weight is not None:
- grad_output = grad_output * torch.t(weight)
- weight = grad_output * torch.t(bn_out)
+ if weight is not None and bias is not None:
+ grad_bn_out = grad_output * weight
+ grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
+ grad_bias = grad_output._grad_sum_to_size(bias.size())
+ elif weight is not None:
+ grad_bn_out = grad_output * weight
+ grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
+ grad_bias = None
+ elif bias is not None:
+ grad_bn_out = grad_output
+ grad_weight= None
+ grad_bias = grad_output._grad_sum_to_size(bias.size())
+ else:
+ grad_bn_out = grad_output
+ grad_weight= None
+ grad_bias = None
- grad_output = grad_output.reshape(input.sizes())
- dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
- impl_idx, input, grad_output, weight, None, None,
- save1, save2, True, eps, [True, has_weight, has_bias])
- return dinput, None, dweight, dbias, None, None
+ grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
+
+ grad_input, _, _ = torch._batch_norm_impl_index_backward(
+ impl_idx, input_reshape, grad_bn_out, None, None, None,
+ save1, save2, True, eps, [True, False, False])
+
+ grad_input = grad_input.view(input.size())
+ return grad_input, None, grad_weight, grad_bias, None, None
return output, backward