From e7ec9100b45480710817ce6259bdbb4d4c2a48ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 8 Mar 2018 16:56:18 -0800 Subject: [PATCH] Check df parameter > 0 for Chi2. PiperOrigin-RevId: 188413552 --- tensorflow/contrib/distributions/python/ops/chi2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index bdd5571..e610f46 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -21,6 +21,8 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma @@ -87,7 +89,11 @@ class Chi2(gamma.Gamma): # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[df]): - self._df = ops.convert_to_tensor(df, name="df") + with ops.control_dependencies([ + check_ops.assert_positive(df), + ] if validate_args else []): + self._df = array_ops.identity(df, name="df") + super(Chi2, self).__init__( concentration=0.5 * self._df, rate=constant_op.constant(0.5, dtype=self._df.dtype), -- 2.7.4