[nnc] Fix batchnorm implementation (#64112)
authorBert Maher <bertrand@fb.com>
Sun, 29 Aug 2021 02:18:10 +0000 (19:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 29 Aug 2021 02:20:35 +0000 (19:20 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64112

Fixes #64062

Test Plan: Imported from OSS

Reviewed By: zhxchen17

Differential Revision: D30622897

Pulled By: bertmaher

fbshipit-source-id: 7d7c6131aa786e61fa1d0a517288396a0bdb1d22

test/test_jit_fuser_te.py
torch/csrc/jit/tensorexpr/operators/norm.cpp

index 014f142..6d2432a 100644 (file)
@@ -1912,6 +1912,31 @@ class TestTEFuser(JitTestCase):
             x = torch.ones((8, 1))
             torch.testing.assert_close(eager(x), script(x))
 
+    def test_batch_norm(self):
+        def test(fn, args):
+            trace = torch.jit.trace(fn, args)
+            self.assertAllFused(trace.graph_for(*args))
+            torch.testing.assert_allclose(fn(*args), trace(*args))
+
+        def bn(i, x):
+            return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu()
+
+        def bn_no_weight(i, x):
+            return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu()
+
+        def bn_no_bias(i, x):
+            return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu()
+
+        def bn_neither(i, x):
+            return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu()
+
+        for device in self.devices:
+            i = torch.randn(4, 16, 32, 40, device=device)
+            x = torch.randn(16, device=device)
+            for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
+                test(fn, (i, x))
+
+
 works_list = [
     '__radd__',
     '__rdiv__',
index 610f928..2e19d73 100644 (file)
@@ -38,11 +38,15 @@ Tensor computeBatchNorm(
             constant(inputs[7]) // eps
         };
 
+        ExprHandle weight = FloatImm::make(1);
+        ExprHandle bias = FloatImm::make(0);
         if (hasWeight) {
-          exprInputs.push_back(tensorOrConstant(inputs[1], {c}));
+          weight = tensorOrConstant(inputs[1], {c});
+          exprInputs.push_back(weight);
         }
         if (hasBias) {
-          exprInputs.push_back(tensorOrConstant(inputs[2], {c}));
+          bias = tensorOrConstant(inputs[2], {c});
+          exprInputs.push_back(bias);
         }
         promoteInputs(exprInputs);
 
@@ -50,18 +54,7 @@ Tensor computeBatchNorm(
         ExprHandle mean = exprInputs[1];
         ExprHandle var = exprInputs[2];
         ExprHandle eps = exprInputs[3];
-        ExprHandle weight = FloatImm::make(1);
-        ExprHandle bias = FloatImm::make(0);
-
-        if (hasWeight) {
-          weight = exprInputs[4];
-        }
-        // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
-        if (hasBias) {
-          bias = exprInputs[5];
-        }
 
-        // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
         auto inv_var = rsqrt(var + eps);
         auto alpha = inv_var * weight;
         auto beta = bias - mean * alpha;