return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast);
}
+// For some ambiguous cases, it is possible a channels last contiguous Tensor has
+// `suggest_memory_format` of Contiguous.
+// See https://github.com/pytorch/pytorch/issues/63224 for details.
+static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) {
+ return t.is_contiguous() ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
+}
+
template<typename scalar_t>
std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
const Tensor& input, const Tensor& weight, const Tensor& bias,
&& running_mean.is_contiguous()
&& running_var.is_contiguous();
- Tensor output = at::empty_like(input, input.suggest_memory_format());
-
// inference contiguous path
if (all_contiguous) {
+ Tensor output = at::empty_like(input, suggest_memory_format_contig(input));
batch_norm_cpu_stub(kCPU, output, input, weight, bias,
save_mean, save_invstd, running_mean, running_var, train, eps);
return std::make_tuple(output, save_mean, save_invstd);
auto b = bias.defined() ? as_nd(bias) :
at::detail::scalar_tensor_static(0, input.scalar_type(), kCPU);
+ Tensor output = at::empty_like(input, input.suggest_memory_format());
auto iter = TensorIteratorConfig()
.add_output(output)
.add_input(input)
&& input.suggest_memory_format() == grad_out_.suggest_memory_format();
if (all_contiguous) {
+ if (grad_input_mask[0]) {
+ grad_input = at::empty_like(input, suggest_memory_format_contig(input));
+ }
batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
return std::make_tuple(grad_input, grad_weight, grad_bias);
void batch_norm_cpu_kernel(Tensor& output, const Tensor& input,
const Tensor& weight, const Tensor& bias, const Tensor& save_mean, const Tensor& save_invstd,
const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {
- switch (input.suggest_memory_format()) {
- case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_contiguous", [&] {
- batch_norm_cpu_contiguous_impl<scalar_t>(output, input, weight, bias,
- save_mean, save_invstd, running_mean, running_var, train, eps);
- });
- break;
- }
- case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_channels_last", [&] {
- batch_norm_cpu_channels_last_impl<scalar_t>(output, input, weight, bias,
- save_mean, save_invstd, running_mean, running_var, train, eps);
- });
- break;
- }
- default:
- TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+ if (input.is_contiguous()) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_contiguous", [&] {
+ batch_norm_cpu_contiguous_impl<scalar_t>(output, input, weight, bias,
+ save_mean, save_invstd, running_mean, running_var, train, eps);
+ });
+ } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_channels_last", [&] {
+ batch_norm_cpu_channels_last_impl<scalar_t>(output, input, weight, bias,
+ save_mean, save_invstd, running_mean, running_var, train, eps);
+ });
+ } else {
+ TORCH_CHECK(false, "batch_norm_cpu_kernel: expecting input to be contiguous.");
}
}
void batch_norm_cpu_collect_stats_kernel(
Tensor& mean, Tensor& var_sum, const Tensor& input) {
int64_t image_size = input.numel() / input.size(0) / input.size(1);
- switch (input.suggest_memory_format()) {
- case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_contiguous", [&] {
- if (image_size == 1) { // NC11 is also channels last
- batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
- } else {
- batch_norm_cpu_collect_stats_contiguous_impl<scalar_t>(mean, var_sum, input);
- }
- });
- break;
- }
- case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] {
+ if (input.is_contiguous()) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_contiguous", [&] {
+ if (image_size == 1) { // NC11 is also channels last
batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
- });
- break;
- }
- default:
- TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+ } else {
+ batch_norm_cpu_collect_stats_contiguous_impl<scalar_t>(mean, var_sum, input);
+ }
+ });
+ } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_collect_stats_channels_last", [&] {
+ batch_norm_cpu_collect_stats_channels_last_impl<scalar_t>(mean, var_sum, input);
+ });
+ } else {
+ TORCH_CHECK(false, "batch_norm_cpu_collect_stats_kernel: expecting input to be contiguous.");
}
}
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps) {
int64_t image_size = input.numel() / input.size(0) / input.size(1);
- switch (input.suggest_memory_format()) {
- case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_contiguous", [&] {
- if (image_size == 1) { // NC11 is also channels last
- batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
- grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
- } else {
- batch_norm_cpu_backward_contiguous_impl<scalar_t>(grad_input, grad_weight, grad_bias,
- grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
- }
- });
- break;
- }
- case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] {
+ if (input.is_contiguous()) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_contiguous", [&] {
+ if (image_size == 1) { // NC11 is also channels last
batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
- });
- break;
- }
- default:
- TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
+ } else {
+ batch_norm_cpu_backward_contiguous_impl<scalar_t>(grad_input, grad_weight, grad_bias,
+ grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
+ }
+ });
+ } else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_backward_channels_last", [&] {
+ batch_norm_cpu_backward_channels_last_impl<scalar_t>(grad_input, grad_weight, grad_bias,
+ grad_output, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
+ });
+ } else {
+ TORCH_CHECK(false, "batch_norm_cpu_backward_kernel: expecting input to be contiguous.");
}
}