],
)
+cuda_py_test(
+ name = "kumaraswamy_test",
+ srcs = ["python/kernel_tests/kumaraswamy_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
cuda_py_test(
name = "moving_stats_test",
size = "small",
],
)
+cuda_py_test(
+ name = "kumaraswamy_bijector_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
cuda_py_test(
name = "masked_autoregressive_test",
size = "small",
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Kumaraswamy Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy
+from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
+from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
+from tensorflow.python.platform import test
+
+
+class KumaraswamyBijectorTest(test.TestCase):
+ """Tests correctness of the Kumaraswamy bijector."""
+
+ def testBijector(self):
+ with self.test_session():
+ a = 2.
+ b = 0.3
+ bijector = Kumaraswamy(
+ concentration1=a, concentration0=b,
+ event_ndims=0, validate_args=True)
+ self.assertEqual("kumaraswamy", bijector.name)
+ x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32)
+ # Kumaraswamy cdf. This is the same as inverse(x).
+ y = 1. - (1. - x ** a) ** b
+ self.assertAllClose(y, bijector.inverse(x).eval())
+ self.assertAllClose(x, bijector.forward(y).eval())
+ kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) +
+ (b - 1) * np.log1p(-x ** a))
+
+ self.assertAllClose(
+ # We should lose a dimension from calculating the determinant of the
+ # jacobian.
+ kumaraswamy_log_pdf,
+ bijector.inverse_log_det_jacobian(x).eval())
+ self.assertAllClose(
+ -bijector.inverse_log_det_jacobian(x).eval(),
+ bijector.forward_log_det_jacobian(y).eval(),
+ rtol=1e-4,
+ atol=0.)
+
+ def testScalarCongruency(self):
+ with self.test_session():
+ assert_scalar_congruency(
+ Kumaraswamy(concentration1=0.5, concentration0=1.1),
+ lower_x=0., upper_x=1., n=int(10e3), rtol=0.02)
+
+ def testBijectiveAndFinite(self):
+ with self.test_session():
+ concentration1 = 1.2
+ concentration0 = 2.
+ bijector = Kumaraswamy(
+ concentration1=concentration1,
+ concentration0=concentration0, validate_args=True)
+ # Omitting the endpoints 0 and 1, since idlj will be inifinity at these
+ # endpoints.
+ y = np.linspace(.01, 0.99, num=10).astype(np.float32)
+ x = 1 - (1 - y ** concentration1) ** concentration0
+ assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ test.main()
dist.prob([.1, .3, .6]).eval()
dist.prob([.2, .3, .5]).eval()
# Either condition can trigger.
- with self.assertRaisesOpError("sample must be positive"):
+ with self.assertRaisesOpError("sample must be non-negative"):
dist.prob([-1., 0.1, 0.5]).eval()
- with self.assertRaisesOpError("sample must be positive"):
- dist.prob([0., 0.1, 0.5]).eval()
with self.assertRaisesOpError("sample must be no larger than `1`"):
dist.prob([.1, .2, 1.2]).eval()
a = np.array([1., 2, 3])
b = np.array([2., 4, 1.2])
dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
+ with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."):
dist.mode().eval()
a = np.array([2., 2, 3])
b = np.array([1., 4, 1.2])
dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False)
- with self.assertRaisesOpError("Condition x < y.*"):
+ with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."):
dist.mode().eval()
def testKumaraswamyModeEnableAllowNanStats(self):
@@Identity
@@Inline
@@Invert
+@@Kumaraswamy
@@MaskedAutoregressiveFlow
@@Permute
@@PowerTransform
from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
+from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import *
from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import *
from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kumaraswamy bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+
+__all__ = [
+ "Kumaraswamy",
+]
+
+
+class Kumaraswamy(bijector.Bijector):
+ """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`.
+
+ This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the
+ bijector applied to a uniform random variable `X ~ U(0, 1) gives back a
+ random variable with the [Kumaraswamy distribution](
+ https://en.wikipedia.org/wiki/Kumaraswamy_distribution):
+
+ ```none
+ Y ~ Kumaraswamy(a, b)
+ pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1)
+ ```
+ """
+
+ def __init__(self,
+ concentration1=None,
+ concentration0=None,
+ event_ndims=0,
+ validate_args=False,
+ name="kumaraswamy"):
+ """Instantiates the `Kumaraswamy` bijector.
+
+ Args:
+ concentration1: Python `float` scalar indicating the transform power,
+ i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is
+ `concentration1`.
+ concentration0: Python `float` scalar indicating the transform power,
+ i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is
+ `concentration0`.
+ event_ndims: Python scalar indicating the number of dimensions associated
+ with a particular draw from the distribution. Currently only zero is
+ supported.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+
+ Raises:
+ ValueError: If `event_ndims` is not zero.
+ """
+ self._graph_parents = []
+ self._name = name
+ self._validate_args = validate_args
+
+ event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+ event_ndims_const = tensor_util.constant_value(event_ndims)
+ if event_ndims_const is not None and event_ndims_const not in (0,):
+ raise ValueError("event_ndims(%s) was not 0" % event_ndims_const)
+ else:
+ if validate_args:
+ event_ndims = control_flow_ops.with_dependencies(
+ [check_ops.assert_equal(
+ event_ndims, 0, message="event_ndims was not 0")],
+ event_ndims)
+
+ with self._name_scope("init", values=[concentration1, concentration0]):
+ concentration1 = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration1, name="concentration1"),
+ validate_args=validate_args)
+ concentration0 = self._maybe_assert_valid_concentration(
+ ops.convert_to_tensor(concentration0, name="concentration0"),
+ validate_args=validate_args)
+
+ self._concentration1 = concentration1
+ self._concentration0 = concentration0
+ super(Kumaraswamy, self).__init__(
+ event_ndims=0,
+ validate_args=validate_args,
+ name=name)
+
+ @property
+ def concentration1(self):
+ """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`."""
+ return self._concentration1
+
+ @property
+ def concentration0(self):
+ """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`."""
+ return self._concentration0
+
+ def _forward(self, x):
+ x = self._maybe_assert_valid(x)
+ return math_ops.exp(
+ math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0))
+ / self.concentration1)
+
+ def _inverse(self, y):
+ y = self._maybe_assert_valid(y)
+ return math_ops.exp(math_ops.log1p(
+ -(1 - y**self.concentration1)**self.concentration0))
+
+ def _inverse_log_det_jacobian(self, y):
+ y = self._maybe_assert_valid(y)
+ event_dims = self._event_dims_tensor(y)
+ return math_ops.reduce_sum(
+ math_ops.log(self.concentration1) + math_ops.log(self.concentration0) +
+ (self.concentration1 - 1) * math_ops.log(y) +
+ (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1),
+ axis=event_dims)
+
+ def _maybe_assert_valid_concentration(self, concentration, validate_args):
+ """Checks the validity of a concentration parameter."""
+ if not validate_args:
+ return concentration
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_positive(
+ concentration,
+ message="Concentration parameter must be positive."),
+ ], concentration)
+
+ def _maybe_assert_valid(self, x):
+ if not self.validate_args:
+ return x
+ return control_flow_ops.with_dependencies([
+ check_ops.assert_non_negative(
+ x,
+ message="sample must be non-negative"),
+ check_ops.assert_less_equal(
+ x, array_ops.ones([], self.concentration0.dtype),
+ message="sample must be no larger than `1`."),
+ ], x)
import numpy as np
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.contrib.distributions.python.ops import distribution_util
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
-from tensorflow.python.ops.distributions import beta
from tensorflow.python.ops.distributions import distribution
-from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.ops.distributions import transformed_distribution
+from tensorflow.python.ops.distributions import uniform
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@tf_export("distributions.Kumaraswamy")
-class Kumaraswamy(beta.Beta):
+class Kumaraswamy(transformed_distribution.TransformedDistribution):
"""Kumaraswamy distribution.
The Kumaraswamy distribution is defined over the `(0, 1)` interval using
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
+ concentration1 = ops.convert_to_tensor(
+ concentration1, name="concentration1")
+ concentration0 = ops.convert_to_tensor(
+ concentration0, name="concentration0")
super(Kumaraswamy, self).__init__(
- concentration1=concentration1,
- concentration0=concentration0,
- validate_args=validate_args,
- allow_nan_stats=allow_nan_stats,
+ distribution=uniform.Uniform(
+ low=array_ops.zeros([], dtype=concentration1.dtype),
+ high=array_ops.ones([], dtype=concentration1.dtype),
+ allow_nan_stats=allow_nan_stats),
+ bijector=bijectors.Kumaraswamy(
+ concentration1=concentration1, concentration0=concentration0,
+ validate_args=validate_args),
+ batch_shape=distribution_util.get_broadcast_shape(
+ concentration1, concentration0),
name=name)
self._reparameterization_type = distribution.FULLY_REPARAMETERIZED
- def _sample_n(self, n, seed=None):
- expanded_concentration1 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration1
- expanded_concentration0 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration0
- shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
- uniform_sample = random_ops.random_uniform(
- shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed)
-
- kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**(
- 1. / expanded_concentration1)
- return kumaraswamy_sample
-
- @distribution_util.AppendDocstring(_kumaraswamy_sample_note)
- def _log_cdf(self, x):
- a = self.concentration1
- b = self.concentration0
- return math_ops.log1p(-(1 - x**a)**b)
+ @property
+ def concentration1(self):
+ """Concentration parameter associated with a `1` outcome."""
+ return self.bijector.concentration1
- @distribution_util.AppendDocstring(_kumaraswamy_sample_note)
- def _cdf(self, x):
- a = self.concentration1
- b = self.concentration0
- return 1 - (1 - x**a)**b
-
- def _survival_function(self, x):
- a = self.concentration1
- b = self.concentration0
- return (1 - x**a)**b
-
- def _log_survival_function(self, x):
- a = self.concentration1
- b = self.concentration0
- return b * math_ops.log1p(-x**a)
-
- def _log_unnormalized_prob(self, x):
- x = self._maybe_assert_valid_sample(x)
- a = self.concentration1
- b = self.concentration0
- return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a)
-
- def _log_normalization(self):
- a = self.concentration1
- b = self.concentration0
- return -(math_ops.log(a) + math_ops.log(b))
+ @property
+ def concentration0(self):
+ """Concentration parameter associated with a `0` outcome."""
+ return self.bijector.concentration0
def _entropy(self):
a = self.concentration1
def _moment(self, n):
"""Compute the n'th (uncentered) moment."""
+ total_concentration = self.concentration1 + self.concentration0
expanded_concentration1 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration1
+ total_concentration, dtype=self.dtype) * self.concentration1
expanded_concentration0 = array_ops.ones_like(
- self.total_concentration, dtype=self.dtype) * self.concentration0
+ total_concentration, dtype=self.dtype) * self.concentration0
beta_arg0 = 1 + n / expanded_concentration1
beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1)
log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta(
name="nan")
is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.)
return array_ops.where(is_defined, mode, nan)
+
return control_flow_ops.with_dependencies([
check_ops.assert_less(
- array_ops.ones([], dtype=self.dtype),
+ array_ops.ones([], dtype=self.concentration1.dtype),
self.concentration1,
message="Mode undefined for concentration1 <= 1."),
check_ops.assert_less(
- array_ops.ones([], dtype=self.dtype),
+ array_ops.ones([], dtype=self.concentration0.dtype),
self.concentration0,
message="Mode undefined for concentration0 <= 1.")
], mode)