x = torch.ones((8, 1))
torch.testing.assert_close(eager(x), script(x))
+ def test_batch_norm(self):
+ def test(fn, args):
+ trace = torch.jit.trace(fn, args)
+ self.assertAllFused(trace.graph_for(*args))
+ torch.testing.assert_allclose(fn(*args), trace(*args))
+
+ def bn(i, x):
+ return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu()
+
+ def bn_no_weight(i, x):
+ return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu()
+
+ def bn_no_bias(i, x):
+ return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu()
+
+ def bn_neither(i, x):
+ return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu()
+
+ for device in self.devices:
+ i = torch.randn(4, 16, 32, 40, device=device)
+ x = torch.randn(16, device=device)
+ for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
+ test(fn, (i, x))
+
+
works_list = [
'__radd__',
'__rdiv__',
constant(inputs[7]) // eps
};
+ ExprHandle weight = FloatImm::make(1);
+ ExprHandle bias = FloatImm::make(0);
if (hasWeight) {
- exprInputs.push_back(tensorOrConstant(inputs[1], {c}));
+ weight = tensorOrConstant(inputs[1], {c});
+ exprInputs.push_back(weight);
}
if (hasBias) {
- exprInputs.push_back(tensorOrConstant(inputs[2], {c}));
+ bias = tensorOrConstant(inputs[2], {c});
+ exprInputs.push_back(bias);
}
promoteInputs(exprInputs);
ExprHandle mean = exprInputs[1];
ExprHandle var = exprInputs[2];
ExprHandle eps = exprInputs[3];
- ExprHandle weight = FloatImm::make(1);
- ExprHandle bias = FloatImm::make(0);
-
- if (hasWeight) {
- weight = exprInputs[4];
- }
- // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- if (hasBias) {
- bias = exprInputs[5];
- }
- // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
auto inv_var = rsqrt(var + eps);
auto alpha = inv_var * weight;
auto beta = bias - mean * alpha;