Add asserts for LinearOperatorBlockDiag.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Mar 2018 01:05:03 +0000 (18:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Mar 2018 01:07:30 +0000 (18:07 -0700)
All asserts are only dependent on the constituent operators due to properties of the direct sum.

PiperOrigin-RevId: 190156830

tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py

index 5d7a996..80649bd 100644 (file)
@@ -24,6 +24,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops.linalg import linear_operator
 from tensorflow.python.ops.linalg import linear_operator_util
 
@@ -333,6 +334,18 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
     mat.set_shape(self.shape)
     return mat
 
+  def _assert_non_singular(self):
+    return control_flow_ops.group([
+        operator.assert_non_singular() for operator in self.operators])
+
+  def _assert_self_adjoint(self):
+    return control_flow_ops.group([
+        operator.assert_self_adjoint() for operator in self.operators])
+
+  def _assert_positive_definite(self):
+    return control_flow_ops.group([
+        operator.assert_positive_definite() for operator in self.operators])
+
   def _split_input_into_blocks(self, x, axis=-1):
     """Split `x` into blocks matching `operators`'s `domain_dimension`.