From: A. Unique TensorFlower Date: Sat, 24 Feb 2018 00:05:57 +0000 (-0800) Subject: Add Kumaraswamy Bijector, and let Kumaraswamy distribution depend on it. X-Git-Tag: upstream/v1.7.0~31^2~395 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=73b14e0c9b9ed70e7b44b5ea95ad2cef9feb7102;p=platform%2Fupstream%2Ftensorflow.git Add Kumaraswamy Bijector, and let Kumaraswamy distribution depend on it. PiperOrigin-RevId: 186838045 --- diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 35dd2ee..ed79ef7 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -252,6 +252,21 @@ cuda_py_test( ) cuda_py_test( + name = "kumaraswamy_test", + srcs = ["python/kernel_tests/kumaraswamy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( name = "moving_stats_test", size = "small", srcs = ["python/kernel_tests/moving_stats_test.py"], @@ -916,6 +931,25 @@ cuda_py_test( ) cuda_py_test( + name = "kumaraswamy_bijector_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( name = "masked_autoregressive_test", size = "small", srcs = ["python/kernel_tests/bijectors/masked_autoregressive_test.py"], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py new file mode 100644 index 0000000..ad11d9f --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Kumaraswamy Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class KumaraswamyBijectorTest(test.TestCase): + """Tests correctness of the Kumaraswamy bijector.""" + + def testBijector(self): + with self.test_session(): + a = 2. + b = 0.3 + bijector = Kumaraswamy( + concentration1=a, concentration0=b, + event_ndims=0, validate_args=True) + self.assertEqual("kumaraswamy", bijector.name) + x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32) + # Kumaraswamy cdf. This is the same as inverse(x). + y = 1. - (1. - x ** a) ** b + self.assertAllClose(y, bijector.inverse(x).eval()) + self.assertAllClose(x, bijector.forward(y).eval()) + kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) + + (b - 1) * np.log1p(-x ** a)) + + self.assertAllClose( + # We should lose a dimension from calculating the determinant of the + # jacobian. + kumaraswamy_log_pdf, + bijector.inverse_log_det_jacobian(x).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(x).eval(), + bijector.forward_log_det_jacobian(y).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Kumaraswamy(concentration1=0.5, concentration0=1.1), + lower_x=0., upper_x=1., n=int(10e3), rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + concentration1 = 1.2 + concentration0 = 2. + bijector = Kumaraswamy( + concentration1=concentration1, + concentration0=concentration0, validate_args=True) + # Omitting the endpoints 0 and 1, since idlj will be inifinity at these + # endpoints. + y = np.linspace(.01, 0.99, num=10).astype(np.float32) + x = 1 - (1 - y ** concentration1) ** concentration0 + assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py index ea3c86b..2980e2b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -130,10 +130,8 @@ class KumaraswamyTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): + with self.assertRaisesOpError("sample must be non-negative"): dist.prob([-1., 0.1, 0.5]).eval() - with self.assertRaisesOpError("sample must be positive"): - dist.prob([0., 0.1, 0.5]).eval() with self.assertRaisesOpError("sample must be no larger than `1`"): dist.prob([.1, .2, 1.2]).eval() @@ -249,13 +247,13 @@ class KumaraswamyTest(test.TestCase): a = np.array([1., 2, 3]) b = np.array([2., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."): dist.mode().eval() a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."): dist.mode().eval() def testKumaraswamyModeEnableAllowNanStats(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 93923c3..9437f56 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -26,6 +26,7 @@ @@Identity @@Inline @@Invert +@@Kumaraswamy @@MaskedAutoregressiveFlow @@Permute @@PowerTransform @@ -59,6 +60,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py new file mode 100644 index 0000000..f5de052 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -0,0 +1,153 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kumaraswamy bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Kumaraswamy", +] + + +class Kumaraswamy(bijector.Bijector): + """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`. + + This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the [Kumaraswamy distribution]( + https://en.wikipedia.org/wiki/Kumaraswamy_distribution): + + ```none + Y ~ Kumaraswamy(a, b) + pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1) + ``` + """ + + def __init__(self, + concentration1=None, + concentration0=None, + event_ndims=0, + validate_args=False, + name="kumaraswamy"): + """Instantiates the `Kumaraswamy` bijector. + + Args: + concentration1: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is + `concentration1`. + concentration0: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is + `concentration0`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init", values=[concentration1, concentration0]): + concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args=validate_args) + concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args=validate_args) + + self._concentration1 = concentration1 + self._concentration0 = concentration0 + super(Kumaraswamy, self).__init__( + event_ndims=0, + validate_args=validate_args, + name=name) + + @property + def concentration1(self): + """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration1 + + @property + def concentration0(self): + """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration0 + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.exp( + math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0)) + / self.concentration1) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.exp(math_ops.log1p( + -(1 - y**self.concentration1)**self.concentration0)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.concentration1) + math_ops.log(self.concentration0) + + (self.concentration1 - 1) * math_ops.log(y) + + (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1), + axis=event_dims) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid(self, x): + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + x, + message="sample must be non-negative"), + check_ops.assert_less_equal( + x, array_ops.ones([], self.concentration0.dtype), + message="sample must be no larger than `1`."), + ], x) diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 74d5d87..120b38d 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -20,15 +20,17 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.distributions import beta from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.distributions import uniform from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -60,7 +62,7 @@ def _harmonic_number(x): @tf_export("distributions.Kumaraswamy") -class Kumaraswamy(beta.Beta): +class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. The Kumaraswamy distribution is defined over the `(0, 1)` interval using @@ -151,59 +153,32 @@ class Kumaraswamy(beta.Beta): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ + concentration1 = ops.convert_to_tensor( + concentration1, name="concentration1") + concentration0 = ops.convert_to_tensor( + concentration0, name="concentration0") super(Kumaraswamy, self).__init__( - concentration1=concentration1, - concentration0=concentration0, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, + distribution=uniform.Uniform( + low=array_ops.zeros([], dtype=concentration1.dtype), + high=array_ops.ones([], dtype=concentration1.dtype), + allow_nan_stats=allow_nan_stats), + bijector=bijectors.Kumaraswamy( + concentration1=concentration1, concentration0=concentration0, + validate_args=validate_args), + batch_shape=distribution_util.get_broadcast_shape( + concentration1, concentration0), name=name) self._reparameterization_type = distribution.FULLY_REPARAMETERIZED - def _sample_n(self, n, seed=None): - expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 - expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 - shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) - uniform_sample = random_ops.random_uniform( - shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) - - kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( - 1. / expanded_concentration1) - return kumaraswamy_sample - - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _log_cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return math_ops.log1p(-(1 - x**a)**b) + @property + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self.bijector.concentration1 - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return 1 - (1 - x**a)**b - - def _survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return (1 - x**a)**b - - def _log_survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return b * math_ops.log1p(-x**a) - - def _log_unnormalized_prob(self, x): - x = self._maybe_assert_valid_sample(x) - a = self.concentration1 - b = self.concentration0 - return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) - - def _log_normalization(self): - a = self.concentration1 - b = self.concentration0 - return -(math_ops.log(a) + math_ops.log(b)) + @property + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self.bijector.concentration0 def _entropy(self): a = self.concentration1 @@ -213,10 +188,11 @@ class Kumaraswamy(beta.Beta): def _moment(self, n): """Compute the n'th (uncentered) moment.""" + total_concentration = self.concentration1 + self.concentration0 expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 + total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 + total_concentration, dtype=self.dtype) * self.concentration0 beta_arg0 = 1 + n / expanded_concentration1 beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( @@ -246,13 +222,14 @@ class Kumaraswamy(beta.Beta): name="nan") is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration1.dtype), self.concentration1, message="Mode undefined for concentration1 <= 1."), check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration0.dtype), self.concentration0, message="Mode undefined for concentration0 <= 1.") ], mode)