cholesky_solve_with_broadcast, matrix_solve_with_broadcast and matrix_triangular_solv...
authorIan Langmore <langmore@google.com>
Tue, 3 Apr 2018 15:59:08 +0000 (08:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 16:01:35 +0000 (09:01 -0700)
PiperOrigin-RevId: 191447378

tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
tensorflow/python/ops/linalg/linear_operator_util.py

index e1edffc..7b291e2 100644 (file)
@@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.linalg import linear_operator_util
 from tensorflow.python.platform import test
@@ -94,8 +95,8 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
 class BroadcastMatrixBatchDimsTest(test.TestCase):
 
   def test_zero_batch_matrices_returned_as_empty_list(self):
-    self.assertAllEqual(
-        [], linear_operator_util.broadcast_matrix_batch_dims([]))
+    self.assertAllEqual([],
+                        linear_operator_util.broadcast_matrix_batch_dims([]))
 
   def test_one_batch_matrix_returned_after_tensor_conversion(self):
     arr = rng.rand(2, 3, 4)
@@ -194,6 +195,44 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
       linear_operator_util.broadcast_matrix_batch_dims([y, x])
 
 
+class CholeskySolveWithBroadcastTest(test.TestCase):
+
+  def test_static_dims_broadcast(self):
+    # batch_shape = [2]
+    chol = rng.rand(3, 3)
+    rhs = rng.rand(2, 3, 7)
+    chol_broadcast = chol + np.zeros((2, 1, 1))
+
+    with self.test_session():
+      result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
+      self.assertAllEqual((2, 3, 7), result.get_shape())
+      expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
+      self.assertAllEqual(expected.eval(), result.eval())
+
+  def test_dynamic_dims_broadcast_64bit(self):
+    # batch_shape = [2, 2]
+    chol = rng.rand(2, 3, 3)
+    rhs = rng.rand(2, 1, 3, 7)
+    chol_broadcast = chol + np.zeros((2, 2, 1, 1))
+    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
+
+    chol_ph = array_ops.placeholder(dtypes.float64)
+    rhs_ph = array_ops.placeholder(dtypes.float64)
+
+    with self.test_session() as sess:
+      result, expected = sess.run(
+          [
+              linear_operator_util.cholesky_solve_with_broadcast(
+                  chol_ph, rhs_ph),
+              linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
+          ],
+          feed_dict={
+              chol_ph: chol,
+              rhs_ph: rhs,
+          })
+      self.assertAllEqual(expected, result)
+
+
 class MatmulWithBroadcastTest(test.TestCase):
 
   def test_static_dims_broadcast(self):
@@ -209,7 +248,7 @@ class MatmulWithBroadcastTest(test.TestCase):
       expected = math_ops.matmul(x, y_broadcast)
       self.assertAllEqual(expected.eval(), result.eval())
 
-  def test_dynamic_dims_broadcast_32bit(self):
+  def test_dynamic_dims_broadcast_64bit(self):
     # batch_shape = [2]
     # for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
     x = rng.rand(2, 1, 3)
@@ -221,9 +260,90 @@ class MatmulWithBroadcastTest(test.TestCase):
 
     with self.test_session() as sess:
       result, expected = sess.run(
-          [linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
-           math_ops.matmul(x, y_broadcast)],
-          feed_dict={x_ph: x, y_ph: y})
+          [
+              linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
+              math_ops.matmul(x, y_broadcast)
+          ],
+          feed_dict={
+              x_ph: x,
+              y_ph: y
+          })
+      self.assertAllEqual(expected, result)
+
+
+class MatrixSolveWithBroadcastTest(test.TestCase):
+
+  def test_static_dims_broadcast(self):
+    # batch_shape = [2]
+    matrix = rng.rand(3, 3)
+    rhs = rng.rand(2, 3, 7)
+    matrix_broadcast = matrix + np.zeros((2, 1, 1))
+
+    with self.test_session():
+      result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
+      self.assertAllEqual((2, 3, 7), result.get_shape())
+      expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
+      self.assertAllEqual(expected.eval(), result.eval())
+
+  def test_dynamic_dims_broadcast_64bit(self):
+    # batch_shape = [2, 2]
+    matrix = rng.rand(2, 3, 3)
+    rhs = rng.rand(2, 1, 3, 7)
+    matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
+    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
+
+    matrix_ph = array_ops.placeholder(dtypes.float64)
+    rhs_ph = array_ops.placeholder(dtypes.float64)
+
+    with self.test_session() as sess:
+      result, expected = sess.run(
+          [
+              linear_operator_util.matrix_solve_with_broadcast(
+                  matrix_ph, rhs_ph),
+              linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
+          ],
+          feed_dict={
+              matrix_ph: matrix,
+              rhs_ph: rhs,
+          })
+      self.assertAllEqual(expected, result)
+
+
+class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
+
+  def test_static_dims_broadcast(self):
+    # batch_shape = [2]
+    matrix = rng.rand(2, 3, 3)
+    rhs = rng.rand(3, 7)
+    rhs_broadcast = rhs + np.zeros((2, 1, 1))
+
+    with self.test_session():
+      result = linear_operator_util.matrix_triangular_solve_with_broadcast(
+          matrix, rhs)
+      self.assertAllEqual((2, 3, 7), result.get_shape())
+      expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
+      self.assertAllEqual(expected.eval(), result.eval())
+
+  def test_dynamic_dims_broadcast_64bit(self):
+    # batch_shape = [2]
+    matrix = rng.rand(2, 3, 3)
+    rhs = rng.rand(3, 7)
+    rhs_broadcast = rhs + np.zeros((2, 1, 1))
+
+    matrix_ph = array_ops.placeholder(dtypes.float64)
+    rhs_ph = array_ops.placeholder(dtypes.float64)
+
+    with self.test_session() as sess:
+      result, expected = sess.run(
+          [
+              linear_operator_util.matrix_triangular_solve_with_broadcast(
+                  matrix_ph, rhs_ph),
+              linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
+          ],
+          feed_dict={
+              matrix_ph: matrix,
+              rhs_ph: rhs,
+          })
       self.assertAllEqual(expected, result)
 
 
@@ -244,7 +364,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
       operator = DomainDimensionStubOperator(3)
       # Should not raise
       linear_operator_util.assert_compatible_matrix_dimensions(
-          operator, x).run()
+          operator, x).run()  # pyformat: disable
 
   def test_incompatible_dimensions_raise(self):
     with self.test_session():
@@ -252,7 +372,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
       operator = DomainDimensionStubOperator(3)
       with self.assertRaisesOpError("Incompatible matrix dimensions"):
         linear_operator_util.assert_compatible_matrix_dimensions(
-            operator, x).run()
+            operator, x).run()  # pyformat: disable
 
 
 if __name__ == "__main__":
index 427bd1e..9dd4076 100644 (file)
@@ -23,6 +23,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 
 
@@ -102,6 +103,22 @@ def assert_is_batch_matrix(tensor):
         "%s" % tensor)
 
 
+def shape_tensor(shape, name=None):
+  """Convert Tensor using default type, unless empty list or tuple."""
+  # Works just like random_ops._ShapeTensor.
+  if isinstance(shape, (tuple, list)) and not shape:
+    dtype = dtypes.int32
+  else:
+    dtype = None
+  return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+
+
+################################################################################
+# Broadcasting versions of common linear algebra functions.
+# TODO(b/77519145) Do this more efficiently in some special cases.
+################################################################################
+
+
 def broadcast_matrix_batch_dims(batch_matrices, name=None):
   """Broadcast leading dimensions of zero or more [batch] matrices.
 
@@ -170,7 +187,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
     bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
     for mat in batch_matrices[1:]:
       bcast_batch_shape = array_ops.broadcast_static_shape(
-          bcast_batch_shape, mat.get_shape()[:-2])
+          bcast_batch_shape,
+          mat.get_shape()[:-2])
     if bcast_batch_shape.is_fully_defined():
       # The [1, 1] at the end will broadcast with anything.
       bcast_shape = bcast_batch_shape.concatenate([1, 1])
@@ -183,7 +201,8 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
     bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
     for mat in batch_matrices[1:]:
       bcast_batch_shape = array_ops.broadcast_dynamic_shape(
-          bcast_batch_shape, array_ops.shape(mat)[:-2])
+          bcast_batch_shape,
+          array_ops.shape(mat)[:-2])
     bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
     for i, mat in enumerate(batch_matrices):
       batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
@@ -195,6 +214,13 @@ def _broadcast_to_shape(x, shape):
   return x + array_ops.zeros(shape=shape, dtype=x.dtype)
 
 
+def cholesky_solve_with_broadcast(chol, rhs, name=None):
+  """Solve systems of linear equations."""
+  with ops.name_scope(name, "CholeskySolveWithBroadcast", [chol, rhs]):
+    chol, rhs = broadcast_matrix_batch_dims([chol, rhs])
+    return linalg_ops.cholesky_solve(chol, rhs)
+
+
 def matmul_with_broadcast(a,
                           b,
                           transpose_a=False,
@@ -206,6 +232,11 @@ def matmul_with_broadcast(a,
                           name=None):
   """Multiplies matrix `a` by matrix `b`, producing `a @ b`.
 
+  Works identically to `tf.matmul`, but broadcasts batch dims
+  of `a` and `b` (by replicating) if they are determined statically to be
+  different, or if static shapes are not fully defined.  Thus, this may result
+  in an inefficient replication of data.
+
   The inputs must be matrices (or tensors of rank > 2, representing batches of
   matrices).
 
@@ -276,7 +307,7 @@ def matmul_with_broadcast(a,
     ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
       are both set to True.
   """
-  with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name:
+  with ops.name_scope(name, "MatMulWithBroadcast", [a, b]):
     a, b = broadcast_matrix_batch_dims([a, b])
     return math_ops.matmul(
         a,
@@ -289,11 +320,43 @@ def matmul_with_broadcast(a,
         b_is_sparse=b_is_sparse)
 
 
-def shape_tensor(shape, name=None):
-  """Convert Tensor using default type, unless empty list or tuple."""
-  # Works just like random_ops._ShapeTensor.
-  if isinstance(shape, (tuple, list)) and not shape:
-    dtype = dtypes.int32
-  else:
-    dtype = None
-  return ops.convert_to_tensor(shape, dtype=dtype, name=name)
+def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
+  """Solve systems of linear equations."""
+  with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
+    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
+    return linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
+
+
+def matrix_triangular_solve_with_broadcast(matrix,
+                                           rhs,
+                                           lower=True,
+                                           adjoint=False,
+                                           name=None):
+  """Solves triangular systems of linear equations with by backsubstitution.
+
+  Works identically to `tf.matrix_triangular_solve`, but broadcasts batch dims
+  of `matrix` and `rhs` (by replicating) if they are determined statically to be
+  different, or if static shapes are not fully defined.  Thus, this may result
+  in an inefficient replication of data.
+
+  Args:
+    matrix: A Tensor. Must be one of the following types:
+      `float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
+    rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
+      Shape is `[..., M, K]`.
+    lower: An optional `bool`. Defaults to `True`. Indicates whether the
+      innermost matrices in `matrix` are lower or upper triangular.
+    adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
+      with matrix or its (block-wise) adjoint.
+    name: A name for the operation (optional).
+
+  Returns:
+    `Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
+  """
+  with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
+    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
+    return linalg_ops.matrix_triangular_solve(
+        matrix,
+        rhs,
+        lower=lower,
+        adjoint=adjoint)