from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.platform import test
class BroadcastMatrixBatchDimsTest(test.TestCase):
def test_zero_batch_matrices_returned_as_empty_list(self):
- self.assertAllEqual(
- [], linear_operator_util.broadcast_matrix_batch_dims([]))
+ self.assertAllEqual([],
+ linear_operator_util.broadcast_matrix_batch_dims([]))
def test_one_batch_matrix_returned_after_tensor_conversion(self):
arr = rng.rand(2, 3, 4)
linear_operator_util.broadcast_matrix_batch_dims([y, x])
+class CholeskySolveWithBroadcastTest(test.TestCase):
+
+ def test_static_dims_broadcast(self):
+ # batch_shape = [2]
+ chol = rng.rand(3, 3)
+ rhs = rng.rand(2, 3, 7)
+ chol_broadcast = chol + np.zeros((2, 1, 1))
+
+ with self.test_session():
+ result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
+ self.assertAllEqual((2, 3, 7), result.get_shape())
+ expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
+ self.assertAllEqual(expected.eval(), result.eval())
+
+ def test_dynamic_dims_broadcast_64bit(self):
+ # batch_shape = [2, 2]
+ chol = rng.rand(2, 3, 3)
+ rhs = rng.rand(2, 1, 3, 7)
+ chol_broadcast = chol + np.zeros((2, 2, 1, 1))
+ rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
+
+ chol_ph = array_ops.placeholder(dtypes.float64)
+ rhs_ph = array_ops.placeholder(dtypes.float64)
+
+ with self.test_session() as sess:
+ result, expected = sess.run(
+ [
+ linear_operator_util.cholesky_solve_with_broadcast(
+ chol_ph, rhs_ph),
+ linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
+ ],
+ feed_dict={
+ chol_ph: chol,
+ rhs_ph: rhs,
+ })
+ self.assertAllEqual(expected, result)
+
+
class MatmulWithBroadcastTest(test.TestCase):
def test_static_dims_broadcast(self):
expected = math_ops.matmul(x, y_broadcast)
self.assertAllEqual(expected.eval(), result.eval())
- def test_dynamic_dims_broadcast_32bit(self):
+ def test_dynamic_dims_broadcast_64bit(self):
# batch_shape = [2]
# for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
x = rng.rand(2, 1, 3)
with self.test_session() as sess:
result, expected = sess.run(
- [linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
- math_ops.matmul(x, y_broadcast)],
- feed_dict={x_ph: x, y_ph: y})
+ [
+ linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
+ math_ops.matmul(x, y_broadcast)
+ ],
+ feed_dict={
+ x_ph: x,
+ y_ph: y
+ })
+ self.assertAllEqual(expected, result)
+
+
+class MatrixSolveWithBroadcastTest(test.TestCase):
+
+ def test_static_dims_broadcast(self):
+ # batch_shape = [2]
+ matrix = rng.rand(3, 3)
+ rhs = rng.rand(2, 3, 7)
+ matrix_broadcast = matrix + np.zeros((2, 1, 1))
+
+ with self.test_session():
+ result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
+ self.assertAllEqual((2, 3, 7), result.get_shape())
+ expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
+ self.assertAllEqual(expected.eval(), result.eval())
+
+ def test_dynamic_dims_broadcast_64bit(self):
+ # batch_shape = [2, 2]
+ matrix = rng.rand(2, 3, 3)
+ rhs = rng.rand(2, 1, 3, 7)
+ matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
+ rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
+
+ matrix_ph = array_ops.placeholder(dtypes.float64)
+ rhs_ph = array_ops.placeholder(dtypes.float64)
+
+ with self.test_session() as sess:
+ result, expected = sess.run(
+ [
+ linear_operator_util.matrix_solve_with_broadcast(
+ matrix_ph, rhs_ph),
+ linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
+ ],
+ feed_dict={
+ matrix_ph: matrix,
+ rhs_ph: rhs,
+ })
+ self.assertAllEqual(expected, result)
+
+
+class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
+
+ def test_static_dims_broadcast(self):
+ # batch_shape = [2]
+ matrix = rng.rand(2, 3, 3)
+ rhs = rng.rand(3, 7)
+ rhs_broadcast = rhs + np.zeros((2, 1, 1))
+
+ with self.test_session():
+ result = linear_operator_util.matrix_triangular_solve_with_broadcast(
+ matrix, rhs)
+ self.assertAllEqual((2, 3, 7), result.get_shape())
+ expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
+ self.assertAllEqual(expected.eval(), result.eval())
+
+ def test_dynamic_dims_broadcast_64bit(self):
+ # batch_shape = [2]
+ matrix = rng.rand(2, 3, 3)
+ rhs = rng.rand(3, 7)
+ rhs_broadcast = rhs + np.zeros((2, 1, 1))
+
+ matrix_ph = array_ops.placeholder(dtypes.float64)
+ rhs_ph = array_ops.placeholder(dtypes.float64)
+
+ with self.test_session() as sess:
+ result, expected = sess.run(
+ [
+ linear_operator_util.matrix_triangular_solve_with_broadcast(
+ matrix_ph, rhs_ph),
+ linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
+ ],
+ feed_dict={
+ matrix_ph: matrix,
+ rhs_ph: rhs,
+ })
self.assertAllEqual(expected, result)
operator = DomainDimensionStubOperator(3)
# Should not raise
linear_operator_util.assert_compatible_matrix_dimensions(
- operator, x).run()
+ operator, x).run() # pyformat: disable
def test_incompatible_dimensions_raise(self):
with self.test_session():
operator = DomainDimensionStubOperator(3)
with self.assertRaisesOpError("Incompatible matrix dimensions"):
linear_operator_util.assert_compatible_matrix_dimensions(
- operator, x).run()
+ operator, x).run() # pyformat: disable
if __name__ == "__main__":
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
"%s" % tensor)
+def shape_tensor(shape, name=None):
+ """Convert Tensor using default type, unless empty list or tuple."""
+ # Works just like random_ops._ShapeTensor.
+ if isinstance(shape, (tuple, list)) and not shape:
+ dtype = dtypes.int32
+ else:
+ dtype = None
+ return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+
+
+################################################################################
+# Broadcasting versions of common linear algebra functions.
+# TODO(b/77519145) Do this more efficiently in some special cases.
+################################################################################
+
+
def broadcast_matrix_batch_dims(batch_matrices, name=None):
"""Broadcast leading dimensions of zero or more [batch] matrices.
bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_static_shape(
- bcast_batch_shape, mat.get_shape()[:-2])
+ bcast_batch_shape,
+ mat.get_shape()[:-2])
if bcast_batch_shape.is_fully_defined():
# The [1, 1] at the end will broadcast with anything.
bcast_shape = bcast_batch_shape.concatenate([1, 1])
bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_dynamic_shape(
- bcast_batch_shape, array_ops.shape(mat)[:-2])
+ bcast_batch_shape,
+ array_ops.shape(mat)[:-2])
bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
return x + array_ops.zeros(shape=shape, dtype=x.dtype)
+def cholesky_solve_with_broadcast(chol, rhs, name=None):
+ """Solve systems of linear equations."""
+ with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]):
+ chol, rhs = broadcast_matrix_batch_dims([chol, rhs])
+ return linalg_ops.cholesky_solve(chol, rhs)
+
+
def matmul_with_broadcast(a,
b,
transpose_a=False,
name=None):
"""Multiplies matrix `a` by matrix `b`, producing `a @ b`.
+ Works identically to `tf.matmul`, but broadcasts batch dims
+ of `a` and `b` (by replicating) if they are determined statically to be
+ different, or if static shapes are not fully defined. Thus, this may result
+ in an inefficient replication of data.
+
The inputs must be matrices (or tensors of rank > 2, representing batches of
matrices).
ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
are both set to True.
"""
- with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name:
+ with ops.name_scope(name, "MatMulWithBroadcast", [a, b]):
a, b = broadcast_matrix_batch_dims([a, b])
return math_ops.matmul(
a,
b_is_sparse=b_is_sparse)
-def shape_tensor(shape, name=None):
- """Convert Tensor using default type, unless empty list or tuple."""
- # Works just like random_ops._ShapeTensor.
- if isinstance(shape, (tuple, list)) and not shape:
- dtype = dtypes.int32
- else:
- dtype = None
- return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
+ """Solve systems of linear equations."""
+ with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
+ matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
+ return linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
+
+
+def matrix_triangular_solve_with_broadcast(matrix,
+ rhs,
+ lower=True,
+ adjoint=False,
+ name=None):
+ """Solves triangular systems of linear equations with by backsubstitution.
+
+ Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims
+ of `matrix` and `rhs` (by replicating) if they are determined statically to be
+ different, or if static shapes are not fully defined. Thus, this may result
+ in an inefficient replication of data.
+
+ Args:
+ matrix: A Tensor. Must be one of the following types:
+ `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
+ rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
+ Shape is `[..., M, K]`.
+ lower: An optional `bool`. Defaults to `True`. Indicates whether the
+ innermost matrices in `matrix` are lower or upper triangular.
+ adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
+ with matrix or its (block-wise) adjoint.
+ name: A name for the operation (optional).
+
+ Returns:
+ `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
+ """
+ with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
+ matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
+ return linalg_ops.matrix_triangular_solve(
+ matrix,
+ rhs,
+ lower=lower,
+ adjoint=adjoint)