Add cache for _zeros in backprop
authorAkshay Modi <nareshmodi@google.com>
Tue, 13 Feb 2018 20:10:40 +0000 (12:10 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 20:14:29 +0000 (12:14 -0800)
PiperOrigin-RevId: 185567508

tensorflow/python/eager/backprop.py
tensorflow/python/framework/test_util.py

index d76fecf..d8e13d7 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import functools
 import operator
 import threading
@@ -42,6 +43,26 @@ from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 
 
+class _TensorCache(object):
+  """Simple cache which evicts items based on length in a FIFO manner."""
+
+  def __init__(self, max_items=256):
+    self._data = collections.OrderedDict()
+    self._max_items = max_items if max_items else 256
+
+  def put(self, key, value):
+    self._data[key] = value
+
+    if len(self._data) > self._max_items:
+      self._data.popitem(last=False)
+
+  def get(self, key):
+    return self._data.get(key, None)
+
+  def flush(self):
+    self._data = {}
+
+
 _op_attr_type_cache = {}
 
 
@@ -734,8 +755,7 @@ def _num_elements(grad):
   raise ValueError("`grad` not a Tensor or IndexedSlices.")
 
 
-_last_zero_shape_dtype = [None, None]
-_last_zero = [None]
+_zeros_cache = _TensorCache()
 
 
 def _fast_fill(value, shape, dtype):
@@ -744,14 +764,17 @@ def _fast_fill(value, shape, dtype):
 
 def _zeros(shape, dtype):
   """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
+  device = context.context().device_name
   if dtype == dtypes.variant:
     # TODO(apassos): need to save enough information about variant tensors to do
     # a zeros
     return None
-  if [shape, dtype] != _last_zero_shape_dtype:
-    _last_zero_shape_dtype[:] = [shape, dtype]
-    _last_zero[0] = _fast_fill(0, shape, dtype)
-  return _last_zero[0]
+  cache_key = shape, dtype, device
+  cached = _zeros_cache.get(cache_key)
+  if cached is None:
+    cached = _fast_fill(0, shape, dtype)
+    _zeros_cache.put(cache_key, cached)
+  return cached
 
 
 def _ones(shape, dtype):
index bfdd988..c09e2d8 100644 (file)
@@ -463,8 +463,7 @@ def assert_no_new_tensors(f):
       f(self, **kwargs)
     # Make an effort to clear caches, which would otherwise look like leaked
     # Tensors.
-    backprop._last_zero = [None]
-    backprop._shape_dtype = [None, None]
+    backprop._zeros_cache.flush()
     context.get_default_context().scalar_cache().clear()
     gc.collect()
     tensors_after = [