implement floordiv with correct integer and division by 0 semantics (#15813)
authorZachary DeVito <zdevito@fb.com>
Tue, 8 Jan 2019 21:09:11 +0000 (13:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 21:44:18 +0000 (13:44 -0800)
Summary:
fixes #15768
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15813

Differential Revision: D13594872

Pulled By: zdevito

fbshipit-source-id: c6c78c9e17fb16ec2bdc42402d203592cf35b7db

test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp

index 1374f1e..0f43f48 100644 (file)
@@ -4437,6 +4437,19 @@ a")
             self.checkScript(div_int_nofuture, (), optimize=True)
             self.checkScript(div_float_nofuture, (), optimize=True)
 
+    def test_floor_div(self):
+        @torch.jit.script
+        def foo(a, b):
+            # type: (int, int) -> int
+            return a // b
+        for i in range(-8, 8):
+            for j in range(-8, 8):
+                if j != 0:
+                    self.assertEqual(foo(i, j), i // j)
+                else:
+                    with self.assertRaisesRegex(RuntimeError, 'division by 0'):
+                        foo(i, j)
+
     def test_number_augassign(self):
         def func():
             z = 1
index af36268..00c5d91 100644 (file)
@@ -72,6 +72,20 @@ Operation listConstruct(int64_t num_inputs) {
   };
 }
 
+static int64_t floordiv(int64_t a, int64_t b) {
+  if (b == 0) {
+    throw std::runtime_error("division by 0");
+  }
+  if ((a > 0) == (b > 0)) {
+    // simple case, both have same sign
+    return a / b;
+  } else {
+    // in python division rounds down, it doesnt not truncate like in c++
+    auto r = lldiv(a,  b);
+    return (r.rem) ? r.quot - 1 : r.quot;
+  }
+}
+
 RegisterOperators reg({
     Operator(
         prim::FusionGroup,
@@ -1182,12 +1196,9 @@ RegisterOperators reg2({
         float),
     DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float),
 
-    // in c++ int division rounds to the integer closer to 0, in python floordiv
-    // rounds to lower integer
     DEFINE_GENERIC_OP(
         aten::floordiv,
-        static_cast<int64_t>(
-            std::floor(static_cast<double>(a) / static_cast<double>(b))),
+        floordiv(a, b),
         std::floor(a / b),
         int,
         float),