From 5d1086ae98ccfe691161ff50c93036d432866741 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Tue, 3 Apr 2018 08:59:08 -0700 Subject: [PATCH] cholesky_solve_with_broadcast, matrix_solve_with_broadcast and matrix_triangular_solve_with_broadcast added to linear_operator_util.py PiperOrigin-RevId: 191447378 --- .../linalg/linear_operator_util_test.py | 136 +++++++++++++++++++-- .../python/ops/linalg/linear_operator_util.py | 85 +++++++++++-- 2 files changed, 202 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py index e1edffc..7b291e2 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.platform import test @@ -94,8 +95,8 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase): class BroadcastMatrixBatchDimsTest(test.TestCase): def test_zero_batch_matrices_returned_as_empty_list(self): - self.assertAllEqual( - [], linear_operator_util.broadcast_matrix_batch_dims([])) + self.assertAllEqual([], + linear_operator_util.broadcast_matrix_batch_dims([])) def test_one_batch_matrix_returned_after_tensor_conversion(self): arr = rng.rand(2, 3, 4) @@ -194,6 +195,44 @@ class BroadcastMatrixBatchDimsTest(test.TestCase): linear_operator_util.broadcast_matrix_batch_dims([y, x]) +class CholeskySolveWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + chol = rng.rand(3, 3) + rhs = rng.rand(2, 3, 7) + chol_broadcast = chol + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.cholesky_solve(chol_broadcast, rhs) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_64bit(self): + # batch_shape = [2, 2] + chol = rng.rand(2, 3, 3) + rhs = rng.rand(2, 1, 3, 7) + chol_broadcast = chol + np.zeros((2, 2, 1, 1)) + rhs_broadcast = rhs + np.zeros((2, 2, 1, 1)) + + chol_ph = array_ops.placeholder(dtypes.float64) + rhs_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [ + linear_operator_util.cholesky_solve_with_broadcast( + chol_ph, rhs_ph), + linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast) + ], + feed_dict={ + chol_ph: chol, + rhs_ph: rhs, + }) + self.assertAllEqual(expected, result) + + class MatmulWithBroadcastTest(test.TestCase): def test_static_dims_broadcast(self): @@ -209,7 +248,7 @@ class MatmulWithBroadcastTest(test.TestCase): expected = math_ops.matmul(x, y_broadcast) self.assertAllEqual(expected.eval(), result.eval()) - def test_dynamic_dims_broadcast_32bit(self): + def test_dynamic_dims_broadcast_64bit(self): # batch_shape = [2] # for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7 x = rng.rand(2, 1, 3) @@ -221,9 +260,90 @@ class MatmulWithBroadcastTest(test.TestCase): with self.test_session() as sess: result, expected = sess.run( - [linear_operator_util.matmul_with_broadcast(x_ph, y_ph), - math_ops.matmul(x, y_broadcast)], - feed_dict={x_ph: x, y_ph: y}) + [ + linear_operator_util.matmul_with_broadcast(x_ph, y_ph), + math_ops.matmul(x, y_broadcast) + ], + feed_dict={ + x_ph: x, + y_ph: y + }) + self.assertAllEqual(expected, result) + + +class MatrixSolveWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + matrix = rng.rand(3, 3) + rhs = rng.rand(2, 3, 7) + matrix_broadcast = matrix + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.matrix_solve(matrix_broadcast, rhs) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_64bit(self): + # batch_shape = [2, 2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(2, 1, 3, 7) + matrix_broadcast = matrix + np.zeros((2, 2, 1, 1)) + rhs_broadcast = rhs + np.zeros((2, 2, 1, 1)) + + matrix_ph = array_ops.placeholder(dtypes.float64) + rhs_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [ + linear_operator_util.matrix_solve_with_broadcast( + matrix_ph, rhs_ph), + linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast) + ], + feed_dict={ + matrix_ph: matrix, + rhs_ph: rhs, + }) + self.assertAllEqual(expected, result) + + +class MatrixTriangularSolveWithBroadcastTest(test.TestCase): + + def test_static_dims_broadcast(self): + # batch_shape = [2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(3, 7) + rhs_broadcast = rhs + np.zeros((2, 1, 1)) + + with self.test_session(): + result = linear_operator_util.matrix_triangular_solve_with_broadcast( + matrix, rhs) + self.assertAllEqual((2, 3, 7), result.get_shape()) + expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) + self.assertAllEqual(expected.eval(), result.eval()) + + def test_dynamic_dims_broadcast_64bit(self): + # batch_shape = [2] + matrix = rng.rand(2, 3, 3) + rhs = rng.rand(3, 7) + rhs_broadcast = rhs + np.zeros((2, 1, 1)) + + matrix_ph = array_ops.placeholder(dtypes.float64) + rhs_ph = array_ops.placeholder(dtypes.float64) + + with self.test_session() as sess: + result, expected = sess.run( + [ + linear_operator_util.matrix_triangular_solve_with_broadcast( + matrix_ph, rhs_ph), + linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast) + ], + feed_dict={ + matrix_ph: matrix, + rhs_ph: rhs, + }) self.assertAllEqual(expected, result) @@ -244,7 +364,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase): operator = DomainDimensionStubOperator(3) # Should not raise linear_operator_util.assert_compatible_matrix_dimensions( - operator, x).run() + operator, x).run() # pyformat: disable def test_incompatible_dimensions_raise(self): with self.test_session(): @@ -252,7 +372,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase): operator = DomainDimensionStubOperator(3) with self.assertRaisesOpError("Incompatible matrix dimensions"): linear_operator_util.assert_compatible_matrix_dimensions( - operator, x).run() + operator, x).run() # pyformat: disable if __name__ == "__main__": diff --git a/tensorflow/python/ops/linalg/linear_operator_util.py b/tensorflow/python/ops/linalg/linear_operator_util.py index 427bd1e..9dd4076 100644 --- a/tensorflow/python/ops/linalg/linear_operator_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_util.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -102,6 +103,22 @@ def assert_is_batch_matrix(tensor): "%s" % tensor) +def shape_tensor(shape, name=None): + """Convert Tensor using default type, unless empty list or tuple.""" + # Works just like random_ops._ShapeTensor. + if isinstance(shape, (tuple, list)) and not shape: + dtype = dtypes.int32 + else: + dtype = None + return ops.convert_to_tensor(shape, dtype=dtype, name=name) + + +################################################################################ +# Broadcasting versions of common linear algebra functions. +# TODO(b/77519145) Do this more efficiently in some special cases. +################################################################################ + + def broadcast_matrix_batch_dims(batch_matrices, name=None): """Broadcast leading dimensions of zero or more [batch] matrices. @@ -170,7 +187,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None): bcast_batch_shape = batch_matrices[0].get_shape()[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_static_shape( - bcast_batch_shape, mat.get_shape()[:-2]) + bcast_batch_shape, + mat.get_shape()[:-2]) if bcast_batch_shape.is_fully_defined(): # The [1, 1] at the end will broadcast with anything. bcast_shape = bcast_batch_shape.concatenate([1, 1]) @@ -183,7 +201,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None): bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_dynamic_shape( - bcast_batch_shape, array_ops.shape(mat)[:-2]) + bcast_batch_shape, + array_ops.shape(mat)[:-2]) bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0) for i, mat in enumerate(batch_matrices): batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) @@ -195,6 +214,13 @@ def _broadcast_to_shape(x, shape): return x + array_ops.zeros(shape=shape, dtype=x.dtype) +def cholesky_solve_with_broadcast(chol, rhs, name=None): + """Solve systems of linear equations.""" + with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]): + chol, rhs = broadcast_matrix_batch_dims([chol, rhs]) + return linalg_ops.cholesky_solve(chol, rhs) + + def matmul_with_broadcast(a, b, transpose_a=False, @@ -206,6 +232,11 @@ def matmul_with_broadcast(a, name=None): """Multiplies matrix `a` by matrix `b`, producing `a @ b`. + Works identically to `tf.matmul`, but broadcasts batch dims + of `a` and `b` (by replicating) if they are determined statically to be + different, or if static shapes are not fully defined. Thus, this may result + in an inefficient replication of data. + The inputs must be matrices (or tensors of rank > 2, representing batches of matrices). @@ -276,7 +307,7 @@ def matmul_with_broadcast(a, ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b are both set to True. """ - with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name: + with ops.name_scope(name, "MatMulWithBroadcast", [a, b]): a, b = broadcast_matrix_batch_dims([a, b]) return math_ops.matmul( a, @@ -289,11 +320,43 @@ def matmul_with_broadcast(a, b_is_sparse=b_is_sparse) -def shape_tensor(shape, name=None): - """Convert Tensor using default type, unless empty list or tuple.""" - # Works just like random_ops._ShapeTensor. - if isinstance(shape, (tuple, list)) and not shape: - dtype = dtypes.int32 - else: - dtype = None - return ops.convert_to_tensor(shape, dtype=dtype, name=name) +def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): + """Solve systems of linear equations.""" + with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): + matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) + return linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint) + + +def matrix_triangular_solve_with_broadcast(matrix, + rhs, + lower=True, + adjoint=False, + name=None): + """Solves triangular systems of linear equations with by backsubstitution. + + Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims + of `matrix` and `rhs` (by replicating) if they are determined statically to be + different, or if static shapes are not fully defined. Thus, this may result + in an inefficient replication of data. + + Args: + matrix: A Tensor. Must be one of the following types: + `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`. + rhs: A `Tensor`. Must have the same `dtype` as `matrix`. + Shape is `[..., M, K]`. + lower: An optional `bool`. Defaults to `True`. Indicates whether the + innermost matrices in `matrix` are lower or upper triangular. + adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve + with matrix or its (block-wise) adjoint. + name: A name for the operation (optional). + + Returns: + `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`. + """ + with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]): + matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) + return linalg_ops.matrix_triangular_solve( + matrix, + rhs, + lower=lower, + adjoint=adjoint) -- 2.7.4