grad_weight = at::empty_like(weight, at::MemoryFormat::Contiguous);
}
if (grad_input_mask[2]) {
- grad_bias = at::empty_like(weight, at::MemoryFormat::Contiguous);
+ grad_bias = at::empty({input.size(1)}, input.options());
}
// since we are directly manipulating pointers in contiguous path,
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
auto num_features = input.sizes()[1];
+
+ if (input.numel() == 0) {
+ Tensor reserve = at::empty({0}, input.options().dtype(kByte));
+ auto options = input.options().dtype(
+ at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda()));
+ auto save_mean = at::empty({num_features}, options);
+ auto save_invstd = at::empty({num_features}, options);
+
+ // don't return view of input, don't return empty tensor because it will break gradient chain
+ auto out = input.clone();
+ if (weight.defined()) out = out * weight[0];
+ if (bias.defined()) out = out + bias[0];
+ return std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t>(
+ out, save_mean, save_invstd, reserve, 0);
+ }
+
if (running_mean.defined()) {
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
} else if (!training) {
const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();});
- if (impl_index == 0) {
+ if (input.numel() == 0) {
+ std::vector<int64_t> dims(input.dim() - 1);
+ dims[0] = 0;
+ std::iota(dims.begin() + 1, dims.end(), 2);
+
+ // don't return empty tensor because it will break gradient chain
+ Tensor grad_input;
+ Tensor grad_weight;
+ Tensor grad_bias;
+ if (output_mask[2]) {
+ grad_bias = grad_output.sum(dims);
+ }
+ if (output_mask[1]) {
+ grad_weight = (grad_output * input).sum(dims);
+ }
+ if (output_mask[0] && weight.defined()) {
+ grad_input = grad_output * weight[0];
+ }
+ return std::make_tuple(grad_input, grad_weight, grad_bias);
+ }
+
+ // backward in inference mode is not supported in cudnn, fallback to native
+ // TODO: verify the same thing in miopen
+ if (impl_index == 0 || (!train)) {
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
} else if (impl_index == 1) {
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
- if (input.numel()==0){
- //don't return view of input, don't return empty tensor because it will break gradient chain
- auto out = input.clone();
- if (weight.defined()) out = out * weight[0];
- if (bias.defined()) out = out + bias[0];
- return out;
- }
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
}
return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm", [&] {
if (!train) {
- return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
+ auto save_mean = at::empty({0}, self.options());
+ auto save_var = at::empty({0}, self.options());
+ return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps);
} else {
auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, InvStd>(self, running_mean, running_var, momentum, eps);
return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps);
// save_mean and save_invstd, so it needs recalculated.
const auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true);
Tensor mean;
- if (save_mean->defined()) {
+ TORCH_INTERNAL_ASSERT(save_mean->defined(), "save_mean should always be defined\n");
+ if (save_mean->numel() != 0) {
mean = *save_mean;
} else if (needs_reduction) {
TORCH_CHECK(!train && running_mean->defined());
}
Tensor invstd;
- if (save_invstd->defined()) {
+ TORCH_INTERNAL_ASSERT(save_invstd->defined(), "save_invstd should always be defined\n");
+ if (save_invstd->numel() != 0) {
invstd = *save_invstd;
} else {
TORCH_CHECK(!train && running_var->defined());
#endif // CUDNN_VERSION >= 7400
} else {
reserve = at::empty({0}, input->options().dtype(kByte));
+ // This keeps a consistent output with native_batch_norm
+ save_mean = at::empty({0}, weight_t.options());
+ save_var = at::empty({0}, weight_t.options());
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
handle, mode, &one, &zero,
idesc.desc(), input->data_ptr(),
save_mean.data_ptr(),
save_var.data_ptr()));
} else {
+ save_mean = at::empty({0}, weight_t.options());
+ save_var = at::empty({0}, weight_t.options());
MIOPEN_CHECK(miopenBatchNormalizationForwardInference(
handle, mode, &one, &zero,
idesc.desc(), input->data_ptr(),
self.assertEqual(w.grad, w_ref.grad)
self.assertEqual(b.grad, b_ref.grad)
+ @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix")
+ def test_batch_norm_inference_backward_cuda(self):
+ with enable_profiling_mode_for_profiling_tests():
+ class MyBatchNorm(torch.nn.Module):
+ def __init__(self, num_features, affine, track_running_stats):
+ super(MyBatchNorm, self).__init__()
+ self.bn = torch.nn.BatchNorm2d(
+ num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float()
+
+ def forward(self, x: torch.Tensor):
+ o = self.bn(x)
+ o = torch.nn.functional.relu(o)
+ return o
+
+ batch = 4
+ c = 2
+ hw = 3
+ # Initialize param and input values
+ x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
+ grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
+
+ training = False
+ affine = True
+ track_running_stats = True
+
+ module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda()
+ ref_module = MyBatchNorm(c, affine, track_running_stats).cuda()
+ module.eval()
+ ref_module.eval()
+
+ jit_module = torch.jit.script(module)
+ ref_module.load_state_dict(module.state_dict())
+
+ x = x_init.detach().clone()
+ x.requires_grad_()
+ x_ref = x_init.detach().clone()
+ x_ref.requires_grad_()
+
+ # Test symbolic differentiation
+ # Run Forward and Backward thrice to trigger autodiff graph
+ for i in range(0, 3):
+ y = jit_module(x)
+ y.backward(grad)
+ x.grad.zero_()
+
+ module.bn.running_mean.zero_()
+ module.bn.running_var.fill_(1.0)
+ ref_module.bn.running_mean.zero_()
+ ref_module.bn.running_var.fill_(1.0)
+
+ # run jitted module
+ y = jit_module(x)
+ y.backward(grad)
+ # reference computation
+ y_ref = ref_module(x_ref)
+ y_ref.backward(grad)
+
+ self.assertEqual(y_ref, y)
+ self.assertEqual(x.grad, x_ref.grad)
+ self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean)
+ self.assertEqual(module.bn.running_var, ref_module.bn.running_var)
+
def test_zeros(self):
class M(torch.jit.ScriptModule):
__constants__ = ['d']
return result, backward
)",
R"(
- def batch_norm_disabled(input : Tensor,
+ def batch_norm(input : Tensor,
weight : Optional[Tensor],
bias : Optional[Tensor],
running_mean : Optional[Tensor],
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
- ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),
- '', (False, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S),
+ (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
+ 'training', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (0, S, S, S),
+ (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+ 'size_zero', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (0, S, S, S),
+ (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+ 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S),
+ (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
+ 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ None, non_differentiable(torch.ones(S)), True, ),
+ 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ non_differentiable(torch.randn(S)), None, True, ),
+ 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ None, None, False, ),
+ 'inference', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
+ 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ None, non_differentiable(torch.ones(S)), False, ),
+ 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
+ ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
+ non_differentiable(torch.randn(S)), None, False, ),
+ 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
('layer_norm', (S, S, S, S), ([5],), '',
(True, ['aten::native_layer_norm'])),