Add support for batch_norm fusion to the JIT (#15146)
authorAdam Paszke <adam.paszke@gmail.com>
Tue, 8 Jan 2019 14:57:45 +0000 (06:57 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 15:00:19 +0000 (07:00 -0800)
Summary:
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.

This commit gives us a 7% end-to-end speedup for ResNet50 with batch size 32. Note that this only applies to inference mode at the moment due to lack of AD support for CNN operations (I'll be adding that soon), and not to the standard `torchvision` models, because they use in-place ops which aren't supported by the fuser (we need a way of proving that de-inplacing them is safe).

cc zou3519 zdevito mruberry ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15146

Differential Revision: D13548303

Pulled By: zou3519

fbshipit-source-id: a2e2e5abc383f637fae19bd1b423f20c2cbc056a

aten/src/ATen/core/interned_strings.h
aten/src/ATen/native/Normalization.cpp
aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/Normalization.cuh
aten/src/ATen/native/native_functions.yaml
test/test_jit.py
torch/csrc/jit/fuser/codegen.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/utils/subgraph_utils.cpp
torch/csrc/jit/passes/utils/subgraph_utils.h

index 40f9b39..5e9144b 100644 (file)
@@ -71,6 +71,7 @@ namespace c10 {
   _(prim, MMBatchSide)             \
   _(prim, min)                     \
   _(prim, max)                     \
+  _(aten, _ncf_unsqueeze)          \
   _(aten, warn)                    \
   _(aten, floordiv)                \
   _(aten, __round_to_zero_floordiv)\
index 1843f78..85f517f 100644 (file)
@@ -36,24 +36,35 @@ static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
   return t.accessor<scalar_t, 1>();
 }
 
+template<typename T>
+struct InvStd {
+  T operator()(T var, double epsilon) const {
+    T invstd = 0;
+    if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
+      invstd = static_cast<T>(1) / std::sqrt(var + epsilon);
+    }
+    return invstd;
+  }
+};
+
+template<typename T>
+struct Var {
+  T operator()(T var, double epsilon) const {
+    return var;
+  }
+};
 
 template<typename scalar_t>
-std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_template(const Tensor& input, const Tensor& weight, const Tensor& bias,
-                               const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) {
+std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
+    const Tensor& input, const Tensor& weight, const Tensor& bias,
+    const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
+    const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
+    bool train, double eps) {
 
-  using accscalar_t = at::acc_type<scalar_t, false>;
   Tensor output = at::empty_like(input);
 
   int64_t n_input = input.size(1);
-  int64_t n = input.numel() / n_input;
 
-  Tensor save_mean;
-  Tensor save_invstd;
-  const int64_t zero = 0;
-  if (train) {
-    save_mean = at::empty({n_input}, input.options());
-    save_invstd = at::empty({n_input}, input.options());
-  }
   auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
   auto save_invstd_a = conditional_accessor_1d<scalar_t>(save_invstd);
 
@@ -61,58 +72,79 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_template(const Tensor& input, co
   auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
 
   parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
-      for (int64_t f = b_begin; f < b_end; ++f) {
-        Tensor in = input.select(1, f);
-        Tensor out = output.select(1, f);
+    for (int64_t f = b_begin; f < b_end; ++f) {
+      Tensor in = input.select(1, f);
+      Tensor out = output.select(1, f);
+
+      scalar_t mean, invstd;
+      if (train) {
+        mean = save_mean_a[f];
+        invstd = save_invstd_a[f];
+      } else {
+        mean = running_mean_a[f];
+        invstd = 1 / std::sqrt(running_var_a[f] + eps);
+      }
 
-        scalar_t mean, invstd;
+      // compute output
+      scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
+      scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
 
-        if (train) {
-          // compute mean per input
-          accscalar_t sum = 0;
-          CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
-              sum += i;
-            });
-
-          mean = (scalar_t) (sum / n);
-          save_mean_a[f] = mean;
-
-          // compute variance per input
-          sum = 0;
-          CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
-              sum += (i - mean) * (i - mean);
-            });
-
-          if (sum == 0 && eps == 0.0) {
-            invstd = 0;
-          } else {
-            invstd = (scalar_t) (1 / std::sqrt(sum/n + eps));
-          }
-          save_invstd_a[f] = invstd;
+      CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
+        o = ((i - mean) * invstd) * w + b;
+      });
+    }
+  });
+  return std::make_tuple(output, save_mean, save_invstd);
+}
 
-          // update running averages
-          if (running_mean.defined()) {
-            running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
-          }
-          if (running_var.defined()) {
-            accscalar_t unbiased_var = sum / (n - 1);
-            running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
-          }
-        } else {
-          mean = running_mean_a[f];
-          invstd = 1 / std::sqrt(running_var_a[f] + eps);
-        }
+template<typename scalar_t, template<typename T> class VarTransform>
+std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
+    const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
+    double momentum, double eps) {
 
-        // compute output
-        scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
-        scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
+  using accscalar_t = at::acc_type<scalar_t, false>;
 
-        CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
-            o = ((i - mean) * invstd) * w + b;
-          });
+  int64_t n_input = input.size(1);
+  int64_t n = input.numel() / n_input;
+
+  Tensor save_mean = at::empty({n_input}, input.options());
+  Tensor save_var_transform = at::empty({n_input}, input.options());
+  auto save_mean_a = save_mean.accessor<scalar_t, 1>();
+  auto save_var_transform_a = save_var_transform.accessor<scalar_t, 1>();
+
+  auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
+  auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
+
+  parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
+    for (int64_t f = b_begin; f < b_end; ++f) {
+      Tensor in = input.select(1, f);
+
+      // compute mean per input
+      accscalar_t sum = 0;
+      CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+          sum += i;
+        });
+      scalar_t mean = sum / n;
+      save_mean_a[f] = mean;
+
+      // compute variance per input
+      accscalar_t var_sum = 0;
+      CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+        var_sum += (i - mean) * (i - mean);
+      });
+      save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);
+
+      // update running averages
+      if (running_mean.defined()) {
+        running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
       }
-    });
-  return std::make_tuple(output, save_mean, save_invstd);
+      if (running_var.defined()) {
+        accscalar_t unbiased_var = var_sum / (n - 1);
+        running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
+      }
+    }
+  });
+  return std::make_tuple(save_mean, save_var_transform);
 }
 
 
@@ -419,11 +451,23 @@ Tensor group_norm(const Tensor& input, int64_t num_groups,
     }
 }
 
+std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
+        const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
+  return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_update_stats", [&] {
+      return batch_norm_cpu_update_stats_template<scalar_t, Var>(self, running_mean, running_var, momentum, 0);
+    });
+}
+
 std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias,
                                                   const Tensor& running_mean, const Tensor& running_var,
                                                   bool train, double momentum, double eps) {
   return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
-      return batch_norm_cpu_template<scalar_t>(self, weight, bias, running_mean, running_var, train, momentum, eps);
+      if (!train) {
+        return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
+      } else {
+        auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, InvStd>(self, running_mean, running_var, momentum, eps);
+        return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps);
+      }
     });
 }
 
index 3e76d27..551e603 100644 (file)
@@ -24,4 +24,15 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_o
     });
 }
 
+std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
+        const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
+  return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
+      if (cuda::detail::canUse32BitIndexMath(self)) {
+        return batch_norm_update_stats_cuda_template<scalar_t, int32_t>(self, running_mean, running_var, momentum);
+      } else {
+        return batch_norm_update_stats_cuda_template<scalar_t, int64_t>(self, running_mean, running_var, momentum);
+      }
+    });
+}
+
 } } // namespace at::native
index e186ef3..bf3e9a5 100644 (file)
@@ -200,8 +200,25 @@ __global__ void batch_norm_transform_input_kernel(
   }
 }
 
+template<typename T>
+struct InvStd {
+  __device__ __forceinline__ T operator()(T var, double epsilon) const {
+    T invstd = 0;
+    if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
+      invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
+    }
+    return invstd;
+  }
+};
 
-template <typename scalar_t, typename accscalar_t, typename index_t>
+template<typename T>
+struct Var {
+  __device__ __forceinline__ T operator()(T var, double epsilon) const {
+    return var;
+  }
+};
+
+template <template<typename T> class VarTransform, typename scalar_t, typename accscalar_t, typename index_t>
 __global__ void batch_norm_collect_statistics_kernel(
     const PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> input,
     const accscalar_t epsilon,
@@ -209,7 +226,7 @@ __global__ void batch_norm_collect_statistics_kernel(
     PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
     PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
     PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
-    PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_invstd) {
+    PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
 
   __shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE];
 
@@ -252,7 +269,7 @@ __global__ void batch_norm_collect_statistics_kernel(
 
   // this writes each warps  item into shared memory
   // there are at most WARP_SIZE items left because
-  // there are at most WARP_SIZE**2 threads at the beginning  
+  // there are at most WARP_SIZE**2 threads at the beginning
   __syncthreads();
   if (tid % WARP_SIZE == 0) {
     shared_n[tid / WARP_SIZE] = n;
@@ -280,12 +297,8 @@ __global__ void batch_norm_collect_statistics_kernel(
 
   // Save the mean, variance, and moving averages
   if (tid == 0) {
-    accscalar_t invstd = 0;
-    if (var_n != static_cast<accscalar_t>(0) || epsilon != static_cast<accscalar_t>(0)) {
-      invstd = static_cast<accscalar_t>(1) / device_sqrt(var_n / N + epsilon);
-    }
     save_mean[plane] = avg;
-    save_invstd[plane] = invstd;
+    save_transformed_var[plane] = VarTransform<accscalar_t>{}(var_n / N, epsilon);
     if (running_mean.data() != NULL) {
       running_mean[plane] = static_cast<scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
     }
@@ -431,7 +444,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda_template(const Tensor& input_
     dim3 blocks(input.size(1));
     tf = getNumThreads(input.size(2));
     dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
-    batch_norm_collect_statistics_kernel<scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+    batch_norm_collect_statistics_kernel<InvStd, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
       (input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
     batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
       (input, output, save_mean, save_invstd, weight, bias, epsilon);
@@ -488,4 +501,39 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tenso
   return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
 }
 
+template<typename scalar_t, typename index_t>
+std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
+        const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
+
+  using accscalar_t = at::acc_type<scalar_t, true>;
+  int64_t n_channels = input_.size(1);
+  auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+
+  auto input_options = input_.options();
+  if (input_.type().scalarType() == at::ScalarType::Half) {
+    input_options = input_options.dtype(ScalarType::Float);
+  }
+  Tensor save_mean_ = at::empty({n_channels}, input_options);
+  Tensor save_var_ = at::empty({n_channels}, input_options);
+
+  auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+  auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
+  auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
+  auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+  auto save_var = save_var_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+  auto stream = at::cuda::getCurrentCUDAStream();
+
+  // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
+  // the feature dimension, we'll use some threads for blocks
+  dim3 blocks(input.size(1));
+  int tf = getNumThreads(input.size(2));
+  dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
+  // NB: epsilon is unused by the Var transform, so we set it to 0
+  batch_norm_collect_statistics_kernel<Var, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+    (input, 0., momentum, running_mean, running_var, save_mean, save_var);
+  THCudaCheck(cudaGetLastError());
+  return std::make_tuple(save_mean_, save_var_);
+
+}
+
 } } // namespace at::native
index 4f4eaf8..1a35f06 100644 (file)
     CPU: batch_norm_backward_cpu
     CUDA: batch_norm_backward_cuda
 
+- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, double momentum) -> (Tensor, Tensor)
+  dispatch:
+    CPU: batch_norm_update_stats_cpu
+    CUDA: batch_norm_update_stats_cuda
+
 - func: ones(IntList size, TensorOptions options={}) -> Tensor
 
 - func: ones_out(Tensor result, IntList size) -> Tensor
index ac48e67..f43d0a4 100644 (file)
@@ -10385,6 +10385,58 @@ class TestFuser(JitTestCase):
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    def test_fuse_batch_norm(self):
+
+        class ResLike(torch.jit.ScriptModule):
+            def __init__(self, optimize=True):
+                super(ResLike, self).__init__(optimize)
+                self.bn = nn.BatchNorm2d(16)
+
+            @torch.jit.script_method
+            def forward(self, x, y):
+                return y + torch.relu(self.bn(x))
+
+        model = ResLike().cuda()
+        model_noopt = ResLike(optimize=False).cuda()
+        model_noopt.load_state_dict(model.state_dict())
+        x = torch.randn(2, 16, 8, 8, device='cuda')
+        y = torch.randn(2, 16, 8, 8, device='cuda')
+        # FIXME: We need differentiation for CNNs for this optimization to trigger
+        with torch.no_grad():
+            out = model(x, y)
+            graph = model.graph_for(x, y)
+            rep = str(graph)
+
+            out_noopt = model_noopt(x, y)
+            rep_noopt = str(model_noopt.graph_for(x, y))
+            self.assertEqual(out, out_noopt, prec=3e-5)
+
+        # Check that batch_norm has really been decomposed
+        self.assertIn('aten::batch_norm_update_stats', rep)
+        self.assertNotIn('aten::batch_norm(', rep)
+        self.assertIn('aten::batch_norm(', rep_noopt)
+
+        # Make sure the fusion group is big, and contains aten::sqrt, which could
+        # originate only from decomposing batch_norm in this case
+        fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
+        self.assertEqual(len(fusion_groups), 1)
+        fused_graph = fusion_groups[0].g('Subgraph')
+        self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    def test_threshold(self):
+        def f(x):
+            return torch.threshold(x, 0, -10) + x + x + x
+
+        x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
+        scripted = torch.jit.script(f)
+
+        self.assertEqual(f(x), scripted(x))
+        self.assertAllFused(scripted.graph_for(x))
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
     @skipIfRocm
     @enable_cpu_fuser
index 512865e..ef50e70 100644 (file)
@@ -140,6 +140,8 @@ static std::string encodeRHS(const Node* n) {
       {aten::abs, "fabs(${0})"},
       {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
       {aten::relu, "${0} < 0 ? 0.f : ${0} "},
+      {aten::threshold,
+       "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
       {aten::log, "logf(${0})"},
       {aten::log10, "log10f(${0})"},
       {aten::log1p, "log1pf(${0})"},
index 0157000..cb69295 100644 (file)
@@ -3,11 +3,14 @@
 #include <ATen/ExpandUtils.h>
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
+#include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 #include <unordered_map>
 
@@ -68,6 +71,7 @@ bool isSimpleMap(Node* node) {
       //"aten::rand_like(Tensor self) -> Tensor",
       "aten::reciprocal(Tensor self) -> Tensor",
       "aten::relu(Tensor self) -> Tensor",
+      "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
       "aten::remainder(Tensor self, Tensor other) -> Tensor",
       "aten::round(Tensor self) -> Tensor",
       "aten::rsqrt(Tensor self) -> Tensor",
@@ -116,6 +120,44 @@ bool isSimpleMap(Node* node) {
   return true;
 }
 
+RegisterOperators reg_bn_unsqueeze({Operator(
+    "aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
+    [](const Node* node) {
+      return [](Stack& stack) {
+        const int64_t ndim = pop(stack).toInt();
+        auto self = pop(stack).toTensor();
+        c10::SmallVector<int64_t, 8> sizes(ndim, 1);
+        JIT_ASSERT(self.dim() == 1);
+        sizes.at(1) = self.size(0);
+        push(stack, self.reshape(sizes));
+        return 0;
+      };
+    })});
+
+// Yes, no, or no value if we can't tell
+c10::optional<bool> isDefined(Value* tensor) {
+  if (tensor->type()->isSubtypeOf(DynamicType::get())) {
+    return true;
+  }
+  if (tensor->node()->kind() == prim::None ||
+      tensor->node()->kind() == prim::Undefined) {
+    return false;
+  }
+  return {};
+}
+
+bool isFusableBatchNorm(Node* batch_norm) {
+  if (!batch_norm->matches(
+          "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
+    return false;
+  }
+  // If we can't determine if weight and bias is defined statically there's
+  // really no point in decomposing batch norm into simpler ops, since it won't
+  // get fused into a single kernel.
+  return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
+      isDefined(batch_norm->namedInput(attr::bias)).has_value();
+}
+
 Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
   JIT_ASSERT(!sizes.empty());
   Graph* graph = sizes[0]->owningGraph();
@@ -127,6 +169,7 @@ Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
 
 struct GraphFuser {
   Block* block_;
+  c10::optional<AliasDb> aliasDb_;
   std::shared_ptr<Graph> graph_;
 
   GraphFuser(Block* block, std::shared_ptr<Graph> graph)
@@ -139,6 +182,10 @@ struct GraphFuser {
   }
 
   bool isFusable(Node* node) {
+    return isFusableMap(node) || isFusableBatchNorm(node);
+  }
+
+  bool isFusableMap(Node* node) {
     // We don't want to bother with cross-block node movements, as they
     // are not necessarily correct.
     if (node->owningBlock() != block_)
@@ -169,7 +216,7 @@ struct GraphFuser {
   // cannot be fused because it is not a simple map, can be put in a fusion
   // group as long as no items in the group read the output of concat
   bool isFusableAsExitNode(Node* node) {
-    return isFusable(node) || isFusableOnlyAsExitNode(node);
+    return isFusableMap(node) || isFusableOnlyAsExitNode(node);
   }
 
   bool isFusableOnlyAsExitNode(Node* node) {
@@ -205,6 +252,59 @@ struct GraphFuser {
     return *n->g(attr::Subgraph);
   }
 
+  void decomposeBatchNorm(Node* batch_norm) {
+    static std::shared_ptr<Graph> bn_graph;
+    static std::once_flag flag;
+    std::call_once(
+        flag,
+        [](std::shared_ptr<Graph>* graph_ptr) {
+          static const char* source = R"SCRIPT(
+        def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
+            if training:
+                norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
+            else:
+                norm_mean = torch._unwrap_optional(running_mean)
+                norm_var = torch._unwrap_optional(running_var)
+            norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
+            norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
+            norm_invstd = 1 / (eps + torch.sqrt(norm_var))
+            return ((input - norm_mean) * norm_invstd)
+      )SCRIPT";
+          auto module = std::make_shared<script::Module>();
+          defineMethodsInModule(
+              module, source, script::nativeResolver, /*self=*/nullptr);
+          *graph_ptr = module->get_method("batch_norm").graph();
+        },
+        &bn_graph);
+
+    JIT_ASSERT(isFusableBatchNorm(batch_norm));
+    WithInsertPoint insert_guard{batch_norm};
+    Value* input = batch_norm->namedInput(attr::input);
+    Value* input_dim = graph_->insert(aten::dim, {input});
+    std::vector<Value*> inputs{input,
+                               batch_norm->namedInput(attr::running_mean),
+                               batch_norm->namedInput(attr::running_var),
+                               batch_norm->namedInput(attr::training),
+                               batch_norm->namedInput(attr::momentum),
+                               batch_norm->namedInput(attr::eps)};
+    Value* new_output =
+        SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
+    auto weight = batch_norm->namedInput(attr::weight);
+    auto bias = batch_norm->namedInput(attr::bias);
+    if (isDefined(weight).value()) {
+      Value* expanded_weight =
+          graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
+      new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
+    }
+    if (isDefined(bias).value()) {
+      Value* expanded_bias =
+          graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
+      new_output = graph_->insert(aten::add, {new_output, expanded_bias});
+    }
+    batch_norm->output()->replaceAllUsesWith(new_output);
+    batch_norm->destroy();
+  }
+
   void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
     // Now we have two fusion groups!
     // Revert the fusion - place all inner nodes of producer back in the outer
@@ -349,10 +449,7 @@ struct GraphFuser {
     *insertion_point = n;
   }
 
-  at::optional<Node*> tryFuse(
-      Node* consumer,
-      Value* producer,
-      const AliasDb& aliasDb) {
+  at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
     // this handles cases where producer can be moved _into_ the fusion group of
     // consumer.
     // TODO: extend to fusion of consumer into _producer's_ fusion blob
@@ -368,7 +465,8 @@ struct GraphFuser {
         // consumer. Fusion will rewrite those later uses to use the version of
         // producer generated by the fused blob. In this case, producer becomes
         // an output of the fusion group.
-        producer->node()->moveBeforeTopologicallyValid(real_consumer, aliasDb);
+        producer->node()->moveBeforeTopologicallyValid(
+            real_consumer, aliasDb_.value());
 
     if (!shouldFuse) {
       return at::nullopt;
@@ -396,6 +494,14 @@ struct GraphFuser {
     } else if (consumer->kind() != prim::FusionGroup) {
       group = createSingletonFusionGroup(consumer);
     }
+    if (producer->node()->matches(
+            "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
+      // We don't do any fusions in here, but simply decompose the batch norm
+      // into a kernel that computes the stats + pointwise ops which will be
+      // considered in this fusion next.
+      decomposeBatchNorm(producer->node());
+      return group;
+    }
     if (producer->node()->kind() == prim::FusionGroup) {
       mergeFusionGroups(group, producer->node());
       return group;
@@ -649,7 +755,7 @@ struct GraphFuser {
         chunk->inputs().begin(),
         chunk->inputs().end(),
         [&](Value* producer_for_chunk) {
-          return isFusable(producer_for_chunk->node()) &&
+          return isFusableMap(producer_for_chunk->node()) &&
               allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
         });
     if (it == chunk->inputs().end()) {
@@ -777,9 +883,7 @@ struct GraphFuser {
   }
 
   // returns where to continue scanning, and whether any fusion was made
-  std::pair<graph_node_list::iterator, bool> scanNode(
-      Node* consumer,
-      const AliasDb& aliasDb) {
+  std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
     if (isFusableAsExitNode(consumer)) {
       auto consumer_inputs = consumer->kind() == aten::cat
           ? consumer->namedInput(attr::tensors)->node()->inputs()
@@ -797,7 +901,7 @@ struct GraphFuser {
           // we scan this consumer again to perform the fusion
           return std::make_pair(consumer->reverseIterator(), true);
         }
-        auto fusion_group = tryFuse(consumer, producer, aliasDb);
+        auto fusion_group = tryFuse(consumer, producer);
         if (fusion_group) {
           // after fusion, consumer moves into a FusionGroup, so inputs is no
           // longer valid so we rescan the new FusionGroup for more fusions...
@@ -948,6 +1052,10 @@ struct GraphFuser {
     }
   }
 
+  void refreshAliasDb() {
+    aliasDb_ = AliasAnalysis(graph_);
+  }
+
   void run() {
     // Run the pass until no changes are made.
     // This is neccessary, because the algorithm can miss out on certain fusion
@@ -968,10 +1076,10 @@ struct GraphFuser {
     bool any_changed = true;
     while (any_changed) {
       any_changed = false;
-      auto aliasDb = AliasAnalysis(graph_);
+      refreshAliasDb();
       for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
         bool changed;
-        std::tie(it, changed) = scanNode(*it, aliasDb);
+        std::tie(it, changed) = scanNode(*it);
         any_changed |= changed;
       }
     }
index 5266cf4..77f3dbf 100644 (file)
@@ -4,72 +4,50 @@ namespace torch {
 namespace jit {
 namespace SubgraphUtils {
 namespace {
-bool isSubgraphNodeKind(Symbol s) {
-  return s == prim::DifferentiableGraph || s == prim::FusionGroup;
-}
 
-bool isSubgraphNodeKind(Node* n) {
-  return isSubgraphNodeKind(n->kind());
+bool hasSubgraph(Node* n) {
+  return n->hasAttribute(attr::Subgraph);
 }
 
 // Combine the nodes in two subgraph together. The nodes will end up in
 // `mergeTo`, and `mergeFrom` is destroyed.
 void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
-  const auto nodes = unmergeSubgraph(mergeFrom);
-  for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
-    mergeNodeIntoSubgraph(*it, mergeTo);
+  Node* nodeBeforeMergeFrom = mergeFrom->prev();
+  Node* nodeAfterMergeFrom = mergeFrom->next();
+  unmergeSubgraph(mergeFrom);
+  std::vector<Node*> nodes;
+  const auto end_it = nodeBeforeMergeFrom->reverseIterator();
+  auto it = nodeAfterMergeFrom->reverseIterator();
+  ++it;
+  while (it != end_it) {
+    // NB: mergeNodeIntoSubgraph destroys node, hence the complications
+    Node* node = *it;
+    ++it;
+    mergeNodeIntoSubgraph(node, mergeTo);
   }
 }
 } // namespace
 
 std::shared_ptr<Graph> getSubgraph(Node* n) {
-  JIT_ASSERT(isSubgraphNodeKind(n));
   return n->g(attr::Subgraph);
 }
 
-std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
+void unmergeSubgraph(Node* subgraphNode) {
   JIT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
-  auto outerGraph = subgraphNode->owningGraph();
-
-  std::vector<Node*> temporary_nodes;
-  auto subgraph = getSubgraph(subgraphNode);
-
-  // Initialize a map of inner graph values to outer graph values
-  std::unordered_map<const Value*, Value*> innerToOuter;
-  const auto innerInputs = subgraph->inputs();
-  const auto outerInputs = subgraphNode->inputs();
-  for (size_t i = 0; i < innerInputs.size(); ++i) {
-    innerToOuter[innerInputs[i]] = outerInputs[i];
-  }
-
-  // Clone all nodes
-  for (auto inner : subgraph->nodes()) {
-    Node* outer = outerGraph->createClone(
-        inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
-    outer->insertBefore(subgraphNode);
-    temporary_nodes.emplace_back(outer);
-    const auto innerOutputs = inner->outputs();
-    const auto outerOutputs = outer->outputs();
-    for (size_t i = 0; i < innerOutputs.size(); ++i) {
-      innerToOuter[innerOutputs[i]] = outerOutputs[i];
-    }
-  }
 
-  // Replace uses of group outputs and destroy the group
-  const auto subgraphOutputs = subgraph->outputs();
+  // Inline the graph, replace uses of node outputs and destroy the node
+  const auto subgraphOutputs = inlineGraph(
+      getSubgraph(subgraphNode), subgraphNode->inputs(), subgraphNode);
   JIT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
   for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
-    const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
-    subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
+    subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
   }
   subgraphNode->destroy();
-
-  return temporary_nodes;
 }
 
 void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
-  JIT_ASSERT(isSubgraphNodeKind(subgraphNode));
-  if (isSubgraphNodeKind(toMerge)) {
+  JIT_ASSERT(hasSubgraph(subgraphNode));
+  if (hasSubgraph(toMerge)) {
     return mergeSubgraph(subgraphNode, toMerge);
   }
 
@@ -150,8 +128,40 @@ void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
   toMerge->destroy();
 }
 
+// Invariant we depend on in mergeSubgraph: All inlined nodes are created
+// between the node preceding insertBefore and insertBefore.
+std::vector<Value*> inlineGraph(
+    const std::shared_ptr<Graph>& subgraph,
+    at::ArrayRef<Value*> outerInputs,
+    Node* insertBefore) {
+  auto outerGraph = insertBefore->owningGraph();
+
+  // Initialize a map of inner graph values to outer graph values
+  std::unordered_map<const Value*, Value*> innerToOuter;
+  const auto innerInputs = subgraph->inputs();
+  JIT_ASSERT(outerInputs.size() == innerInputs.size());
+  for (size_t i = 0; i < innerInputs.size(); ++i) {
+    innerToOuter[innerInputs[i]] = outerInputs[i];
+  }
+
+  // Clone all nodes
+  for (auto inner : subgraph->nodes()) {
+    Node* outer = outerGraph->createClone(
+        inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
+    outer->insertBefore(insertBefore);
+    const auto innerOutputs = inner->outputs();
+    const auto outerOutputs = outer->outputs();
+    for (size_t i = 0; i < innerOutputs.size(); ++i) {
+      innerToOuter[innerOutputs[i]] = outerOutputs[i];
+    }
+  }
+
+  return fmap(subgraph->outputs(), [&](Value* output) {
+    return innerToOuter.at(output);
+  });
+}
+
 Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
-  JIT_ASSERT(isSubgraphNodeKind(subgraphKind));
   auto graph = n->owningGraph();
   auto subgraph = graph->create(subgraphKind, 0);
   subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
@@ -159,6 +169,7 @@ Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
   mergeNodeIntoSubgraph(n, subgraph);
   return subgraph;
 }
+
 } // namespace SubgraphUtils
 } // namespace jit
 } // namespace torch
index dc81902..4d0c449 100644 (file)
@@ -26,11 +26,16 @@ void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
 
 // Move nodes from a subgraph node to the outer graph.
 // `subgraphNode` is destroyed.
-std::vector<Node*> unmergeSubgraph(Node* subgraphNode);
+void unmergeSubgraph(Node* subgraphNode);
 
 // Convenience function
 std::shared_ptr<Graph> getSubgraph(Node* n);
 
+std::vector<Value*> inlineGraph(
+    const std::shared_ptr<Graph>& subgraph,
+    at::ArrayRef<Value*> outerInputs,
+    Node* insertBefore);
+
 } // namespace SubgraphUtils
 } // namespace jit
 } // namespace torch