From acc83ad54ef85eacd3c7ebfe20ee8225164e5bd1 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 8 Jan 2019 13:09:11 -0800 Subject: [PATCH] implement floordiv with correct integer and division by 0 semantics (#15813) Summary: fixes #15768 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15813 Differential Revision: D13594872 Pulled By: zdevito fbshipit-source-id: c6c78c9e17fb16ec2bdc42402d203592cf35b7db --- test/test_jit.py | 13 +++++++++++++ torch/csrc/jit/register_prim_ops.cpp | 19 +++++++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 1374f1e..0f43f48 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4437,6 +4437,19 @@ a") self.checkScript(div_int_nofuture, (), optimize=True) self.checkScript(div_float_nofuture, (), optimize=True) + def test_floor_div(self): + @torch.jit.script + def foo(a, b): + # type: (int, int) -> int + return a // b + for i in range(-8, 8): + for j in range(-8, 8): + if j != 0: + self.assertEqual(foo(i, j), i // j) + else: + with self.assertRaisesRegex(RuntimeError, 'division by 0'): + foo(i, j) + def test_number_augassign(self): def func(): z = 1 diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index af36268..00c5d91 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -72,6 +72,20 @@ Operation listConstruct(int64_t num_inputs) { }; } +static int64_t floordiv(int64_t a, int64_t b) { + if (b == 0) { + throw std::runtime_error("division by 0"); + } + if ((a > 0) == (b > 0)) { + // simple case, both have same sign + return a / b; + } else { + // in python division rounds down, it doesnt not truncate like in c++ + auto r = lldiv(a, b); + return (r.rem) ? r.quot - 1 : r.quot; + } +} + RegisterOperators reg({ Operator( prim::FusionGroup, @@ -1182,12 +1196,9 @@ RegisterOperators reg2({ float), DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float), - // in c++ int division rounds to the integer closer to 0, in python floordiv - // rounds to lower integer DEFINE_GENERIC_OP( aten::floordiv, - static_cast( - std::floor(static_cast(a) / static_cast(b))), + floordiv(a, b), std::floor(a / b), int, float), -- 2.7.4