[AUTOTVM] Fix a bug in generating the search space (#4779)
authorwpan11nv <60017475+wpan11nv@users.noreply.github.com>
Wed, 29 Jan 2020 03:40:39 +0000 (19:40 -0800)
committerGitHub <noreply@github.com>
Wed, 29 Jan 2020 03:40:39 +0000 (19:40 -0800)
- Do not use numpy.prod which ignores integer (64 bits) overflows.
  This leads to an incorrect number of points in the search space.

python/tvm/autotvm/task/space.py
tests/python/unittest/test_autotvm_space.py

index f1422bf..d83a248 100644 (file)
@@ -226,7 +226,9 @@ class SplitSpace(TransformSpace):
     def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
         """Generate space by DFS"""
         if now == self.num_output - 1:
-            prod = np.prod(tmp_stack, dtype=np.int64)
+            prod = functools.reduce(lambda x, y: x * y, tmp_stack)
+            if prod > self.product:
+                return
             if self.product % prod == 0 or (not enforce_no_tail and prod < self.product):
                 self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
         else:
index 85d5724..95f3201 100644 (file)
@@ -62,6 +62,21 @@ def test_split():
     cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3)
     assert len(cfg.space_map['tile_c']) == 84
 
+    # Count the number of non-negative integer solutions of a + b + c + d = n
+    def count4(n):
+        cnt = 0
+        for a in range(0, n + 1):
+            for b in range(0, n - a + 1):
+                cnt += n - a - b + 1
+        return cnt
+
+    # test overflow
+    n = 25
+    cfg = ConfigSpace()
+    cfg.define_split('x', cfg.axis(2**n), policy='factors', num_outputs=4)
+    # count4(25) is 3276.
+    assert len(cfg.space_map['x']) == count4(n)
+
     # test fallback
     cfg = FallbackConfigEntity()
     cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)