From 4f969db325a7a70878bd3eae5bbb3fecd598d4ca Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Sat, 28 Aug 2021 19:18:10 -0700 Subject: [PATCH] [nnc] Fix batchnorm implementation (#64112) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64112 Fixes #64062 Test Plan: Imported from OSS Reviewed By: zhxchen17 Differential Revision: D30622897 Pulled By: bertmaher fbshipit-source-id: 7d7c6131aa786e61fa1d0a517288396a0bdb1d22 --- test/test_jit_fuser_te.py | 25 +++++++++++++++++++++++++ torch/csrc/jit/tensorexpr/operators/norm.cpp | 19 ++++++------------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 014f142..6d2432a 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1912,6 +1912,31 @@ class TestTEFuser(JitTestCase): x = torch.ones((8, 1)) torch.testing.assert_close(eager(x), script(x)) + def test_batch_norm(self): + def test(fn, args): + trace = torch.jit.trace(fn, args) + self.assertAllFused(trace.graph_for(*args)) + torch.testing.assert_allclose(fn(*args), trace(*args)) + + def bn(i, x): + return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() + + def bn_no_weight(i, x): + return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu() + + def bn_no_bias(i, x): + return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu() + + def bn_neither(i, x): + return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu() + + for device in self.devices: + i = torch.randn(4, 16, 32, 40, device=device) + x = torch.randn(16, device=device) + for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: + test(fn, (i, x)) + + works_list = [ '__radd__', '__rdiv__', diff --git a/torch/csrc/jit/tensorexpr/operators/norm.cpp b/torch/csrc/jit/tensorexpr/operators/norm.cpp index 610f928..2e19d73 100644 --- a/torch/csrc/jit/tensorexpr/operators/norm.cpp +++ b/torch/csrc/jit/tensorexpr/operators/norm.cpp @@ -38,11 +38,15 @@ Tensor computeBatchNorm( constant(inputs[7]) // eps }; + ExprHandle weight = FloatImm::make(1); + ExprHandle bias = FloatImm::make(0); if (hasWeight) { - exprInputs.push_back(tensorOrConstant(inputs[1], {c})); + weight = tensorOrConstant(inputs[1], {c}); + exprInputs.push_back(weight); } if (hasBias) { - exprInputs.push_back(tensorOrConstant(inputs[2], {c})); + bias = tensorOrConstant(inputs[2], {c}); + exprInputs.push_back(bias); } promoteInputs(exprInputs); @@ -50,18 +54,7 @@ Tensor computeBatchNorm( ExprHandle mean = exprInputs[1]; ExprHandle var = exprInputs[2]; ExprHandle eps = exprInputs[3]; - ExprHandle weight = FloatImm::make(1); - ExprHandle bias = FloatImm::make(0); - - if (hasWeight) { - weight = exprInputs[4]; - } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - if (hasBias) { - bias = exprInputs[5]; - } - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) auto inv_var = rsqrt(var + eps); auto alpha = inv_var * weight; auto beta = bias - mean * alpha; -- 2.7.4