self._is_positive_definite = is_positive_definite
self._name = name or type(self).__name__
- # We will cache some tensors to avoid repeatedly adding shape
- # manipulation ops to the graph.
- # Naming convention:
- # self._cached_X_tensor is the cached version of self._X_tensor.
- self._cached_shape_tensor = None
- self._cached_batch_shape_tensor = None
- self._cached_domain_dimension_tensor = None
- self._cached_range_dimension_tensor = None
- self._cached_tensor_rank_tensor = None
-
@contextlib.contextmanager
def _name_scope(self, name=None, values=None):
"""Helper function to standardize op scope."""
`int32` `Tensor`
"""
with self._name_scope(name):
- # Be clean by avoiding adding shape Ops to the graph too many times.
- if self._cached_shape_tensor is None:
- # Prefer to use statically defined shape if available.
- if self.shape.is_fully_defined():
- self._cached_shape_tensor = linear_operator_util.shape_tensor(
- self.shape.as_list())
- else:
- self._cached_shape_tensor = self._shape_tensor()
- return self._cached_shape_tensor
+ # Prefer to use statically defined shape if available.
+ if self.shape.is_fully_defined():
+ return linear_operator_util.shape_tensor(self.shape.as_list())
+ else:
+ return self._shape_tensor()
@property
def batch_shape(self):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_batch_shape_tensor is None:
- # Prefer to use statically defined shape if available.
- if self.batch_shape.is_fully_defined():
- self._cached_batch_shape_tensor = linear_operator_util.shape_tensor(
- self.batch_shape.as_list(), name="batch_shape")
- else:
- self._cached_batch_shape_tensor = self.shape_tensor()[:-2]
- return self._cached_batch_shape_tensor
+ # Prefer to use statically defined shape if available.
+ if self.batch_shape.is_fully_defined():
+ return linear_operator_util.shape_tensor(
+ self.batch_shape.as_list(), name="batch_shape")
+ else:
+ return self.shape_tensor()[:-2]
@property
def tensor_rank(self, name="tensor_rank"):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_tensor_rank_tensor is None:
- # Prefer to use statically defined shape if available.
- if self.tensor_rank is not None:
- self._cached_tensor_rank_tensor = ops.convert_to_tensor(
- self.tensor_rank)
- else:
- self._cached_tensor_rank_tensor = array_ops.size(self.shape_tensor())
- return self._cached_tensor_rank_tensor
+ # Prefer to use statically defined shape if available.
+ if self.tensor_rank is not None:
+ return ops.convert_to_tensor(self.tensor_rank)
+ else:
+ return array_ops.size(self.shape_tensor())
@property
def domain_dimension(self):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_domain_dimension_tensor is None:
- # Prefer to use statically defined shape if available.
- if self.domain_dimension.value is not None:
- self._cached_domain_dimension_tensor = ops.convert_to_tensor(
- self.domain_dimension.value)
- else:
- self._cached_domain_dimension_tensor = self.shape_tensor()[-1]
- return self._cached_domain_dimension_tensor
+ # Prefer to use statically defined shape if available.
+ if self.domain_dimension.value is not None:
+ return ops.convert_to_tensor(self.domain_dimension.value)
+ else:
+ return self.shape_tensor()[-1]
@property
def range_dimension(self):
"""
# Derived classes get this "for free" once .shape() is implemented.
with self._name_scope(name):
- if self._cached_range_dimension_tensor is None:
- # Prefer to use statically defined shape if available.
- if self.range_dimension.value is not None:
- self._cached_range_dimension_tensor = ops.convert_to_tensor(
- self.range_dimension.value)
- else:
- self._cached_range_dimension_tensor = self.shape_tensor()[-2]
- return self._cached_range_dimension_tensor
+ # Prefer to use statically defined shape if available.
+ if self.range_dimension.value is not None:
+ return ops.convert_to_tensor(self.range_dimension.value)
+ else:
+ return self.shape_tensor()[-2]
def _assert_non_singular(self):
"""Private default implementation of _assert_non_singular."""
if self._can_use_cholesky():
return self.assert_positive_definite()
else:
- singular_values = linalg_ops.svd(
- self._get_cached_dense_matrix(), compute_uv=False)
+ singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
# TODO(langmore) Add .eig and .cond as methods.
cond = (math_ops.reduce_max(singular_values, axis=-1) /
math_ops.reduce_min(singular_values, axis=-1))
# and sufficient.
if self.is_self_adjoint:
return check_ops.assert_positive(
- array_ops.matrix_diag_part(self._get_cached_chol()),
+ array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
message="Matrix was not positive definite.")
# We have no generic check for positive definite.
raise NotImplementedError("assert_positive_definite is not implemented.")
return self._assert_positive_definite()
def _assert_self_adjoint(self):
- dense = self._get_cached_dense_matrix()
+ dense = self.to_dense()
logging.warn(
"Using (possibly slow) default implementation of assert_self_adjoint."
" Requires conversion to a dense matrix.")
"Using (possibly slow) default implementation of determinant."
" Requires conversion to a dense matrix and O(N^3) operations.")
if self._can_use_cholesky():
- diag = array_ops.matrix_diag_part(self._get_cached_chol())
+ diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
_, log_abs_det = linalg.slogdet(self._matrix)
return log_abs_det
" Requires conversion to a dense matrix and O(N^3) operations.")
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
if self._can_use_cholesky():
- return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
- return linalg_ops.matrix_solve(
- self._get_cached_dense_matrix(), rhs, adjoint=adjoint)
+ return linalg_ops.cholesky_solve(
+ linalg_ops.cholesky(self.to_dense()), rhs)
+ return linalg_ops.matrix_solve(self.to_dense(), rhs, adjoint=adjoint)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
"""Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
def _diag_part(self):
"""Generic and often inefficient implementation. Override often."""
- return array_ops.matrix_diag_part(self._get_cached_dense_matrix())
+ return array_ops.matrix_diag_part(self.to_dense())
def diag_part(self, name="diag_part"):
"""Efficiently get the [batch] diagonal part of this operator.
def _add_to_tensor(self, x):
# Override if a more efficient implementation is available.
- return self._get_cached_dense_matrix() + x
+ return self.to_dense() + x
def add_to_tensor(self, x, name="add_to_tensor"):
"""Add matrix represented by this operator to `x`. Equivalent to `A + x`.
# TODO(langmore) Add complex types when tf.cholesky can use them.
return (not self.dtype.is_complex and self.is_self_adjoint and
self.is_positive_definite)
-
- def _get_cached_dense_matrix(self):
- if not hasattr(self, "_cached_dense_matrix"):
- self._cached_dense_matrix = self.to_dense()
- return self._cached_dense_matrix
-
- def _get_cached_chol(self):
- if not hasattr(self, "_cached_chol"):
- self._cached_chol = linalg_ops.cholesky(self._get_cached_dense_matrix())
- return self._cached_chol