Support `float16` `dtype` in `tf.linalg.*`.
authorJoshua V. Dillon <jvdillon@google.com>
Fri, 2 Feb 2018 21:08:31 +0000 (13:08 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 21:15:17 +0000 (13:15 -0800)
Note: not all `LinearOperator` functions will support `float16`. This change
merely enables constructing the `LinearOperator` object(s) using this `dtype`.

PiperOrigin-RevId: 184323477

tensorflow/python/ops/linalg/linalg_impl.py
tensorflow/python/ops/linalg/linear_operator_diag.py
tensorflow/python/ops/linalg/linear_operator_full_matrix.py
tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
tensorflow/python/ops/linalg/linear_operator_lower_triangular.py

index db33a08..a5096ff 100644 (file)
@@ -65,8 +65,8 @@ def logdet(matrix, name=None):
   ```
 
   Args:
-    matrix:  A `Tensor`. Must be `float32`, `float64`, `complex64`, or
-      `complex128` with shape `[..., M, M]`.
+    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
+      or `complex128` with shape `[..., M, M]`.
     name:  A name to give this `Op`.  Defaults to `logdet`.
 
   Returns:
@@ -99,8 +99,8 @@ def adjoint(matrix, name=None):
                         #  [3 - 3j, 6 - 6j]]
 
   Args:
-    matrix:  A `Tensor`. Must be `float32`, `float64`, `complex64`, or
-      `complex128` with shape `[..., M, M]`.
+    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
+      or `complex128` with shape `[..., M, M]`.
     name:  A name to give this `Op` (optional).
 
   Returns:
index a4724d0..2217bfd 100644 (file)
@@ -121,8 +121,8 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
 
     Args:
       diag:  Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
-        The diagonal of the operator.  Allowed dtypes: `float32`, `float64`,
-          `complex64`, `complex128`.
+        The diagonal of the operator.  Allowed dtypes: `float16`, `float32`,
+          `float64`, `complex64`, `complex128`.
       is_non_singular:  Expect that this operator is non-singular.
       is_self_adjoint:  Expect that this operator is equal to its hermitian
         transpose.  If `diag.dtype` is real, this is auto-set to `True`.
@@ -167,7 +167,12 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
   def _check_diag(self, diag):
     """Static check of diag."""
     allowed_dtypes = [
-        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
+        dtypes.float16,
+        dtypes.float32,
+        dtypes.float64,
+        dtypes.complex64,
+        dtypes.complex128,
+    ]
 
     dtype = diag.dtype
     if dtype not in allowed_dtypes:
index dd4c7cb..8fb59ca 100644 (file)
@@ -114,7 +114,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
 
     Args:
       matrix:  Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`.
-        Allowed dtypes: `float32`, `float64`, `complex64`, `complex128`.
+        Allowed dtypes: `float16`, `float32`, `float64`, `complex64`,
+        `complex128`.
       is_non_singular:  Expect that this operator is non-singular.
       is_self_adjoint:  Expect that this operator is equal to its hermitian
         transpose.
@@ -147,7 +148,12 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
   def _check_matrix(self, matrix):
     """Static check of the `matrix` argument."""
     allowed_dtypes = [
-        dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
+        dtypes.float16,
+        dtypes.float32,
+        dtypes.float64,
+        dtypes.complex64,
+        dtypes.complex128,
+    ]
 
     matrix = ops.convert_to_tensor(matrix, name="matrix")
 
index ad3bb2e..36eed89 100644 (file)
@@ -150,8 +150,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
     `is_X` matrix property hints, which will trigger the appropriate code path.
 
     Args:
-      base_operator:  Shape `[B1,...,Bb, M, N]` real `float32` or `float64`
-        `LinearOperator`.  This is `L` above.
+      base_operator:  Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or
+        `float64` `LinearOperator`.  This is `L` above.
       u:  Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
         This is `U` above.
       diag_update:  Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
@@ -188,7 +188,11 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
     #    because if diag has non-zero imaginary part, it will not be
     #    self-adjoint positive definite.
     dtype = base_operator.dtype
-    allowed_dtypes = [dtypes.float32, dtypes.float64]
+    allowed_dtypes = [
+        dtypes.float16,
+        dtypes.float32,
+        dtypes.float64,
+    ]
     if dtype not in allowed_dtypes:
       raise TypeError(
           "Argument matrix must have dtype in %s.  Found: %s"
index 6ea55f0..6419030 100644 (file)
@@ -118,7 +118,8 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
     Args:
       tril:  Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
         The lower triangular part of `tril` defines this operator.  The strictly
-        upper triangle is ignored.  Allowed dtypes: `float32`, `float64`.
+        upper triangle is ignored.  Allowed dtypes: `float16`, `float32`,
+        `float64`.
       is_non_singular:  Expect that this operator is non-singular.
         This operator is non-singular if and only if its diagonal elements are
         all non-zero.
@@ -164,7 +165,11 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
     """Static check of the `tril` argument."""
     # TODO(langmore) Add complex types once matrix_triangular_solve works for
     # them.
-    allowed_dtypes = [dtypes.float32, dtypes.float64]
+    allowed_dtypes = [
+        dtypes.float16,
+        dtypes.float32,
+        dtypes.float64,
+    ]
     dtype = tril.dtype
     if dtype not in allowed_dtypes:
       raise TypeError(