[TOPI] Memoize winograd matrix (#3687)
authorLianmin Zheng <lianminzheng@gmail.com>
Fri, 2 Aug 2019 15:50:33 +0000 (23:50 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 2 Aug 2019 15:50:33 +0000 (08:50 -0700)
* [TOPI] Memoize winograd matrix

* lint

* Fix name

python/tvm/contrib/pickle_memoize.py
topi/python/topi/nn/winograd_util.py

index c657971..b5abf9b 100644 (file)
@@ -34,9 +34,11 @@ class Cache(object):
     ----------
     key: str
        The file key to the function
+    save_at_exit: bool
+        Whether save the cache to file when the program exits
     """
     cache_by_key = {}
-    def __init__(self, key):
+    def __init__(self, key, save_at_exit):
         cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
         if not os.path.exists(cache_dir):
             os.mkdir(cache_dir)
@@ -49,6 +51,7 @@ class Cache(object):
         else:
             self.cache = {}
         self.dirty = False
+        self.save_at_exit = save_at_exit
 
     def save(self):
         if self.dirty:
@@ -60,16 +63,19 @@ class Cache(object):
 def _atexit():
     """Save handler."""
     for value in Cache.cache_by_key.values():
-        value.save()
+        if value.save_at_exit:
+            value.save()
 
 
-def memoize(key):
+def memoize(key, save_at_exit=False):
     """Memoize the result of function and reuse multiple times.
 
     Parameters
     ----------
     key: str
         The unique key to the file
+    save_at_exit: bool
+        Whether save the cache to file when the program exits
 
     Returns
     -------
@@ -81,9 +87,9 @@ def memoize(key):
         allow_types = (string_types, int, float)
         fkey = key + "." + f.__name__ + ".pkl"
         if fkey not in Cache.cache_by_key:
-            Cache.cache_by_key[fkey] = Cache(fkey)
+            Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
         cache = Cache.cache_by_key[fkey]
-        cargs = tuple(x.cell_contents for x in f.__closure__)
+        cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
         cargs = (len(cargs),) + cargs
 
         def _memoized_f(func, *args, **kwargs):
index db57f76..464b633 100644 (file)
@@ -25,6 +25,7 @@
 from operator import mul
 from functools import reduce
 import numpy as np
+from tvm.contrib.pickle_memoize import memoize
 from ..util import const_matrix
 
 
@@ -131,6 +132,8 @@ def _interpolation_points(degree):
 
     return np.array(in_pts[degree-1], dtype=np.float64)
 
+
+@memoize("topi.nn.winograd_matrices", save_at_exit=False)
 def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
     """Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
     """