Remove cached tensors from LinearOperator.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Mar 2018 23:05:22 +0000 (16:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 23:08:16 +0000 (16:08 -0700)
CSE should already ensure that tensors are deduped (i.e. multiple Cholesky's for the same matrix are not performed).

This also makes it easier to use LinearOperators in tf.while_loop.

PiperOrigin-RevId: 190141732

tensorflow/python/ops/linalg/linear_operator.py

index 957a795..c7513d5 100644 (file)
@@ -204,16 +204,6 @@ class LinearOperator(object):
     self._is_positive_definite = is_positive_definite
     self._name = name or type(self).__name__
 
-    # We will cache some tensors to avoid repeatedly adding shape
-    # manipulation ops to the graph.
-    # Naming convention:
-    #   self._cached_X_tensor is the cached version of self._X_tensor.
-    self._cached_shape_tensor = None
-    self._cached_batch_shape_tensor = None
-    self._cached_domain_dimension_tensor = None
-    self._cached_range_dimension_tensor = None
-    self._cached_tensor_rank_tensor = None
-
   @contextlib.contextmanager
   def _name_scope(self, name=None, values=None):
     """Helper function to standardize op scope."""
@@ -299,15 +289,11 @@ class LinearOperator(object):
       `int32` `Tensor`
     """
     with self._name_scope(name):
-      # Be clean by avoiding adding shape Ops to the graph too many times.
-      if self._cached_shape_tensor is None:
-        # Prefer to use statically defined shape if available.
-        if self.shape.is_fully_defined():
-          self._cached_shape_tensor = linear_operator_util.shape_tensor(
-              self.shape.as_list())
-        else:
-          self._cached_shape_tensor = self._shape_tensor()
-      return self._cached_shape_tensor
+      # Prefer to use statically defined shape if available.
+      if self.shape.is_fully_defined():
+        return linear_operator_util.shape_tensor(self.shape.as_list())
+      else:
+        return self._shape_tensor()
 
   @property
   def batch_shape(self):
@@ -338,14 +324,12 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_batch_shape_tensor is None:
-        # Prefer to use statically defined shape if available.
-        if self.batch_shape.is_fully_defined():
-          self._cached_batch_shape_tensor = linear_operator_util.shape_tensor(
-              self.batch_shape.as_list(), name="batch_shape")
-        else:
-          self._cached_batch_shape_tensor = self.shape_tensor()[:-2]
-      return self._cached_batch_shape_tensor
+      # Prefer to use statically defined shape if available.
+      if self.batch_shape.is_fully_defined():
+        return linear_operator_util.shape_tensor(
+            self.batch_shape.as_list(), name="batch_shape")
+      else:
+        return self.shape_tensor()[:-2]
 
   @property
   def tensor_rank(self, name="tensor_rank"):
@@ -378,14 +362,11 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_tensor_rank_tensor is None:
-        # Prefer to use statically defined shape if available.
-        if self.tensor_rank is not None:
-          self._cached_tensor_rank_tensor = ops.convert_to_tensor(
-              self.tensor_rank)
-        else:
-          self._cached_tensor_rank_tensor = array_ops.size(self.shape_tensor())
-      return self._cached_tensor_rank_tensor
+      # Prefer to use statically defined shape if available.
+      if self.tensor_rank is not None:
+        return ops.convert_to_tensor(self.tensor_rank)
+      else:
+        return array_ops.size(self.shape_tensor())
 
   @property
   def domain_dimension(self):
@@ -416,14 +397,11 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_domain_dimension_tensor is None:
-        # Prefer to use statically defined shape if available.
-        if self.domain_dimension.value is not None:
-          self._cached_domain_dimension_tensor = ops.convert_to_tensor(
-              self.domain_dimension.value)
-        else:
-          self._cached_domain_dimension_tensor = self.shape_tensor()[-1]
-      return self._cached_domain_dimension_tensor
+      # Prefer to use statically defined shape if available.
+      if self.domain_dimension.value is not None:
+        return ops.convert_to_tensor(self.domain_dimension.value)
+      else:
+        return self.shape_tensor()[-1]
 
   @property
   def range_dimension(self):
@@ -454,14 +432,11 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_range_dimension_tensor is None:
-        # Prefer to use statically defined shape if available.
-        if self.range_dimension.value is not None:
-          self._cached_range_dimension_tensor = ops.convert_to_tensor(
-              self.range_dimension.value)
-        else:
-          self._cached_range_dimension_tensor = self.shape_tensor()[-2]
-      return self._cached_range_dimension_tensor
+      # Prefer to use statically defined shape if available.
+      if self.range_dimension.value is not None:
+        return ops.convert_to_tensor(self.range_dimension.value)
+      else:
+        return self.shape_tensor()[-2]
 
   def _assert_non_singular(self):
     """Private default implementation of _assert_non_singular."""
@@ -471,8 +446,7 @@ class LinearOperator(object):
     if self._can_use_cholesky():
       return self.assert_positive_definite()
     else:
-      singular_values = linalg_ops.svd(
-          self._get_cached_dense_matrix(), compute_uv=False)
+      singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
       # TODO(langmore) Add .eig and .cond as methods.
       cond = (math_ops.reduce_max(singular_values, axis=-1) /
               math_ops.reduce_min(singular_values, axis=-1))
@@ -524,7 +498,7 @@ class LinearOperator(object):
     # and sufficient.
     if self.is_self_adjoint:
       return check_ops.assert_positive(
-          array_ops.matrix_diag_part(self._get_cached_chol()),
+          array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
           message="Matrix was not positive definite.")
     # We have no generic check for positive definite.
     raise NotImplementedError("assert_positive_definite is not implemented.")
@@ -547,7 +521,7 @@ class LinearOperator(object):
       return self._assert_positive_definite()
 
   def _assert_self_adjoint(self):
-    dense = self._get_cached_dense_matrix()
+    dense = self.to_dense()
     logging.warn(
         "Using (possibly slow) default implementation of assert_self_adjoint."
         "  Requires conversion to a dense matrix.")
@@ -692,7 +666,7 @@ class LinearOperator(object):
         "Using (possibly slow) default implementation of determinant."
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     if self._can_use_cholesky():
-      diag = array_ops.matrix_diag_part(self._get_cached_chol())
+      diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
       return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
     _, log_abs_det = linalg.slogdet(self._matrix)
     return log_abs_det
@@ -726,9 +700,9 @@ class LinearOperator(object):
         "  Requires conversion to a dense matrix and O(N^3) operations.")
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._can_use_cholesky():
-      return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
-    return linalg_ops.matrix_solve(
-        self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
+      return linalg_ops.cholesky_solve(
+          linalg_ops.cholesky(self.to_dense()), rhs)
+    return linalg_ops.matrix_solve(self.to_dense(), rhs, adjoint=adjoint)
 
   def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
     """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
@@ -866,7 +840,7 @@ class LinearOperator(object):
 
   def _diag_part(self):
     """Generic and often inefficient implementation.  Override often."""
-    return array_ops.matrix_diag_part(self._get_cached_dense_matrix())
+    return array_ops.matrix_diag_part(self.to_dense())
 
   def diag_part(self, name="diag_part"):
     """Efficiently get the [batch] diagonal part of this operator.
@@ -915,7 +889,7 @@ class LinearOperator(object):
 
   def _add_to_tensor(self, x):
     # Override if a more efficient implementation is available.
-    return self._get_cached_dense_matrix() + x
+    return self.to_dense() + x
 
   def add_to_tensor(self, x, name="add_to_tensor"):
     """Add matrix represented by this operator to `x`.  Equivalent to `A + x`.
@@ -936,13 +910,3 @@ class LinearOperator(object):
     # TODO(langmore) Add complex types when tf.cholesky can use them.
     return (not self.dtype.is_complex and self.is_self_adjoint and
             self.is_positive_definite)
-
-  def _get_cached_dense_matrix(self):
-    if not hasattr(self, "_cached_dense_matrix"):
-      self._cached_dense_matrix = self.to_dense()
-    return self._cached_dense_matrix
-
-  def _get_cached_chol(self):
-    if not hasattr(self, "_cached_chol"):
-      self._cached_chol = linalg_ops.cholesky(self._get_cached_dense_matrix())
-    return self._cached_chol