From: Xiaoqiang Zheng Date: Wed, 17 Apr 2019 02:22:13 +0000 (-0700) Subject: Add a fast path for batch-norm CPU inference. (#19152) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~207 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5627940e9c4ea18ba6c15a2f46f57d8905937c43;p=platform%2Fupstream%2Fpytorch.git Add a fast path for batch-norm CPU inference. (#19152) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19152 Adding a fast path for batch-norm CPU inference when all tensors are contiguous. * Leverage vectorization through smiple loops. * Folding linear terms before computation. * For resnext-101, this version gets 18.95 times faster. * Add a microbenchmark: * (buck build mode/opt -c python.package_style=inplace --show-output //caffe2/benchmarks/operator_benchmark:batchnorm_benchmark) && \ (OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 buck-out/gen/caffe2/benchmarks/operator_benchmark/batchnorm_benchmark#binary.par) * batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s * batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s * batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s Reviewed By: soumith, BIT-silence Differential Revision: D14889728 fbshipit-source-id: 20c9e567e38ff7dbb9097873b85160eca2b0a795 --- diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index b1d3b31..ca0215c 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -54,6 +54,77 @@ struct Var { } }; +/// A fast path for CPU inference when all tensors are contiguous. +/// This code achieves machine bandwidth peak without AVX support. +/// If this changes for future architectures, we can move it to the cpu/ +/// directory. +template +void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input, + const Tensor& weight /* optional */, const Tensor& bias /* optional */, + const Tensor& mean, const Tensor& variance, double eps) { + int64_t n_batch = input.size(0); + int64_t n_channel = input.size(1); + int64_t image_size = input.numel() / n_batch / n_channel; + + scalar_t* output_data = output.data(); + const scalar_t* input_data = input.data(); + const scalar_t* weight_data = weight.defined() ? weight.data() : nullptr; + const scalar_t* bias_data = bias.defined() ? bias.data() : nullptr; + const scalar_t* mean_data = mean.data(); + const scalar_t* var_data = variance.data(); + + /// Collect the linear and constant terms regarding the input. + /// output(n, c, h, w) + /// = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c) + /// + bias(c) + /// = input(n, c, h, w) * inv_var(c) * weight(c) + + /// - mean(c) * inv_var(c) * weight(c) + bias(c), + /// where inv_var(c) = 1 / sqrt(var(c) + eps). + /// So the linear term, alpha(c) = inv_var(c) * weight(c), + /// the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c) + /// Note that this is only a good idea if (input_size >> c), in degenerate + /// cases where image_size == 1 && batch_size == 1, it is slow. + Tensor alpha = at::empty_like(mean); + Tensor beta = at::empty_like(mean); + scalar_t* alpha_data = alpha.data(); + scalar_t* beta_data = beta.data(); + for (int64_t c = 0; c < n_channel; c++) { + scalar_t inv_var = 1 / std::sqrt(var_data[c] + static_cast(eps)); + scalar_t weight_v = weight_data ? weight_data[c] : 1; + scalar_t bias_v = bias_data ? bias_data[c] : 0; + alpha_data[c] = inv_var * weight_v; + beta_data[c] = bias_v - mean_data[c] * inv_var * weight_v; + } + + // Apply the linear terms to the input, + // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c) + // No need to use parallel_for as this function is supposed to be + // memory-limited. + // Keep the loop struture simple to make sure compiler vetorization kicks in. + if (image_size != 1) { + for (int64_t n = 0; n < n_batch; ++n) { + for (int64_t c = 0; c < n_channel; ++c) { + for (int64_t i = 0; i < image_size; ++i) { + // Keep all the offset calculation within the inner loop for + // simplicity. Compilers are very good at hoisting the common part + // outside. + int64_t offset = n * n_channel * image_size + c * image_size + i; + output_data[offset] = input_data[offset] * alpha_data[c] + + beta_data[c]; + } + } + } + } else { + // image_size == 1 + for (int64_t n = 0; n < n_batch; ++n) { + for (int64_t c = 0; c < n_channel; ++c) { + int64_t offset = n * n_channel + c; + output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c]; + } + } + } +} + template std::tuple batch_norm_cpu_transform_input_template( const Tensor& input, const Tensor& weight, const Tensor& bias, @@ -63,6 +134,16 @@ std::tuple batch_norm_cpu_transform_input_template( Tensor output = at::empty_like(input); + // Check if we should use the fast path. + if (!train && input.is_contiguous() + && (!weight.defined() || weight.is_contiguous()) + && (!bias.defined() || bias.is_contiguous()) + && running_mean.is_contiguous() + && running_var.is_contiguous()) { + batch_norm_cpu_inference_contiguous(output, input, weight, bias, + running_mean, running_var, eps); + return std::make_tuple(output, save_mean, save_invstd); + } int64_t n_input = input.size(1); auto save_mean_a = conditional_accessor_1d(save_mean); @@ -384,7 +465,7 @@ Tensor instance_norm( Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape, const Tensor& weight /* optional */, const Tensor& bias /* optional */, double eps, bool cudnn_enabled) { - + int64_t normalized_ndim = normalized_shape.size(); AT_CHECK(normalized_ndim >= 1, diff --git a/benchmarks/operator_benchmark/batchnorm_benchmark.py b/benchmarks/operator_benchmark/batchnorm_benchmark.py new file mode 100644 index 0000000..4577cc8 --- /dev/null +++ b/benchmarks/operator_benchmark/batchnorm_benchmark.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import time + +import numpy +import torch +import torch.nn.functional as F + + +def benchmark_batch_norm(data_shape): + C = data_shape[1] + x = torch.rand(data_shape) + mean = torch.rand(C) + var = torch.rand(C) + weight = torch.rand(C) + bias = torch.rand(C) + NITER = 10000 + input_size = numpy.prod(data_shape) + total_size = 2 * input_size + 4 * C + for i in range(-10, NITER): + if i == 0: + s = time.time() + F.batch_norm(x, mean, var, weight, bias) + elapsed_sec = (time.time() - s) / NITER + print( + "batch_norm: data shape: %s, bandwidth: %.2f GB/s" + % (data_shape, (total_size * 4) / elapsed_sec / 1e9) + ) + + +def main(): + data_shapes = [[1, 256, 3136], [1, 2 ** 16, 1], [128, 2048, 1]] + for data_shape in data_shapes: + benchmark_batch_norm(data_shape) + + +if __name__ == "__main__": + main()