Add a fast path for batch-norm CPU inference. (#19152)
authorXiaoqiang Zheng <zhengxq@fb.com>
Wed, 17 Apr 2019 02:22:13 +0000 (19:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 17 Apr 2019 02:27:54 +0000 (19:27 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19152

Adding a fast path for batch-norm CPU inference when all tensors are contiguous.
* Leverage vectorization through smiple loops.
* Folding linear terms before computation.
* For resnext-101, this version gets 18.95 times faster.
* Add a microbenchmark:
* (buck build mode/opt -c python.package_style=inplace --show-output //caffe2/benchmarks/operator_benchmark:batchnorm_benchmark) && \
(OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 buck-out/gen/caffe2/benchmarks/operator_benchmark/batchnorm_benchmark#binary.par)
* batch_norm: data shape: [1, 256, 3136], bandwidth: 22.26 GB/s
* batch_norm: data shape: [1, 65536, 1], bandwidth: 5.57 GB/s
* batch_norm: data shape: [128, 2048, 1], bandwidth: 18.21 GB/s

Reviewed By: soumith, BIT-silence

Differential Revision: D14889728

fbshipit-source-id: 20c9e567e38ff7dbb9097873b85160eca2b0a795

aten/src/ATen/native/Normalization.cpp
benchmarks/operator_benchmark/batchnorm_benchmark.py [new file with mode: 0644]

index b1d3b31..ca0215c 100644 (file)
@@ -54,6 +54,77 @@ struct Var {
   }
 };
 
+/// A fast path for CPU inference when all tensors are contiguous.
+/// This code achieves machine bandwidth peak without AVX support.
+/// If this changes for future architectures, we can move it to the cpu/
+/// directory.
+template<typename scalar_t>
+void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input,
+    const Tensor& weight /* optional */, const Tensor& bias /* optional */,
+    const Tensor& mean, const Tensor& variance, double eps) {
+  int64_t n_batch = input.size(0);
+  int64_t n_channel = input.size(1);
+  int64_t image_size = input.numel() / n_batch / n_channel;
+
+  scalar_t* output_data = output.data<scalar_t>();
+  const scalar_t* input_data = input.data<scalar_t>();
+  const scalar_t* weight_data = weight.defined() ? weight.data<scalar_t>() : nullptr;
+  const scalar_t* bias_data = bias.defined() ? bias.data<scalar_t>() : nullptr;
+  const scalar_t* mean_data = mean.data<scalar_t>();
+  const scalar_t* var_data = variance.data<scalar_t>();
+
+  /// Collect the linear and constant terms regarding the input.
+  /// output(n, c, h, w)
+  ///     = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
+  ///         + bias(c)
+  ///     = input(n, c, h, w) * inv_var(c) * weight(c) +
+  ///         - mean(c) * inv_var(c) * weight(c) + bias(c),
+  /// where inv_var(c) = 1 / sqrt(var(c) + eps).
+  /// So the linear term, alpha(c) = inv_var(c) * weight(c),
+  ///   the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
+  /// Note that this is only a good idea if (input_size >> c), in degenerate
+  /// cases where image_size == 1 && batch_size == 1, it is slow.
+  Tensor alpha = at::empty_like(mean);
+  Tensor beta = at::empty_like(mean);
+  scalar_t* alpha_data = alpha.data<scalar_t>();
+  scalar_t* beta_data = beta.data<scalar_t>();
+  for (int64_t c = 0; c < n_channel; c++) {
+    scalar_t inv_var = 1 / std::sqrt(var_data[c] + static_cast<scalar_t>(eps));
+    scalar_t weight_v = weight_data ? weight_data[c] : 1;
+    scalar_t bias_v = bias_data ? bias_data[c] : 0;
+    alpha_data[c] = inv_var * weight_v;
+    beta_data[c] = bias_v - mean_data[c] * inv_var * weight_v;
+  }
+
+  // Apply the linear terms to the input,
+  // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
+  // No need to use parallel_for as this function is supposed to be
+  // memory-limited.
+  // Keep the loop struture simple to make sure compiler vetorization kicks in.
+  if (image_size != 1) {
+    for (int64_t n = 0; n < n_batch; ++n) {
+      for (int64_t c = 0; c < n_channel; ++c) {
+        for (int64_t i = 0; i < image_size; ++i) {
+          // Keep all the offset calculation within the inner loop for
+          // simplicity. Compilers are very good at hoisting the common part
+          // outside.
+          int64_t offset = n * n_channel * image_size + c * image_size + i;
+          output_data[offset] = input_data[offset] * alpha_data[c] +
+              beta_data[c];
+        }
+      }
+    }
+  } else {
+    // image_size == 1
+    for (int64_t n = 0; n < n_batch; ++n) {
+      for (int64_t c = 0; c < n_channel; ++c) {
+        int64_t offset = n * n_channel + c;
+        output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
+      }
+    }
+  }
+}
+
 template<typename scalar_t>
 std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
     const Tensor& input, const Tensor& weight, const Tensor& bias,
@@ -63,6 +134,16 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
 
   Tensor output = at::empty_like(input);
 
+  // Check if we should use the fast path.
+  if (!train && input.is_contiguous()
+      && (!weight.defined() || weight.is_contiguous())
+      && (!bias.defined() || bias.is_contiguous())
+      && running_mean.is_contiguous()
+      && running_var.is_contiguous()) {
+    batch_norm_cpu_inference_contiguous<scalar_t>(output, input, weight, bias,
+      running_mean, running_var, eps);
+    return std::make_tuple(output, save_mean, save_invstd);
+  }
   int64_t n_input = input.size(1);
 
   auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
@@ -384,7 +465,7 @@ Tensor instance_norm(
 Tensor layer_norm(const Tensor& input, IntArrayRef normalized_shape,
     const Tensor& weight /* optional */, const Tensor& bias /* optional */,
     double eps, bool cudnn_enabled) {
-  
+
     int64_t normalized_ndim = normalized_shape.size();
 
     AT_CHECK(normalized_ndim >= 1,
diff --git a/benchmarks/operator_benchmark/batchnorm_benchmark.py b/benchmarks/operator_benchmark/batchnorm_benchmark.py
new file mode 100644 (file)
index 0000000..4577cc8
--- /dev/null
@@ -0,0 +1,38 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import time
+
+import numpy
+import torch
+import torch.nn.functional as F
+
+
+def benchmark_batch_norm(data_shape):
+    C = data_shape[1]
+    x = torch.rand(data_shape)
+    mean = torch.rand(C)
+    var = torch.rand(C)
+    weight = torch.rand(C)
+    bias = torch.rand(C)
+    NITER = 10000
+    input_size = numpy.prod(data_shape)
+    total_size = 2 * input_size + 4 * C
+    for i in range(-10, NITER):
+        if i == 0:
+            s = time.time()
+        F.batch_norm(x, mean, var, weight, bias)
+    elapsed_sec = (time.time() - s) / NITER
+    print(
+        "batch_norm: data shape: %s, bandwidth: %.2f GB/s"
+        % (data_shape, (total_size * 4) / elapsed_sec / 1e9)
+    )
+
+
+def main():
+    data_shapes = [[1, 256, 3136], [1, 2 ** 16, 1], [128, 2048, 1]]
+    for data_shape in data_shapes:
+        benchmark_batch_norm(data_shape)
+
+
+if __name__ == "__main__":
+    main()