from __future__ import division
from __future__ import print_function
+import collections
import functools
import operator
import threading
from tensorflow.python.util import tf_inspect
+class _TensorCache(object):
+ """Simple cache which evicts items based on length in a FIFO manner."""
+
+ def __init__(self, max_items=256):
+ self._data = collections.OrderedDict()
+ self._max_items = max_items if max_items else 256
+
+ def put(self, key, value):
+ self._data[key] = value
+
+ if len(self._data) > self._max_items:
+ self._data.popitem(last=False)
+
+ def get(self, key):
+ return self._data.get(key, None)
+
+ def flush(self):
+ self._data = {}
+
+
_op_attr_type_cache = {}
raise ValueError("`grad` not a Tensor or IndexedSlices.")
-_last_zero_shape_dtype = [None, None]
-_last_zero = [None]
+_zeros_cache = _TensorCache()
def _fast_fill(value, shape, dtype):
def _zeros(shape, dtype):
"""Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
+ device = context.context().device_name
if dtype == dtypes.variant:
# TODO(apassos): need to save enough information about variant tensors to do
# a zeros
return None
- if [shape, dtype] != _last_zero_shape_dtype:
- _last_zero_shape_dtype[:] = [shape, dtype]
- _last_zero[0] = _fast_fill(0, shape, dtype)
- return _last_zero[0]
+ cache_key = shape, dtype, device
+ cached = _zeros_cache.get(cache_key)
+ if cached is None:
+ cached = _fast_fill(0, shape, dtype)
+ _zeros_cache.put(cache_key, cached)
+ return cached
def _ones(shape, dtype):