From f85f23bea87ea856ac839806b5d62f4257fb684b Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Fri, 2 Feb 2018 13:08:31 -0800 Subject: [PATCH] Support `float16` `dtype` in `tf.linalg.*`. Note: not all `LinearOperator` functions will support `float16`. This change merely enables constructing the `LinearOperator` object(s) using this `dtype`. PiperOrigin-RevId: 184323477 --- tensorflow/python/ops/linalg/linalg_impl.py | 8 ++++---- tensorflow/python/ops/linalg/linear_operator_diag.py | 11 ++++++++--- tensorflow/python/ops/linalg/linear_operator_full_matrix.py | 10 ++++++++-- .../python/ops/linalg/linear_operator_low_rank_update.py | 10 +++++++--- .../python/ops/linalg/linear_operator_lower_triangular.py | 9 +++++++-- 5 files changed, 34 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index db33a08..a5096ff 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -65,8 +65,8 @@ def logdet(matrix, name=None): ``` Args: - matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or - `complex128` with shape `[..., M, M]`. + matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. name: A name to give this `Op`. Defaults to `logdet`. Returns: @@ -99,8 +99,8 @@ def adjoint(matrix, name=None): # [3 - 3j, 6 - 6j]] Args: - matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or - `complex128` with shape `[..., M, M]`. + matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. name: A name to give this `Op` (optional). Returns: diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index a4724d0..2217bfd 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -121,8 +121,8 @@ class LinearOperatorDiag(linear_operator.LinearOperator): Args: diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. - The diagonal of the operator. Allowed dtypes: `float32`, `float64`, - `complex64`, `complex128`. + The diagonal of the operator. Allowed dtypes: `float16`, `float32`, + `float64`, `complex64`, `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. If `diag.dtype` is real, this is auto-set to `True`. @@ -167,7 +167,12 @@ class LinearOperatorDiag(linear_operator.LinearOperator): def _check_diag(self, diag): """Static check of diag.""" allowed_dtypes = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ] dtype = diag.dtype if dtype not in allowed_dtypes: diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py index dd4c7cb..8fb59ca 100644 --- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py +++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py @@ -114,7 +114,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): Args: matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`. - Allowed dtypes: `float32`, `float64`, `complex64`, `complex128`. + Allowed dtypes: `float16`, `float32`, `float64`, `complex64`, + `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian transpose. @@ -147,7 +148,12 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): def _check_matrix(self, matrix): """Static check of the `matrix` argument.""" allowed_dtypes = [ - dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ] matrix = ops.convert_to_tensor(matrix, name="matrix") diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index ad3bb2e..36eed89 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -150,8 +150,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): `is_X` matrix property hints, which will trigger the appropriate code path. Args: - base_operator: Shape `[B1,...,Bb, M, N]` real `float32` or `float64` - `LinearOperator`. This is `L` above. + base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or + `float64` `LinearOperator`. This is `L` above. u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. This is `U` above. diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` @@ -188,7 +188,11 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): # because if diag has non-zero imaginary part, it will not be # self-adjoint positive definite. dtype = base_operator.dtype - allowed_dtypes = [dtypes.float32, dtypes.float64] + allowed_dtypes = [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] if dtype not in allowed_dtypes: raise TypeError( "Argument matrix must have dtype in %s. Found: %s" diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index 6ea55f0..6419030 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -118,7 +118,8 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly - upper triangle is ignored. Allowed dtypes: `float32`, `float64`. + upper triangle is ignored. Allowed dtypes: `float16`, `float32`, + `float64`. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. @@ -164,7 +165,11 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): """Static check of the `tril` argument.""" # TODO(langmore) Add complex types once matrix_triangular_solve works for # them. - allowed_dtypes = [dtypes.float32, dtypes.float64] + allowed_dtypes = [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + ] dtype = tril.dtype if dtype not in allowed_dtypes: raise TypeError( -- 2.7.4