From eed6828acf19260279b38a7fbaf79141c813f795 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 11 Apr 2018 14:02:49 -0700 Subject: [PATCH] BREAKING_CHANGE: Remove event_ndims in Bijector, and require `log_det_jacobian` methods to take event_ndims. The class level event_ndims parameter is being deprecated in favor of passing it in to the `log_det_jacobian` methods. Specific changes: - `log_det_jacobian` signatures are now `log_det_jacobian(input, event_ndims)` - Constructors no long have event_ndims passed in (e.g. Affine() vs. Affine(event_ndims=0)). - All bijectors must specify a subset of [forward_min_event_ndims, inverse_min_event_ndims]. This is the minimal dimensionality the bijector operates on, with it being "broadcasted" to any passed in event_ndims (e.g. Exp has forward_min_event_ndims = 0. That means it operates on scalars. However, we can use the bijector on any event_ndims > 0 (i.e. we've broadcasted the transformation to work on any amount of event_ndims > 0), and jacobian reduction will work in those cases. As a result of this change, all bijectors should "broadcast" (e.g. Sigmoid now works on any number of event_ndims). Other changes (internal and documentation): - Added clarifications on Jacobian Determinant vs. Jacobian Matrix. - Added clarifications on min_event_ndims, and what the jacobian reduction is over. - Changed caching of ildj to be keyed on event_ndims. - Several bug fixes to bugs unearthed while writing this code (e.g. transformed distribution shape computation being incorrect) PiperOrigin-RevId: 192504919 --- .../kernel_tests/bijectors/absolute_value_test.py | 35 +- .../bijectors/affine_linear_operator_test.py | 30 +- .../kernel_tests/bijectors/affine_scalar_test.py | 65 ++-- .../python/kernel_tests/bijectors/affine_test.py | 231 ++++++----- .../bijectors/batch_normalization_test.py | 5 +- .../python/kernel_tests/bijectors/chain_test.py | 132 ++++++- .../bijectors/cholesky_outer_product_test.py | 9 +- .../bijectors/conditional_bijector_test.py | 12 +- .../python/kernel_tests/bijectors/exp_test.py | 18 +- .../python/kernel_tests/bijectors/gumbel_test.py | 16 +- .../python/kernel_tests/bijectors/inline_test.py | 18 +- .../python/kernel_tests/bijectors/invert_test.py | 12 +- .../bijectors/kumaraswamy_bijector_test.py | 15 +- .../bijectors/masked_autoregressive_test.py | 5 +- .../python/kernel_tests/bijectors/permute_test.py | 11 +- .../kernel_tests/bijectors/power_transform_test.py | 17 +- .../python/kernel_tests/bijectors/real_nvp_test.py | 12 +- .../python/kernel_tests/bijectors/reshape_test.py | 7 +- .../python/kernel_tests/bijectors/sigmoid_test.py | 16 +- .../bijectors/sinh_arcsinh_bijector_test.py | 22 +- .../bijectors/softmax_centered_test.py | 14 +- .../python/kernel_tests/bijectors/softplus_test.py | 40 +- .../python/kernel_tests/bijectors/square_test.py | 7 +- .../python/kernel_tests/bijectors/weibull_test.py | 16 +- .../conditional_transformed_distribution_test.py | 3 +- .../python/kernel_tests/mvn_diag_test.py | 2 +- .../kernel_tests/transformed_distribution_test.py | 121 +++++- .../kernel_tests/vector_laplace_diag_test.py | 2 +- .../python/ops/bijectors/absolute_value.py | 29 +- .../distributions/python/ops/bijectors/affine.py | 10 +- .../python/ops/bijectors/affine_linear_operator.py | 36 +- .../python/ops/bijectors/affine_scalar.py | 13 +- .../python/ops/bijectors/batch_normalization.py | 6 +- .../distributions/python/ops/bijectors/chain.py | 157 +++++++- .../python/ops/bijectors/cholesky_outer_product.py | 2 +- .../python/ops/bijectors/conditional_bijector.py | 12 +- .../distributions/python/ops/bijectors/exp.py | 10 +- .../distributions/python/ops/bijectors/gumbel.py | 15 +- .../distributions/python/ops/bijectors/inline.py | 15 +- .../distributions/python/ops/bijectors/invert.py | 3 +- .../python/ops/bijectors/kumaraswamy.py | 27 +- .../python/ops/bijectors/masked_autoregressive.py | 3 +- .../distributions/python/ops/bijectors/permute.py | 8 +- .../python/ops/bijectors/power_transform.py | 16 +- .../distributions/python/ops/bijectors/real_nvp.py | 4 +- .../distributions/python/ops/bijectors/reshape.py | 8 +- .../distributions/python/ops/bijectors/sigmoid.py | 4 +- .../python/ops/bijectors/sinh_arcsinh.py | 29 +- .../python/ops/bijectors/softmax_centered.py | 12 +- .../distributions/python/ops/bijectors/softplus.py | 11 +- .../distributions/python/ops/bijectors/square.py | 2 +- .../distributions/python/ops/bijectors/weibull.py | 17 +- .../ops/conditional_transformed_distribution.py | 21 +- .../distributions/python/ops/poisson_lognormal.py | 2 +- .../python/ops/relaxed_onehot_categorical.py | 2 +- .../distributions/python/ops/sinh_arcsinh.py | 4 +- .../python/ops/vector_diffeomixture.py | 10 +- .../python/ops/vector_sinh_arcsinh_diag.py | 4 +- .../kernel_tests/distributions/bijector_test.py | 181 +++++++-- .../distributions/identity_bijector_test.py | 21 +- .../python/ops/distributions/bijector_impl.py | 429 ++++++++++++++++----- .../python/ops/distributions/bijector_test_util.py | 23 +- tensorflow/python/ops/distributions/bijectors.py | 31 -- .../python/ops/distributions/distributions.py | 2 - .../python/ops/distributions/identity_bijector.py | 8 +- .../ops/distributions/transformed_distribution.py | 58 ++- ...sorflow.distributions.bijectors.-bijector.pbtxt | 65 ---- ...sorflow.distributions.bijectors.-identity.pbtxt | 66 ---- .../tensorflow.distributions.bijectors.pbtxt | 11 - .../api/golden/tensorflow.distributions.pbtxt | 4 - 70 files changed, 1412 insertions(+), 872 deletions(-) delete mode 100644 tensorflow/python/ops/distributions/bijectors.py delete mode 100644 tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt delete mode 100644 tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt delete mode 100644 tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py index e0d65c7..042c8eb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -18,11 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - # pylint: disable=g-importing-member from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import AbsoluteValue -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,50 +32,38 @@ class AbsoluteValueTest(test.TestCase): def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) self.assertEqual("absolute_value", bijector.name) x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3] y = math_ops.abs(x) y_ = y.eval() - zeros = np.zeros((2, 3)) self.assertAllClose(y_, bijector.forward(x).eval()) self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) - self.assertAllClose((zeros, zeros), - sess.run(bijector.inverse_log_det_jacobian(y))) + self.assertAllClose((0., 0.), + sess.run(bijector.inverse_log_det_jacobian( + y, event_ndims=0))) # Run things twice to make sure there are no issues in caching the tuples # returned by .inverse* self.assertAllClose(y_, bijector.forward(x).eval()) self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) - self.assertAllClose((zeros, zeros), - sess.run(bijector.inverse_log_det_jacobian(y))) - - def testEventNdimsMustBeZeroOrRaiseStatic(self): - with self.test_session(): - with self.assertRaisesRegexp(ValueError, "event_ndims.*was not 0"): - AbsoluteValue(event_ndims=1) - - def testEventNdimsMustBeZeroOrRaiseDynamic(self): - with self.test_session() as sess: - event_ndims = array_ops.placeholder(dtypes.int32) - abs_bijector = AbsoluteValue(event_ndims=event_ndims, validate_args=True) - with self.assertRaisesOpError("event_ndims was not 0"): - sess.run(abs_bijector.inverse_log_det_jacobian([1.]), - feed_dict={event_ndims: 1}) + self.assertAllClose((0., 0.), + sess.run(bijector.inverse_log_det_jacobian( + y, event_ndims=0))) def testNegativeYRaisesForInverseIfValidateArgs(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): sess.run(bijector.inverse(-1.)) def testNegativeYRaisesForILDJIfValidateArgs(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): - sess.run(bijector.inverse_log_det_jacobian(-1.)) + sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 405ddd2..1e4ad72 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -38,9 +38,11 @@ class AffineLinearOperatorTest(test.TestCase): self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose(ildj, affine.inverse_log_det_jacobian( + y, event_ndims=2).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(), + affine.forward_log_det_jacobian(x, event_ndims=2).eval()) def testDiag(self): with self.test_session(): @@ -58,14 +60,16 @@ class AffineLinearOperatorTest(test.TestCase): self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, affine.inverse_log_det_jacobian(y, event_ndims=1).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=1).eval(), + affine.forward_log_det_jacobian(x, event_ndims=1).eval()) def testTriL(self): with self.test_session(): shift = np.array([-1, 0, 1], dtype=np.float32) - tril = np.array([[[1, 0, 0], + tril = np.array([[[3, 0, 0], [2, -1, 0], [3, 2, 1]], [[2, 0, 0], @@ -85,15 +89,17 @@ class AffineLinearOperatorTest(test.TestCase): # y = np.matmul(x, tril) + shift. y = np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift ildj = -np.sum(np.log(np.abs(np.diagonal( - tril, axis1=-2, axis2=-1))), - axis=-1) + tril, axis1=-2, axis2=-1)))) self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, affine.inverse_log_det_jacobian( + y, event_ndims=2).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(), + affine.forward_log_det_jacobian(x, event_ndims=2).eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py index 16173a1..d253362 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py @@ -40,13 +40,13 @@ class AffineScalarBijectorTest(test.TestCase): def testNoBatchScalar(self): with self.test_session() as sess: - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -55,19 +55,20 @@ class AffineScalarBijectorTest(test.TestCase): x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose([-np.log(2.)] * 3, - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): with self.test_session() as sess: - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value).astype(np.float64) x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) for run in (static_run, dynamic_run): mu = np.float64([1.]) @@ -77,18 +78,20 @@ class AffineScalarBijectorTest(test.TestCase): x = np.float64([1.]) # One sample from one batches. self.assertAllClose([2.], run(bijector.forward, x)) self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + 0., + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self): with self.test_session() as sess: - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value).astype(np.float64) x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) for run in (static_run, dynamic_run): multiplier = np.float64([2.]) @@ -98,19 +101,20 @@ class AffineScalarBijectorTest(test.TestCase): x = np.float64([1.]) # One sample from one batches. self.assertAllClose([2.], run(bijector.forward, x)) self.assertAllClose([0.5], run(bijector.inverse, x)) - self.assertAllClose([np.log(0.5)], - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + [np.log(0.5)], + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testTwoBatchScalarIdentityViaIdentity(self): with self.test_session() as sess: - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float32) x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] @@ -120,18 +124,20 @@ class AffineScalarBijectorTest(test.TestCase): x = [1., 1] # One sample from each of two batches. self.assertAllClose([2., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose([0., 0.], run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + 0., + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testTwoBatchScalarIdentityViaScale(self): with self.test_session() as sess: - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float32) x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] @@ -142,7 +148,8 @@ class AffineScalarBijectorTest(test.TestCase): self.assertAllClose([3., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) self.assertAllClose( - [-np.log(2), 0.], run(bijector.inverse_log_det_jacobian, x)) + [-np.log(2), 0.], + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testScalarCongruency(self): with self.test_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 077e617..9e14b9a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -40,14 +40,15 @@ class AffineBijectorTest(test.TestCase): def testNoBatchMultivariateIdentity(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] @@ -66,18 +67,20 @@ class AffineBijectorTest(test.TestCase): x = [[1., 1], [-1., -1]] self.assertAllClose([[2., 0], [0., -2]], run(bijector.forward, x)) self.assertAllClose([[0., 2], [-2., 0]], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + 0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateDiag(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] @@ -89,9 +92,12 @@ class AffineBijectorTest(test.TestCase): # = [-1, -1] + [1, -1] self.assertAllClose([3., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + # Reset bijector. + bijector = Affine(shift=mu, scale_diag=[2., 1]) # x is a 2-batch of 2-vectors. # The first vector is [1, 1], the second is [-1, -1]. # Each undergoes matmul(sigma, x) + shift. @@ -103,8 +109,9 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([[0., 2], [-1., 0]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateFullDynamic(self): with self.test_session() as sess: @@ -126,18 +133,20 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose( -np.log(4), - sess.run(bijector.inverse_log_det_jacobian(x), feed_dict)) + sess.run(bijector.inverse_log_det_jacobian(x, event_ndims=1), + feed_dict)) def testBatchMultivariateIdentity(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value, dtype=np.float32) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [[1., -1]] @@ -147,19 +156,21 @@ class AffineBijectorTest(test.TestCase): x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(4), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(4), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateDiag(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value, dtype=np.float32) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [[1., -1]] @@ -169,8 +180,9 @@ class AffineBijectorTest(test.TestCase): x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) - self.assertAllClose([-np.log(4)], - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + [-np.log(4)], + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateFullDynamic(self): with self.test_session() as sess: @@ -191,20 +203,22 @@ class AffineBijectorTest(test.TestCase): bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict)) - self.assertAllClose([-np.log(4)], - sess.run( - bijector.inverse_log_det_jacobian(x), feed_dict)) + self.assertAllClose( + [-np.log(4)], + sess.run(bijector.inverse_log_det_jacobian( + x, event_ndims=1), feed_dict)) def testIdentityWithDiagUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -216,19 +230,21 @@ class AffineBijectorTest(test.TestCase): x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.**3), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.**3), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -240,19 +256,21 @@ class AffineBijectorTest(test.TestCase): x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 5]], run(bijector.forward, x)) self.assertAllClose([[1., 0.5]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(4.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(4.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -262,19 +280,21 @@ class AffineBijectorTest(test.TestCase): x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 7]], run(bijector.forward, x)) self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(6.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(6.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityAndDiagWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -287,19 +307,21 @@ class AffineBijectorTest(test.TestCase): x = [[1., 2]] # One multivariate sample. self.assertAllClose([[2., 9]], run(bijector.forward, x)) self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(12.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(12.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -319,22 +341,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 1.5, 4 / 3.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(60.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(60.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -353,22 +377,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 1., 0.8], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(150.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(150.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -388,22 +414,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 14 / 15., 4 / 25.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(150.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(150.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdateNoDiagonal(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -423,11 +451,12 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([1 / 3., 8 / 9., 4 / 30.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(90.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(90.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateRaisesWhenSingular(self): with self.test_session(): @@ -530,6 +559,7 @@ class AffineBijectorTest(test.TestCase): backward = np.squeeze(backward, axis=-1) self.assertAllClose(backward, bijector.inverse(x).eval()) + scale *= np.ones(shape=x.shape[:-1], dtype=scale.dtype) ildj = -np.log(np.abs(np.linalg.det(scale))) # TODO(jvdillon): We need to make it so the scale_identity_multiplier # case does not deviate in expected shape. Fixing this will get rid of @@ -540,7 +570,8 @@ class AffineBijectorTest(test.TestCase): ildj = np.squeeze(ildj[0]) elif ildj.ndim < scale.ndim - 2: ildj = np.reshape(ildj, scale.shape[0:-2]) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian(x, event_ndims=1).eval()) def testLegalInputs(self): self._testLegalInputs( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py index a215a4a..c832fca 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py @@ -83,10 +83,11 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers, moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) denorm_x = batch_norm.forward(array_ops.identity(norm_x)) - fldj = batch_norm.forward_log_det_jacobian(x) + fldj = batch_norm.forward_log_det_jacobian( + x, event_ndims=len(event_dims)) # Use identity to invalidate cache. ildj = batch_norm.inverse_log_det_jacobian( - array_ops.identity(denorm_x)) + array_ops.identity(denorm_x), event_ndims=len(event_dims)) variables.global_variables_initializer().run() # Update variables. norm_x_ = sess.run(norm_x) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index a748acd..ca20442 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -20,21 +20,33 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test +class ShapeChanging(bijector.Bijector): + """Only used for op_ndims manipulation.""" + + def __init__(self, forward_min_event_ndims=0, inverse_min_event_ndims=3): + super(ShapeChanging, self).__init__( + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, + validate_args=False, name="shape_changer") + + class ChainBijectorTest(test.TestCase): """Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation.""" def testBijector(self): with self.test_session(): - chain = Chain((Exp(event_ndims=1), Softplus(event_ndims=1))) + chain = Chain((Exp(), Softplus())) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) @@ -42,9 +54,10 @@ class ChainBijectorTest(test.TestCase): self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval()) self.assertAllClose( -np.sum(np.log(x - 1.), axis=2), - chain.inverse_log_det_jacobian(x).eval()) + chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - np.sum(x, axis=2), chain.forward_log_det_jacobian(x).eval()) + np.sum(x, axis=2), + chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testBijectorIdentity(self): with self.test_session(): @@ -54,31 +67,126 @@ class ChainBijectorTest(test.TestCase): [2., 3.]]]) self.assertAllClose(x, chain.forward(x).eval()) self.assertAllClose(x, chain.inverse(x).eval()) - self.assertAllClose(0., chain.inverse_log_det_jacobian(x).eval()) - self.assertAllClose(0., chain.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + 0., chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) + self.assertAllClose( + 0., chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): - bijector = Chain((Exp(), Softplus())) + chain = Chain((Exp(), Softplus())) assert_scalar_congruency( - bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + chain, lower_x=1e-3, upper_x=1.5, rtol=0.05) def testShapeGetters(self): with self.test_session(): - bijector = Chain([ + chain = Chain([ SoftmaxCentered(validate_args=True), SoftmaxCentered(validate_args=True), ]) x = tensor_shape.TensorShape([1]) y = tensor_shape.TensorShape([2 + 1]) - self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y, chain.forward_event_shape(x)) self.assertAllEqual( y.as_list(), - bijector.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) + chain.forward_event_shape_tensor(x.as_list()).eval()) + self.assertAllEqual(x, chain.inverse_event_shape(y)) self.assertAllEqual( x.as_list(), - bijector.inverse_event_shape_tensor(y.as_list()).eval()) + chain.inverse_event_shape_tensor(y.as_list()).eval()) + + def testMinEventNdimsChain(self): + chain = Chain([Exp(), Exp(), Exp()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Affine(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Exp(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Exp()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Exp(), Softplus(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingAddDims(self): + chain = Chain([ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(3, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(4, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(3, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(), ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(6, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingRemoveDims(self): + chain = Chain([ShapeChanging(3, 0)]) + self.assertEqual(3, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(3, 0), Affine()]) + self.assertEqual(3, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), ShapeChanging(3, 0)]) + self.assertEqual(4, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) + self.assertEqual(6, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingAddRemoveDims(self): + chain = Chain([ + ShapeChanging(2, 1), + ShapeChanging(3, 0), + ShapeChanging(1, 2)]) + self.assertEqual(4, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + def testChainExpAffine(self): + scale_diag = np.array([1., 2., 3.], dtype=np.float32) + chain = Chain([Exp(), Affine(scale_diag=scale_diag)]) + x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] + y = [1., 4., 27.] + self.assertAllClose(y, self.evaluate(chain.forward(x))) + self.assertAllClose(x, self.evaluate(chain.inverse(y))) + self.assertAllClose( + np.log(6, dtype=np.float32) + np.sum(scale_diag * x), + self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) + + self.assertAllClose( + -np.log(6, dtype=np.float32) - np.sum(scale_diag * x), + self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + + def testChainAffineExp(self): + scale_diag = np.array([1., 2., 3.], dtype=np.float32) + chain = Chain([Affine(scale_diag=scale_diag), Exp()]) + x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] + y = [1., 4., 9.] + self.assertAllClose(y, self.evaluate(chain.forward(x))) + self.assertAllClose(x, self.evaluate(chain.inverse(y))) + self.assertAllClose( + np.log(6, dtype=np.float32) + np.sum(x), + self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) + + self.assertAllClose( + -np.log(6, dtype=np.float32) - np.sum(x), + self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index f392e83..e281e81 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -51,10 +51,13 @@ class CholeskyOuterProductBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7) + ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=2).eval(), atol=0., rtol=1e-7) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian( + y, event_ndims=2).eval(), + bijector.forward_log_det_jacobian( + x, event_ndims=2).eval(), atol=0., rtol=1e-7) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 26e0d2a..8b279eb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -27,7 +27,7 @@ class _TestBijector(ConditionalBijector): def __init__(self): super(_TestBijector, self).__init__( - event_ndims=0, + forward_min_event_ndims=0, graph_parents=[], is_constant_jacobian=True, validate_args=False, @@ -51,11 +51,15 @@ class ConditionalBijectorTest(test.TestCase): def testConditionalBijector(self): b = _TestBijector() - for name in ["forward", "inverse", "inverse_log_det_jacobian", - "forward_log_det_jacobian"]: + for name in ["forward", "inverse"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1.0, arg1="b1", arg2="b2") + method(1., arg1="b1", arg2="b2") + + for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: + method = getattr(b, name) + with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): + method(1., event_ndims=0., arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py index 9970c0b..7be939c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py @@ -31,17 +31,21 @@ class ExpBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - bijector = Exp(event_ndims=1) + bijector = Exp() self.assertEqual("exp", bijector.name) x = [[[1.], [2.]]] y = np.exp(x) self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - -np.sum(np.log(y), axis=-1), - bijector.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-bijector.inverse_log_det_jacobian(np.exp(x)).eval(), - bijector.forward_log_det_jacobian(x).eval()) + -np.squeeze(np.log(y), axis=-1), + bijector.inverse_log_det_jacobian( + y, event_ndims=1).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian( + np.exp(x), event_ndims=1).eval(), + bijector.forward_log_det_jacobian( + x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): @@ -51,10 +55,10 @@ class ExpBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): - bijector = Exp(event_ndims=0) + bijector = Exp() x = np.linspace(-10, 10, num=10).astype(np.float32) y = np.logspace(-10, 10, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y) + assert_bijective_and_finite(bijector, x, y, event_ndims=0) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py index 9a90598..54e54c3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -34,7 +34,7 @@ class GumbelBijectorTest(test.TestCase): with self.test_session(): loc = 0.3 scale = 5. - bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True) + bijector = Gumbel(loc=loc, scale=scale, validate_args=True) self.assertEqual("gumbel", bijector.name) x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32) # Gumbel distribution @@ -43,13 +43,11 @@ class GumbelBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - np.squeeze(gumbel_dist.logpdf(x), axis=2), - bijector.forward_log_det_jacobian(x).eval()) + np.squeeze(gumbel_dist.logpdf(x), axis=-1), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -60,10 +58,10 @@ class GumbelBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): - bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True) + bijector = Gumbel(loc=0., scale=3.0, validate_args=True) x = np.linspace(-10., 10., num=10).astype(np.float32) y = np.linspace(0.01, 0.99, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py index 739fa6d..7d3bd75 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py @@ -33,15 +33,13 @@ class InlineBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - exp = Exp(event_ndims=1) + exp = Exp() inline = Inline( forward_fn=math_ops.exp, inverse_fn=math_ops.log, - inverse_log_det_jacobian_fn=( - lambda y: -math_ops.reduce_sum( # pylint: disable=g-long-lambda - math_ops.log(y), reduction_indices=-1)), - forward_log_det_jacobian_fn=( - lambda x: math_ops.reduce_sum(x, reduction_indices=-1)), + inverse_log_det_jacobian_fn=lambda y: -math_ops.log(y), + forward_log_det_jacobian_fn=lambda x: x, + forward_min_event_ndims=0, name="exp") self.assertEqual(exp.name, inline.name) @@ -51,9 +49,10 @@ class InlineBijectorTest(test.TestCase): self.assertAllClose(x, inline.inverse(y).eval()) self.assertAllClose( -np.sum(np.log(y), axis=-1), - inline.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-inline.inverse_log_det_jacobian(y).eval(), - inline.forward_log_det_jacobian(x).eval()) + inline.inverse_log_det_jacobian(y, event_ndims=1).eval()) + self.assertAllClose( + -inline.inverse_log_det_jacobian(y, event_ndims=1).eval(), + inline.forward_log_det_jacobian(x, event_ndims=1).eval()) def testShapeGetters(self): with self.test_session(): @@ -62,6 +61,7 @@ class InlineBijectorTest(test.TestCase): forward_event_shape_fn=lambda x: x.as_list() + [1], inverse_event_shape_tensor_fn=lambda x: x[:-1], inverse_event_shape_fn=lambda x: x[:-1], + forward_min_event_ndims=0, name="shape_only") x = tensor_shape.TensorShape([1, 2, 3]) y = tensor_shape.TensorShape([1, 2, 3, 1]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index 58ba9ce..8b14c83 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -34,9 +34,9 @@ class InvertBijectorTest(test.TestCase): with self.test_session(): for fwd in [ bijectors.Identity(), - bijectors.Exp(event_ndims=1), + bijectors.Exp(), bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]), - bijectors.Softplus(event_ndims=1), + bijectors.Softplus(), bijectors.SoftmaxCentered(), ]: rev = bijectors.Invert(fwd) @@ -46,11 +46,11 @@ class InvertBijectorTest(test.TestCase): self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval()) self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval()) self.assertAllClose( - fwd.forward_log_det_jacobian(x).eval(), - rev.inverse_log_det_jacobian(x).eval()) + fwd.forward_log_det_jacobian(x, event_ndims=1).eval(), + rev.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - fwd.inverse_log_det_jacobian(x).eval(), - rev.forward_log_det_jacobian(x).eval()) + fwd.inverse_log_det_jacobian(x, event_ndims=1).eval(), + rev.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py index 074b5f2..a808988 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -34,8 +34,7 @@ class KumaraswamyBijectorTest(test.TestCase): a = 2. b = 0.3 bijector = Kumaraswamy( - concentration1=a, concentration0=b, - event_ndims=0, validate_args=True) + concentration1=a, concentration0=b, validate_args=True) self.assertEqual("kumaraswamy", bijector.name) x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32) # Kumaraswamy cdf. This is the same as inverse(x). @@ -46,13 +45,11 @@ class KumaraswamyBijectorTest(test.TestCase): (b - 1) * np.log1p(-x ** a)) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - kumaraswamy_log_pdf, - bijector.inverse_log_det_jacobian(x).eval()) + np.squeeze(kumaraswamy_log_pdf, axis=-1), + bijector.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(x).eval(), - bijector.forward_log_det_jacobian(y).eval(), + -bijector.inverse_log_det_jacobian(x, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(y, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -73,7 +70,7 @@ class KumaraswamyBijectorTest(test.TestCase): # endpoints. y = np.linspace(.01, 0.99, num=10).astype(np.float32) x = 1 - (1 - y ** concentration1) ** concentration0 - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index dcfb0eb..5ba5a20 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -79,9 +79,10 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, forward_x = ma.forward(x) # Use identity to invalidate cache. inverse_y = ma.inverse(array_ops.identity(forward_x)) - fldj = ma.forward_log_det_jacobian(x) + fldj = ma.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. - ildj = ma.inverse_log_det_jacobian(array_ops.identity(forward_x)) + ildj = ma.inverse_log_det_jacobian( + array_ops.identity(forward_x), event_ndims=1) variables.global_variables_initializer().run() [ forward_x_, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py index 54590de..7eef4ab 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -53,8 +53,8 @@ class PermuteBijectorTest(test.TestCase): bijector.permutation, bijector.inverse(expected_y), bijector.forward(expected_x), - bijector.forward_log_det_jacobian(expected_x), - bijector.inverse_log_det_jacobian(expected_y), + bijector.forward_log_det_jacobian(expected_x, event_ndims=1), + bijector.inverse_log_det_jacobian(expected_y, event_ndims=1), ], feed_dict={permutation_ph: expected_permutation}) self.assertEqual("permute", bijector.name) self.assertAllEqual(expected_permutation, permutation_) @@ -78,10 +78,9 @@ class PermuteBijectorTest(test.TestCase): x = np.random.randn(4, 2, 3) y = x[..., permutation] with self.test_session(): - bijector = Permute( - permutation=permutation, - validate_args=True) - assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + bijector = Permute(permutation=permutation, validate_args=True) + assert_bijective_and_finite( + bijector, x, y, event_ndims=1, rtol=1e-6, atol=0) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py index de1659a..85d2283 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py @@ -32,8 +32,7 @@ class PowerTransformBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): c = 0.2 - bijector = PowerTransform( - power=c, event_ndims=1, validate_args=True) + bijector = PowerTransform(power=c, validate_args=True) self.assertEqual("power_transform", bijector.name) x = np.array([[[-1.], [2.], [-5. + 1e-4]]]) y = (1. + x * c)**(1. / c) @@ -41,27 +40,25 @@ class PowerTransformBijectorTest(test.TestCase): self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( (c - 1.) * np.sum(np.log(y), axis=-1), - bijector.inverse_log_det_jacobian(y).eval()) + bijector.inverse_log_det_jacobian(y, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) def testScalarCongruency(self): with self.test_session(): - bijector = PowerTransform( - power=0.2, validate_args=True) + bijector = PowerTransform(power=0.2, validate_args=True) assert_scalar_congruency( bijector, lower_x=-2., upper_x=1.5, rtol=0.05) def testBijectiveAndFinite(self): with self.test_session(): - bijector = PowerTransform( - power=0.2, event_ndims=0, validate_args=True) + bijector = PowerTransform(power=0.2, validate_args=True) x = np.linspace(-4.999, 10, num=10).astype(np.float32) y = np.logspace(0.001, 10, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py index 46fe779..2d52895 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py @@ -52,24 +52,28 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): forward_x = nvp.forward(x) # Use identity to invalidate cache. inverse_y = nvp.inverse(array_ops.identity(forward_x)) - fldj = nvp.forward_log_det_jacobian(x) + forward_inverse_y = nvp.forward(inverse_y) + fldj = nvp.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. - ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x)) + ildj = nvp.inverse_log_det_jacobian( + array_ops.identity(forward_x), event_ndims=1) variables.global_variables_initializer().run() [ forward_x_, inverse_y_, + forward_inverse_y_, ildj_, fldj_, ] = sess.run([ forward_x, inverse_y, + forward_inverse_y, ildj, fldj, ]) self.assertEqual("real_nvp", nvp.name) - self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) - self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-1, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-1, atol=0.) self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) def testMutuallyConsistent(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index e216d88..46f2c63 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -65,8 +65,8 @@ class _ReshapeBijectorTest(object): ildj_) = sess.run(( bijector.inverse(expected_y), bijector.forward(expected_x), - bijector.forward_log_det_jacobian(expected_x), - bijector.inverse_log_det_jacobian(expected_y), + bijector.forward_log_det_jacobian(expected_x, event_ndims=2), + bijector.inverse_log_det_jacobian(expected_y, event_ndims=2), ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) @@ -301,7 +301,8 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): event_shape_in=[2, 3], event_shape_out=[1, 2, 3], validate_args=True) - assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + assert_bijective_and_finite( + bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): if ops._USE_C_API: diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index e4f9d72..cea4a62 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -36,12 +36,13 @@ class SigmoidBijectorTest(test.TestCase): x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32) y = special.expit(x) ildj = -np.log(y) - np.log1p(-y) - self.assertAllClose(y, Sigmoid().forward(x).eval(), atol=0., rtol=1e-2) - self.assertAllClose(x, Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4) - self.assertAllClose(ildj, Sigmoid().inverse_log_det_jacobian(y).eval(), - atol=0., rtol=1e-6) - self.assertAllClose(-ildj, Sigmoid().forward_log_det_jacobian(x).eval(), - atol=0., rtol=1e-4) + bijector = Sigmoid() + self.assertAllClose(y, bijector.forward(x).eval(), atol=0., rtol=1e-2) + self.assertAllClose(x, bijector.inverse(y).eval(), atol=0., rtol=1e-4) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval(), atol=0., rtol=1e-6) + self.assertAllClose(-ildj, bijector.forward_log_det_jacobian( + x, event_ndims=0).eval(), atol=0., rtol=1e-4) def testScalarCongruency(self): with self.test_session(): @@ -52,7 +53,8 @@ class SigmoidBijectorTest(test.TestCase): x = np.linspace(-7., 7., 100).astype(np.float32) eps = 1e-3 y = np.linspace(eps, 1. - eps, 100).astype(np.float32) - assert_bijective_and_finite(Sigmoid(), x, y, atol=0., rtol=1e-4) + assert_bijective_and_finite( + Sigmoid(), x, y, event_ndims=0, atol=0., rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 172c180..45760a2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -39,7 +39,6 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector = SinhArcsinh( skewness=skewness, tailweight=tailweight, - event_ndims=1, validate_args=True) self.assertEqual("SinhArcsinh", bijector.name) x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32) @@ -50,10 +49,11 @@ class SinhArcsinhBijectorTest(test.TestCase): np.sum( np.log(np.cosh(np.arcsinh(y) / tailweight - skewness)) - np.log(tailweight) - np.log(np.sqrt(y**2 + 1)), - axis=-1), bijector.inverse_log_det_jacobian(y).eval()) + axis=-1), + bijector.inverse_log_det_jacobian(y, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -106,14 +106,15 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True) x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace( -2, 10, 1000))).astype(np.float32) - assert_bijective_and_finite(bijector, x, x, rtol=1e-3) + assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectiveAndFiniteSkewness1Tailweight3(self): with self.test_session(): bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True) x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace( -2, 5, 1000))).astype(np.float32) - assert_bijective_and_finite(bijector, x, x, rtol=1e-3) + assert_bijective_and_finite( + bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectorEndpoints(self): with self.test_session(): @@ -124,7 +125,8 @@ class SinhArcsinhBijectorTest(test.TestCase): [np.finfo(dtype).min, np.finfo(dtype).max], dtype=dtype) # Note that the above bijector is the identity bijector. Hence, the # log_det_jacobian will be 0. Because of this we use atol. - assert_bijective_and_finite(bijector, bounds, bounds, atol=2e-6) + assert_bijective_and_finite( + bijector, bounds, bounds, event_ndims=0, atol=2e-6) def testBijectorOverRange(self): with self.test_session(): @@ -156,12 +158,12 @@ class SinhArcsinhBijectorTest(test.TestCase): np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( y_float128**2 + 1)) - np.log(tailweight), - bijector.inverse_log_det_jacobian(y).eval(), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), rtol=1e-4, atol=0.) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), rtol=1e-4, atol=0.) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index cad4dd1..0f0a2fa 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -44,12 +44,12 @@ class SoftmaxCenteredBijectorTest(test.TestCase): self.assertAllClose(x, softmax.inverse(y).eval()) self.assertAllClose( -np.sum(np.log(y), axis=1), - softmax.inverse_log_det_jacobian(y).eval(), + softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), + -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(), + softmax.forward_log_det_jacobian(x, event_ndims=1).eval(), atol=0., rtol=1e-7) @@ -67,14 +67,14 @@ class SoftmaxCenteredBijectorTest(test.TestCase): feed_dict={y: real_y})) self.assertAllClose( -np.sum(np.log(real_y), axis=1), - softmax.inverse_log_det_jacobian(y).eval( + softmax.inverse_log_det_jacobian(y, event_ndims=1).eval( feed_dict={y: real_y}), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval( + -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval( feed_dict={y: real_y}), - softmax.forward_log_det_jacobian(x).eval( + softmax.forward_log_det_jacobian(x, event_ndims=1).eval( feed_dict={x: real_x}), atol=0., rtol=1e-7) @@ -104,7 +104,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): y = np.array([y_0, y_1, y_2]) y /= y.sum(axis=0) y = y.T # y.shape = [5, 3] - assert_bijective_and_finite(softmax, x, y) + assert_bijective_and_finite(softmax, x, y, event_ndims=1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index d9af9ae..3d8a0a3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -43,13 +43,13 @@ class SoftplusBijectorTest(test.TestCase): def testHingeSoftnessZeroRaises(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=0., validate_args=True) + bijector = Softplus(hinge_softness=0., validate_args=True) with self.assertRaisesOpError("must be non-zero"): bijector.forward([1., 1.]).eval() def testBijectorForwardInverseEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) @@ -59,7 +59,7 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.5) + bijector = Softplus(hinge_softness=1.5) x = 2 * rng.randn(2, 10) y = 1.5 * self._softplus(x / 1.5) @@ -68,16 +68,17 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorLogDetJacobianEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() y = 2 * rng.rand(2, 10) # No reduction needed if event_dims = 0. ildj = self._softplus_ildj_before_reduction(y) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval()) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval()) def testBijectorForwardInverseEventDimsOne(self): with self.test_session(): - bijector = Softplus(event_ndims=1) + bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) @@ -87,58 +88,59 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorLogDetJacobianEventDimsOne(self): with self.test_session(): - bijector = Softplus(event_ndims=1) + bijector = Softplus() y = 2 * rng.rand(2, 10) ildj_before = self._softplus_ildj_before_reduction(y) ildj = np.sum(ildj_before, axis=1) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval()) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithPositiveHingeSoftness(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.3) + bijector = Softplus(hinge_softness=1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithNegativeHingeSoftness(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=-1.3) + bijector = Softplus(hinge_softness=-1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testBijectiveAndFinite32bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.23) + bijector = Softplus(hinge_softness=1.23) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=-0.7) + bijector = Softplus(hinge_softness=-0.7) x = np.linspace(-20., 20., 100).astype(np.float32) y = -np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFinite16bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() # softplus(-20) is zero, so we can't use such a large range as in 32bit. x = np.linspace(-10., 20., 100).astype(np.float16) # Note that float16 is only in the open set (0, inf) for a smaller @@ -146,7 +148,7 @@ class SoftplusBijectorTest(test.TestCase): # for the test. y = np.logspace(-6, 3, 100).astype(np.float16) assert_bijective_and_finite( - bijector, x, y, rtol=1e-1, atol=1e-3) + bijector, x, y, event_ndims=0, rtol=1e-1, atol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py index f03d6f1..30c7a73 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py @@ -41,10 +41,11 @@ class SquareBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7) + ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval(), atol=0., rtol=1e-7) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), atol=0., rtol=1e-7) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py index 7a31228..f57adcd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py @@ -36,7 +36,7 @@ class WeibullBijectorTest(test.TestCase): concentration = 0.3 bijector = Weibull( scale=scale, concentration=concentration, - event_ndims=1, validate_args=True) + validate_args=True) self.assertEqual("weibull", bijector.name) x = np.array([[[0.], [1.], [14.], [20.], [100.]]], dtype=np.float32) # Weibull distribution @@ -45,13 +45,11 @@ class WeibullBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - np.squeeze(weibull_dist.logpdf(x), axis=2), - bijector.forward_log_det_jacobian(x).eval()) + weibull_dist.logpdf(x), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), rtol=1e-4, atol=0.) @@ -64,12 +62,12 @@ class WeibullBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): bijector = Weibull( - scale=20., concentration=2., event_ndims=0, validate_args=True) + scale=20., concentration=2., validate_args=True) x = np.linspace(1., 8., num=10).astype(np.float32) y = np.linspace( -np.expm1(-1 / 400.), -np.expm1(-16), num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py index 5454719..4e8989b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py @@ -44,6 +44,7 @@ class _ChooseLocation(ConditionalBijector): graph_parents=[self._loc], is_constant_jacobian=True, validate_args=False, + forward_min_event_ndims=0, name=name) def _forward(self, x, z): @@ -52,7 +53,7 @@ class _ChooseLocation(ConditionalBijector): def _inverse(self, x, z): return x - self._gather_loc(z) - def _inverse_log_det_jacobian(self, x, z=None): + def _inverse_log_det_jacobian(self, x, event_ndims, z=None): return 0. def _gather_loc(self, z): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 933756a..9635134 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -68,7 +68,7 @@ class MultivariateNormalDiagTest(test.TestCase): dist = ds.TransformedDistribution( base_dist, validate_args=True, - bijector=bijectors.Softplus(event_ndims=1)) + bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index f0ba1ec..5fe1331 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -36,6 +37,35 @@ ds = distributions la = linalg +class DummyMatrixTransform(bs.Bijector): + """Tractable matrix transformation. + + This is a non-sensical bijector that has forward/inverse_min_event_ndims=2. + The main use is to check that transformed distribution calculations are done + appropriately. + """ + + def __init__(self): + super(DummyMatrixTransform, self).__init__( + forward_min_event_ndims=2, + is_constant_jacobian=False, + validate_args=False, + name="dummy") + + def _forward(self, x): + return x + + def _inverse(self, y): + return y + + # Note: These jacobians don't make sense. + def _forward_log_det_jacobian(self, x): + return -linalg_ops.matrix_determinant(x) + + def _inverse_log_det_jacobian(self, x): + return linalg_ops.matrix_determinant(x) + + class TransformedDistributionTest(test.TestCase): def _cls(self): @@ -55,7 +85,7 @@ class TransformedDistributionTest(test.TestCase): # you may or may not need a reduce_sum. log_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.Exp(event_ndims=0)) + bijector=bs.Exp()) sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu)) # sample @@ -87,7 +117,7 @@ class TransformedDistributionTest(test.TestCase): sigma = 2.0 abs_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.AbsoluteValue(event_ndims=0)) + bijector=bs.AbsoluteValue()) sp_normal = stats.norm(mu, sigma) # sample @@ -129,7 +159,7 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(grid, cdf_, rtol=1e-6, atol=0.) def testCachedSamples(self): - exp_forward_only = bs.Exp(event_ndims=0) + exp_forward_only = bs.Exp() exp_forward_only._inverse = self._make_unimplemented( "inverse") exp_forward_only._inverse_event_shape_tensor = self._make_unimplemented( @@ -153,7 +183,7 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(expected_log_pdf, log_pdf_val, rtol=1e-4, atol=0.) def testCachedSamplesInvert(self): - exp_inverse_only = bs.Exp(event_ndims=0) + exp_inverse_only = bs.Exp() exp_inverse_only._forward = self._make_unimplemented( "forward") exp_inverse_only._forward_event_shape_tensor = self._make_unimplemented( @@ -210,8 +240,11 @@ class TransformedDistributionTest(test.TestCase): int_identity = bs.Inline( forward_fn=array_ops.identity, inverse_fn=array_ops.identity, - inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), - forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + inverse_log_det_jacobian_fn=( + lambda y: math_ops.cast(0, dtypes.int32)), + forward_log_det_jacobian_fn=( + lambda x: math_ops.cast(0, dtypes.int32)), + forward_min_event_ndims=0, is_constant_jacobian=True) normal = self._cls()( distribution=ds.Normal(loc=0., scale=1.), @@ -435,6 +468,82 @@ class ScalarToMultiTest(test.TestCase): event_shape=[3], validate_args=True) + def testMatrixEvent(self): + with self.test_session() as sess: + batch_shape = [2] + event_shape = [2, 3, 3] + batch_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_batch_shape") + event_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_event_shape") + feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32), + event_shape_pl: np.array(event_shape, dtype=np.int32)} + + scale = 2. + loc = 0. + fake_mvn_dynamic = self._cls()( + distribution=ds.Normal( + loc=loc, + scale=scale), + bijector=DummyMatrixTransform(), + batch_shape=batch_shape_pl, + event_shape=event_shape_pl, + validate_args=True) + + fake_mvn_static = self._cls()( + distribution=ds.Normal( + loc=loc, + scale=scale), + bijector=DummyMatrixTransform(), + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=True) + + def actual_mvn_log_prob(x): + # This distribution is the normal PDF, reduced over the + # last 3 dimensions + a jacobian term which corresponds + # to the determinant of x. + return (np.sum( + stats.norm(loc, scale).logpdf(x), axis=(-1, -2, -3)) + + np.sum(np.linalg.det(x), axis=-1)) + + self.assertAllEqual([2, 3, 3], fake_mvn_static.event_shape) + self.assertAllEqual([2], fake_mvn_static.batch_shape) + + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.event_shape) + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.batch_shape) + + num_samples = 5e3 + for fake_mvn, feed_dict in ((fake_mvn_static, {}), + (fake_mvn_dynamic, feed_dict)): + # Ensure sample works by checking first, second moments. + y = fake_mvn.sample(int(num_samples), seed=0) + x = y[0:5, ...] + [ + x_, + fake_event_shape_, + fake_batch_shape_, + fake_log_prob_, + fake_prob_, + ] = sess.run([ + x, + fake_mvn.event_shape_tensor(), + fake_mvn.batch_shape_tensor(), + fake_mvn.log_prob(x), + fake_mvn.prob(x), + ], feed_dict=feed_dict) + + # Ensure all other functions work as intended. + self.assertAllEqual([5, 2, 2, 3, 3], x_.shape) + self.assertAllEqual([2, 3, 3], fake_event_shape_) + self.assertAllEqual([2], fake_batch_shape_) + self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_, + atol=0., rtol=1e-6) + self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_, + atol=0., rtol=1e-5) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py index c355ade..1226c66 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py @@ -61,7 +61,7 @@ class VectorLaplaceDiagTest(test.TestCase): dist = ds.TransformedDistribution( base_dist, validate_args=True, - bijector=bijectors.Softplus(event_ndims=1)) + bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 0fe9f6a..c9e31d7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,9 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import constant_op from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -72,38 +70,22 @@ class AbsoluteValue(bijector.Bijector): """ - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. validate_args: Python `bool` indicating whether arguments should be checked for correctness, in particular whether inputs to `inverse` and `inverse_log_det_jacobian` are non-negative. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. """ self._graph_parents = [] self._name = name - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - with self._name_scope("init"): super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, + is_constant_jacobian=True, validate_args=validate_args, name=name) @@ -121,8 +103,7 @@ class AbsoluteValue(bijector.Bijector): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + zeros = constant_op.constant(0., dtype=y.dtype) if self.validate_args: zeros = control_flow_ops.with_dependencies( [check_ops.assert_non_negative(y, message="Argument y was negative")], diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index bef7bbb..b4c2939 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -184,6 +184,7 @@ class Affine(bijector.Bijector): with self._name_scope("init", values=[ shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_diag, scale_perturb_factor]): + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -234,7 +235,7 @@ class Affine(bijector.Bijector): event_ndims=1, validate_args=validate_args) super(Affine, self).__init__( - event_ndims=1, + forward_min_event_ndims=1, graph_parents=( [self._scale] if tensor_util.is_tensor(self._scale) else self._scale.graph_parents + @@ -360,16 +361,17 @@ class Affine(bijector.Bijector): x, sample_shape, expand_batch_dim=False) return x - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. event_size = array_ops.shape(x)[-1] event_size = math_ops.cast(event_size, dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() def _maybe_check_scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 89043b1..59f9742 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -22,9 +22,6 @@ from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator @@ -94,7 +91,6 @@ class AffineLinearOperator(bijector.Bijector): def __init__(self, shift=None, scale=None, - event_ndims=1, validate_args=False, name="affine_linear_operator"): """Instantiates the `AffineLinearOperator` bijector. @@ -103,14 +99,11 @@ class AffineLinearOperator(bijector.Bijector): shift: Floating-point `Tensor`. scale: Subclass of `LinearOperator`. Represents the (batch) positive definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: - ValueError: if `event_ndims` is not 0 or 1. TypeError: if `scale` is not a `LinearOperator`. TypeError: if `shift.dtype` does not match `scale.dtype`. ValueError: if not `scale.is_non_singular`. @@ -120,20 +113,6 @@ class AffineLinearOperator(bijector.Bijector): self._validate_args = validate_args graph_parents = [] with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -166,10 +145,10 @@ class AffineLinearOperator(bijector.Bijector): self._scale = scale self._shaper = _DistributionShape( batch_ndims=batch_ndims, - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args) super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, graph_parents=graph_parents, is_constant_jacobian=True, dtype=dtype, @@ -213,12 +192,13 @@ class AffineLinearOperator(bijector.Bijector): x, sample_shape, expand_batch_dim=False) return x - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) + return constant_op.constant(0., dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if self.validate_args else []): return self.scale.log_abs_determinant() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py index 8adaa54..cd792e2 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -99,7 +100,7 @@ class AffineScalar(bijector.Bijector): self._scale) super(AffineScalar, self).__init__( - event_ndims=0, + forward_min_event_ndims=0, is_constant_jacobian=True, validate_args=validate_args, name=name) @@ -131,8 +132,10 @@ class AffineScalar(bijector.Bijector): return x def _forward_log_det_jacobian(self, x): - log_det_jacobian = array_ops.zeros_like(x) + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self.scale is None: - return log_det_jacobian - log_det_jacobian += math_ops.log(math_ops.abs(self.scale)) - return log_det_jacobian + return constant_op.constant(0., dtype=x.dtype.base_dtype) + + return math_ops.log(math_ops.abs(self.scale)) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index 33fdd32..224cec8 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -157,7 +157,12 @@ class BatchNormalization(bijector.Bijector): gamma_constraint=g_constraint) self._validate_bn_layer(self.batchnorm) self._training = training + if isinstance(self.batchnorm.axis, int): + forward_min_event_ndims = 1 + else: + forward_min_event_ndims = len(self.batchnorm.axis) super(BatchNormalization, self).__init__( + forward_min_event_ndims=forward_min_event_ndims, validate_args=validate_args, name=name) def _validate_bn_layer(self, layer): @@ -186,7 +191,6 @@ class BatchNormalization(bijector.Bijector): input_shape = np.int32(x.shape.as_list()) ndims = len(input_shape) - # event_dims = self._compute_event_dims(x) reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 3ce7c26..85ad23e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -21,6 +21,9 @@ from __future__ import print_function import itertools from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector @@ -29,6 +32,91 @@ __all__ = [ ] +def _use_static_shape(input_tensor, ndims): + return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) + + +def _maybe_get_event_ndims_statically(event_ndims): + static_event_ndims = (event_ndims if isinstance(event_ndims, int) + else tensor_util.constant_value(event_ndims)) + if static_event_ndims is not None: + return static_event_ndims + + return event_ndims + + +def _compute_min_event_ndims(bijector_list, compute_forward=True): + """Computes the min_event_ndims associated with the give list of bijectors. + + Given a list `bijector_list` of bijectors, compute the min_event_ndims that is + associated with the composition of bijectors in that list. + + min_event_ndims is the # of right most dimensions for which the bijector has + done necessary computation on (i.e. the non-broadcastable part of the + computation). + + We can derive the min_event_ndims for a chain of bijectors as follows: + + In the case where there are no rank changing bijectors, this will simply be + `max(b.forward_min_event_ndims for b in bijector_list)`. This is because the + bijector with the most forward_min_event_ndims requires the most dimensions, + and hence the chain also requires operating on those dimensions. + + However in the case of rank changing, more care is needed in determining the + exact amount of dimensions. Padding dimensions causes subsequent bijectors to + operate on the padded dimensions, and Removing dimensions causes bijectors to + operate more left. + + Args: + bijector_list: List of bijectors to be composed by chain. + compute_forward: Boolean. If True, computes the min_event_ndims associated + with a forward call to Chain, and otherwise computes the min_event_ndims + associated with an inverse call to Chain. The latter is the same as the + min_event_ndims associated with a forward call to Invert(Chain(....)). + + Returns: + min_event_ndims + """ + min_event_ndims = 0 + # This is a mouthful, but what this encapsulates is that if not for rank + # changing bijectors, we'd only need to compute the largest of the min + # required ndims. Hence "max_min". Due to rank changing bijectors, we need to + # account for synthetic rank growth / synthetic rank decrease from a rank + # changing bijector. + rank_changed_adjusted_max_min_event_ndims = 0 + + if compute_forward: + bijector_list = reversed(bijector_list) + + for b in bijector_list: + if compute_forward: + current_min_event_ndims = b.forward_min_event_ndims + current_inverse_min_event_ndims = b.inverse_min_event_ndims + else: + current_min_event_ndims = b.inverse_min_event_ndims + current_inverse_min_event_ndims = b.forward_min_event_ndims + + # New dimensions were touched. + if rank_changed_adjusted_max_min_event_ndims < current_min_event_ndims: + min_event_ndims += ( + current_min_event_ndims - rank_changed_adjusted_max_min_event_ndims) + rank_changed_adjusted_max_min_event_ndims = max( + current_min_event_ndims, rank_changed_adjusted_max_min_event_ndims) + + # If the number of dimensions has increased via forward, then + # inverse_min_event_ndims > forward_min_event_ndims, and hence the + # dimensions we computed on, have moved left (so we have operated + # on additional dimensions). + # Conversely, if the number of dimensions has decreased via forward, + # then we have inverse_min_event_ndims < forward_min_event_ndims, + # and so we will have operated on fewer right most dimensions. + + number_of_changed_dimensions = ( + current_min_event_ndims - current_inverse_min_event_ndims) + rank_changed_adjusted_max_min_event_ndims -= number_of_changed_dimensions + return min_event_ndims + + class Chain(bijector.Bijector): """Bijector which applies a sequence of bijectors. @@ -93,21 +181,24 @@ class Chain(bijector.Bijector): raise ValueError("incompatible dtypes: %s" % dtype) elif len(dtype) == 2: dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims elif len(dtype) == 1: dtype = dtype[0] - event_ndims = bijectors[0].event_ndims else: dtype = None - event_ndims = None + + inverse_min_event_ndims = _compute_min_event_ndims( + bijectors, compute_forward=False) + forward_min_event_ndims = _compute_min_event_ndims( + bijectors, compute_forward=True) super(Chain, self).__init__( graph_parents=list(itertools.chain.from_iterable( b.graph_parents for b in bijectors)), + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), validate_args=validate_args, dtype=dtype, - event_ndims=event_ndims, name=name or ("identity" if not bijectors else "_of_".join(["chain"] + [b.name for b in bijectors]))) @@ -147,10 +238,31 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") + ildj = constant_op.constant( + 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + + if not self.bijectors: + return ildj + + event_ndims = _maybe_get_event_ndims_statically( + self.inverse_min_event_ndims) + + if _use_static_shape(y, event_ndims): + event_shape = y.shape[y.shape.ndims - event_ndims:] + else: + event_shape = array_ops.shape(y)[array_ops.rank(y) - event_ndims:] + for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + ildj += b.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **kwargs.get(b.name, {})) + + if _use_static_shape(y, event_ndims): + event_shape = b.inverse_event_shape(event_shape) + event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + else: + event_shape = b.inverse_event_shape_tensor(event_shape) + event_ndims = _maybe_get_event_ndims_statically( + array_ops.rank(event_shape)) y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -160,9 +272,34 @@ class Chain(bijector.Bijector): return x def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") + x = ops.convert_to_tensor(x, name="x") + + fldj = constant_op.constant( + 0., dtype=x.dtype, name="inverse_log_det_jacobian") + + if not self.bijectors: + return fldj + + event_ndims = _maybe_get_event_ndims_statically( + self.forward_min_event_ndims) + + if _use_static_shape(x, event_ndims): + event_shape = x.shape[x.shape.ndims - event_ndims:] + else: + event_shape = array_ops.shape(x)[array_ops.rank(x) - event_ndims:] + for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + fldj += b.forward_log_det_jacobian( + x, event_ndims=event_ndims, **kwargs.get(b.name, {})) + if _use_static_shape(x, event_ndims): + event_shape = b.forward_event_shape(event_shape) + event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + else: + event_shape = b.forward_event_shape_tensor(event_shape) + event_ndims = _maybe_get_event_ndims_statically( + array_ops.rank(event_shape)) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj + diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 8f09e16..caae2ad 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -80,7 +80,7 @@ class CholeskyOuterProduct(bijector.Bijector): self._graph_parents = [] self._name = name super(CholeskyOuterProduct, self).__init__( - event_ndims=2, + forward_min_event_ndims=2, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index ccb1f02..e9e994f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -44,12 +44,16 @@ class ConditionalBijector(bijector.Bijector): "**condition_kwargs": "Named arguments forwarded to subclass implementation."}) def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + self, y, event_ndims, name="inverse_log_det_jacobian", + **condition_kwargs): + return self._call_inverse_log_det_jacobian( + y, event_ndims, name, **condition_kwargs) @distribution_util.AppendDocstring(kwargs_dict={ "**condition_kwargs": "Named arguments forwarded to subclass implementation."}) def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) + self, x, event_ndims, name="forward_log_det_jacobian", + **condition_kwargs): + return self._call_forward_log_det_jacobian( + x, event_ndims, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index b1ff840..9fc1bbf 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -33,8 +33,8 @@ class Exp(power_transform.PowerTransform): ```python # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) + # batch ndim 2. + exp = Exp() x = [[[1., 2], [3, 4]], [[5, 6], @@ -48,19 +48,17 @@ class Exp(power_transform.PowerTransform): """ def __init__(self, - event_ndims=0, validate_args=False, name="exp"): """Instantiates the `Exp` bijector. Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ + # forward_min_event_ndims = 0. + # No forward_min_event_ndims specified as this is done in PowerTransform. super(Exp, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index 67f3978..e656a25 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -48,7 +48,6 @@ class Gumbel(bijector.Bijector): def __init__(self, loc=0., scale=1., - event_ndims=0, validate_args=False, name="gumbel"): """Instantiates the `Gumbel` bijector. @@ -60,8 +59,6 @@ class Gumbel(bijector.Bijector): scale: Positive Float-like `Tensor` that is the same dtype and is broadcastable with `loc`. This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -80,7 +77,9 @@ class Gumbel(bijector.Bijector): ], self._scale) super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + validate_args=validate_args, + forward_min_event_ndims=0, + name=name) @property def loc(self): @@ -102,15 +101,11 @@ class Gumbel(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + return math_ops.log(self.scale / (-math_ops.log(y) * y)) def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + return -z - math_ops.exp(-z) - math_ops.log(self.scale) def _maybe_assert_valid_y(self, y): if not self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index fab1b22..2bde956 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -40,7 +40,7 @@ class Inline(bijector.Bijector): name="exp") ``` - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + The above example is equivalent to the `Bijector` `Exp()`. """ def __init__(self, @@ -54,6 +54,8 @@ class Inline(bijector.Bijector): inverse_event_shape_tensor_fn=None, is_constant_jacobian=False, validate_args=False, + forward_min_event_ndims=None, + inverse_min_event_ndims=None, name="inline"): """Creates a `Bijector` from callables. @@ -76,10 +78,15 @@ class Inline(bijector.Bijector): constant for all input arguments. validate_args: Python `bool` indicating whether arguments should be checked for correctness. + forward_min_event_ndims: Python `int` indicating the minimal + dimensionality this bijector acts on. + inverse_min_event_ndims: Python `int` indicating the minimal + dimensionality this bijector acts on. name: Python `str`, name given to ops managed by this object. """ super(Inline, self).__init__( - event_ndims=0, + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -134,8 +141,8 @@ class Inline(bijector.Bijector): "inverse_log_det_jacobian_fn is not a callable function.") return self._inverse_log_det_jacobian_fn(y, **kwargs) - def _forward_log_det_jacobian(self, y, **kwargs): + def _forward_log_det_jacobian(self, x, **kwargs): if not callable(self._forward_log_det_jacobian_fn): raise NotImplementedError( "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) + return self._forward_log_det_jacobian_fn(x, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 2c603fe..1904239 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -66,8 +66,9 @@ class Invert(bijector_lib.Bijector): self._bijector = bijector super(Invert, self).__init__( - event_ndims=bijector.event_ndims, graph_parents=bijector.graph_parents, + forward_min_event_ndims=bijector.inverse_min_event_ndims, + inverse_min_event_ndims=bijector.forward_min_event_ndims, is_constant_jacobian=bijector.is_constant_jacobian, validate_args=validate_args, dtype=bijector.dtype, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py index f5de052..97000c1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -48,7 +47,6 @@ class Kumaraswamy(bijector.Bijector): def __init__(self, concentration1=None, concentration0=None, - event_ndims=0, validate_args=False, name="kumaraswamy"): """Instantiates the `Kumaraswamy` bijector. @@ -60,31 +58,14 @@ class Kumaraswamy(bijector.Bijector): concentration0: Python `float` scalar indicating the transform power, i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is `concentration0`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. """ self._graph_parents = [] self._name = name self._validate_args = validate_args - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - with self._name_scope("init", values=[concentration1, concentration0]): concentration1 = self._maybe_assert_valid_concentration( ops.convert_to_tensor(concentration1, name="concentration1"), @@ -96,7 +77,7 @@ class Kumaraswamy(bijector.Bijector): self._concentration1 = concentration1 self._concentration0 = concentration0 super(Kumaraswamy, self).__init__( - event_ndims=0, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -123,12 +104,10 @@ class Kumaraswamy(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( + return ( math_ops.log(self.concentration1) + math_ops.log(self.concentration0) + (self.concentration1 - 1) * math_ops.log(y) + - (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1), - axis=event_dims) + (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1)) def _maybe_assert_valid_concentration(self, concentration, validate_args): """Checks the validity of a concentration parameter.""" diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 84b2340..ef56cf6 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -61,7 +61,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): this property by zeroing out weights in its `masked_dense` layers. In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression" is implemented using a `tf.while_loop` and a deep neural network (DNN) with masked weights such that the autoregressive property is automatically met in the `inverse`. @@ -220,6 +220,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): self._shift_and_log_scale_fn = shift_and_log_scale_fn self._unroll_loop = unroll_loop super(MaskedAutoregressiveFlow, self).__init__( + forward_min_event_ndims=1, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 8654cc3..4978167 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -114,6 +114,7 @@ class Permute(bijector_lib.Bijector): ], permutation) self._permutation = permutation super(Permute, self).__init__( + forward_min_event_ndims=1, is_constant_jacobian=True, validate_args=validate_args, name=name or "permute") @@ -132,7 +133,10 @@ class Permute(bijector_lib.Bijector): axis=-1) def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. + return constant_op.constant(0., dtype=y.dtype.base_dtype) def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) + return constant_op.constant(0., dtype=x.dtype.base_dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index c37db61..71f123f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -43,7 +43,6 @@ class PowerTransform(bijector.Bijector): def __init__(self, power=0., - event_ndims=0, validate_args=False, name="power_transform"): """Instantiates the `PowerTransform` bijector. @@ -51,8 +50,6 @@ class PowerTransform(bijector.Bijector): Args: power: Python `float` scalar indicating the transform power, i.e., `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -70,7 +67,7 @@ class PowerTransform(bijector.Bijector): raise ValueError("`power` must be a non-negative TF constant.") self._power = power super(PowerTransform, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -97,18 +94,13 @@ class PowerTransform(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) + return (self.power - 1.) * math_ops.log(y) def _forward_log_det_jacobian(self, x): x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) + return x + return (1. / self.power - 1.) * math_ops.log1p(x * self.power) def _maybe_assert_valid_x(self, x): if not self.validate_args or self.power == 0.: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 71ab369..f09ab21 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -166,7 +166,7 @@ class RealNVP(bijector_lib.Bijector): self._input_depth = None self._shift_and_log_scale_fn = shift_and_log_scale_fn super(RealNVP, self).__init__( - event_ndims=1, + forward_min_event_ndims=1, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -224,7 +224,7 @@ class RealNVP(bijector_lib.Bijector): _, log_scale = self._shift_and_log_scale_fn( x0, self._input_depth - self._num_masked) if log_scale is None: - return constant_op.constant(0., dtype=x.dtype, name="ildj") + return constant_op.constant(0., dtype=x.dtype, name="fldj") return math_ops.reduce_sum(log_scale, axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 55eca06..82210cd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -128,9 +128,11 @@ class Reshape(bijector_lib.Bijector): self._event_shape_in = event_shape_in self._event_shape_out = event_shape_out - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") + super(Reshape, self).__init__( + forward_min_event_ndims=0, + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") def _maybe_check_valid_shape(self, shape, validate_args): """Check that a shape Tensor is int-type and otherwise sane.""" diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index a640dfe..5df8c88 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -33,7 +33,9 @@ class Sigmoid(bijector.Bijector): def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) def _forward(self, x): return math_ops.sigmoid(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 3a75e4a..2a32e8ab 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -91,7 +91,6 @@ class SinhArcsinh(bijector.Bijector): def __init__(self, skewness=None, tailweight=None, - event_ndims=0, validate_args=False, name="SinhArcsinh"): """Instantiates the `SinhArcsinh` bijector. @@ -101,8 +100,6 @@ class SinhArcsinh(bijector.Bijector): of type `float32`. tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -125,7 +122,9 @@ class SinhArcsinh(bijector.Bijector): message="Argument tailweight was not positive") ], self._tailweight) super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) @property def skewness(self): @@ -149,31 +148,29 @@ class SinhArcsinh(bijector.Bijector): # dx/dy # = cosh(arcsinh(y) / tailweight - skewness) # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + return ( math_ops.log(math_ops.cosh( math_ops.asinh(y) / self.tailweight - self.skewness) # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) + - math_ops.log(self.tailweight)) def _forward_log_det_jacobian(self, x): # y = sinh((arcsinh(x) + skewness) * tailweight) # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), # dy/dx # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + return ( math_ops.log(math_ops.cosh( (math_ops.asinh(x) + self.skewness) * self.tailweight) # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) + + math_ops.log(self.tailweight)) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index dc94fd0..f52b915 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -66,7 +66,7 @@ class SoftmaxCentered(bijector.Bijector): self._graph_parents = [] self._name = name super(SoftmaxCentered, self).__init__( - event_ndims=1, + forward_min_event_ndims=1, validate_args=validate_args, name=name) @@ -105,8 +105,6 @@ class SoftmaxCentered(bijector.Bijector): y.shape.assert_is_compatible_with(shape) y.set_shape(shape) - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). return nn_ops.softmax(y) def _inverse(self, y): @@ -162,8 +160,6 @@ class SoftmaxCentered(bijector.Bijector): # -log_normalization + reduce_sum(logits - log_normalization) log_normalization = nn_ops.softplus( math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) + return array_ops.squeeze( + (-log_normalization + math_ops.reduce_sum( + x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 81957fc..96a938c 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -62,7 +62,7 @@ class Softplus(bijector.Bijector): ```python # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) + softplus = Softplus() x = [[[1., 2], [3, 4]], [[5, 6], @@ -81,7 +81,6 @@ class Softplus(bijector.Bijector): "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) def __init__(self, - event_ndims=0, hinge_softness=None, validate_args=False, name="softplus"): @@ -101,7 +100,7 @@ class Softplus(bijector.Bijector): [nonzero_check], self.hinge_softness) super(Softplus, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -130,14 +129,12 @@ class Softplus(bijector.Bijector): # 1 - exp{-Y} approx Y. if self.hinge_softness is not None: y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) + return -math_ops.log(-math_ops.expm1(-y)) def _forward_log_det_jacobian(self, x): if self.hinge_softness is not None: x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) + return -nn_ops.softplus(-x) @property def hinge_softness(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py index 1e9dbf3..2ccfdc9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -59,7 +59,7 @@ class Square(bijector.Bijector): """ self._name = name super(Square, self).__init__( - event_ndims=0, + forward_min_event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index 00520bc..39129cd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -50,7 +50,6 @@ class Weibull(bijector.Bijector): def __init__(self, scale=1., concentration=1., - event_ndims=0, validate_args=False, name="weibull"): """Instantiates the `Weibull` bijector. @@ -62,8 +61,6 @@ class Weibull(bijector.Bijector): concentration: Positive Float-type `Tensor` that is the same dtype and is broadcastable with `scale`. This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -89,7 +86,7 @@ class Weibull(bijector.Bijector): ], self._concentration) super(Weibull, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -113,22 +110,18 @@ class Weibull(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( + return ( -math_ops.log1p(-y) + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) + math_ops.log(self.scale / self.concentration)) def _forward_log_det_jacobian(self, x): x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( + return ( -(x / self.scale) ** self.concentration + (self.concentration - 1) * math_ops.log(x) + math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) + -self.concentration * math_ops.log(self.scale)) def _maybe_assert_valid_x(self, x): if not self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 1d4c566..10b4536 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -105,7 +106,9 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + event_ndims = self._maybe_get_event_ndims_statically() + ildj = self.bijector.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_log_prob_for_one_fiber(y, x, ildj, distribution_kwargs) @@ -128,7 +131,9 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + event_ndims = self._maybe_get_event_ndims_statically() + ildj = self.bijector.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_prob_for_one_fiber(y, x, ildj, distribution_kwargs) @@ -214,3 +219,15 @@ class ConditionalTransformedDistribution( # implies the qth quantile of Y is g(x_q). inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) + + def _maybe_get_event_ndims_statically(self): + if self.event_shape.ndims is not None: + return self.event_shape.ndims + + event_ndims = array_ops.size(self.event_shape_tensor()) + static_event_ndims = tensor_util.constant_value(event_ndims) + + if static_event_ndims is not None: + return static_event_ndims + + return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 92f2bba..3314181 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -114,7 +114,7 @@ def quadrature_scheme_lognormal_quantiles( # Create a LogNormal distribution. dist = transformed_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=loc, scale=scale), - bijector=Exp(event_ndims=0), + bijector=Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index f56ba07..02cf3c7 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -409,5 +409,5 @@ class RelaxedOneHotCategorical( validate_args=validate_args, allow_nan_stats=allow_nan_stats) super(RelaxedOneHotCategorical, self).__init__(dist, - bijectors.Exp(event_ndims=1), + bijectors.Exp(), name=name) diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index 0d8a192..cde6d85 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -166,13 +166,13 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): # Make the SAS bijector, 'F'. f = bijectors.SinhArcsinh( - skewness=skewness, tailweight=tailweight, event_ndims=0) + skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = bijectors.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), - tailweight=tailweight, event_ndims=0) + tailweight=tailweight) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 971d65c..da271a8 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -427,7 +427,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._endpoint_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, - event_ndims=1, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale))] @@ -467,7 +466,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._interpolated_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, - event_ndims=1, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( @@ -621,9 +619,11 @@ class VectorDiffeomixture(distribution_lib.Distribution): log_prob = math_ops.reduce_sum(self.distribution.log_prob(y), axis=-2) # Because the affine transformation has a constant Jacobian, it is the case # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general. - fldj = array_ops.stack( - [aff.forward_log_det_jacobian(x) for aff in self.interpolated_affine], - axis=-1) + fldj = array_ops.stack([ + aff.forward_log_det_jacobian( + x, + event_ndims=array_ops.rank(self.event_shape_tensor()) + ) for aff in self.interpolated_affine], axis=-1) return math_ops.reduce_logsumexp( self.mixture_distribution.logits - fldj + log_prob, axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 003c66b..05919be 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -215,13 +215,13 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): tailweight = ops.convert_to_tensor( tailweight, dtype=dtype, name="tailweight") f = bijectors.SinhArcsinh( - skewness=skewness, tailweight=tailweight, event_ndims=1) + skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = bijectors.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), - tailweight=tailweight, event_ndims=0) + tailweight=tailweight) # Make the Affine bijector, Z --> loc + C * Z. c = 2 * scale_diag_part / f_noskew.forward( diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py index 9f9fb5c..1858224 100644 --- a/tensorflow/python/kernel_tests/distributions/bijector_test.py +++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import abc +import numpy as np import six from tensorflow.python.framework import constant_op @@ -43,11 +44,10 @@ class BaseBijectorTest(test.TestCase): """Minimal specification of a `Bijector`.""" def __init__(self): - super(_BareBonesBijector, self).__init__() + super(_BareBonesBijector, self).__init__(forward_min_event_ndims=0) with self.test_session() as sess: bij = _BareBonesBijector() - self.assertEqual(None, bij.event_ndims) self.assertEqual([], bij.graph_parents) self.assertEqual(False, bij.is_constant_jacobian) self.assertEqual(False, bij.validate_args) @@ -67,13 +67,21 @@ class BaseBijectorTest(test.TestCase): self.assertAllEqual(shape, inverse_event_shape_) self.assertAllEqual(shape, bij.inverse_event_shape(shape)) - for fn in ["forward", - "inverse", - "inverse_log_det_jacobian", - "forward_log_det_jacobian"]: - with self.assertRaisesRegexp( - NotImplementedError, fn + " not implemented"): - getattr(bij, fn)(0) + with self.assertRaisesRegexp( + NotImplementedError, "inverse not implemented"): + bij.inverse(0) + + with self.assertRaisesRegexp( + NotImplementedError, "forward not implemented"): + bij.forward(0) + + with self.assertRaisesRegexp( + NotImplementedError, "inverse_log_det_jacobian not implemented"): + bij.inverse_log_det_jacobian(0, event_ndims=0) + + with self.assertRaisesRegexp( + NotImplementedError, "forward_log_det_jacobian not implemented"): + bij.forward_log_det_jacobian(0, event_ndims=0) class IntentionallyMissingError(Exception): @@ -85,7 +93,7 @@ class BrokenBijector(bijector.Bijector): def __init__(self, forward_missing=False, inverse_missing=False): super(BrokenBijector, self).__init__( - event_ndims=0, validate_args=False, name="broken") + validate_args=False, forward_min_event_ndims=0, name="broken") self._forward_missing = forward_missing self._inverse_missing = inverse_missing @@ -120,35 +128,42 @@ class BijectorCachingTestBase(object): def testCachingOfForwardResults(self): broken_bijector = self.broken_bijector_cls(inverse_missing=True) - with self.test_session(): - x = constant_op.constant(1.1) + x = constant_op.constant(1.1) + + # Call forward and forward_log_det_jacobian one-by-one (not together). + y = broken_bijector.forward(x) + _ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0) - # Call forward and forward_log_det_jacobian one-by-one (not together). - y = broken_bijector.forward(x) - _ = broken_bijector.forward_log_det_jacobian(x) + # Now, everything should be cached if the argument is y. + broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) + try: + broken_bijector.inverse(y) + broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) + except IntentionallyMissingError: + raise AssertionError("Tests failed! Cached values not used.") - # Now, everything should be cached if the argument is y. - try: - broken_bijector.inverse(y) - broken_bijector.inverse_log_det_jacobian(y) - except IntentionallyMissingError: - raise AssertionError("Tests failed! Cached values not used.") + # Different event_ndims should not be cached. + with self.assertRaises(IntentionallyMissingError): + broken_bijector.inverse_log_det_jacobian(y, event_ndims=1) def testCachingOfInverseResults(self): broken_bijector = self.broken_bijector_cls(forward_missing=True) - with self.test_session(): - y = constant_op.constant(1.1) + y = constant_op.constant(1.1) - # Call inverse and inverse_log_det_jacobian one-by-one (not together). - x = broken_bijector.inverse(y) - _ = broken_bijector.inverse_log_det_jacobian(y) + # Call inverse and inverse_log_det_jacobian one-by-one (not together). + x = broken_bijector.inverse(y) + _ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0) - # Now, everything should be cached if the argument is x. - try: - broken_bijector.forward(x) - broken_bijector.forward_log_det_jacobian(x) - except IntentionallyMissingError: - raise AssertionError("Tests failed! Cached values not used.") + # Now, everything should be cached if the argument is x. + try: + broken_bijector.forward(x) + broken_bijector.forward_log_det_jacobian(x, event_ndims=0) + except IntentionallyMissingError: + raise AssertionError("Tests failed! Cached values not used.") + + # Different event_ndims should not be cached. + with self.assertRaises(IntentionallyMissingError): + broken_bijector.forward_log_det_jacobian(x, event_ndims=1) class BijectorCachingTest(BijectorCachingTestBase, test.TestCase): @@ -159,5 +174,107 @@ class BijectorCachingTest(BijectorCachingTestBase, test.TestCase): return BrokenBijector +class ExpOnlyJacobian(bijector.Bijector): + """Only used for jacobian calculations.""" + + def __init__(self, forward_min_event_ndims=0): + super(ExpOnlyJacobian, self).__init__( + validate_args=False, + is_constant_jacobian=False, + forward_min_event_ndims=forward_min_event_ndims, + name="exp") + + def _inverse_log_det_jacobian(self, y): + return -math_ops.log(y) + + def _forward_log_det_jacobian(self, x): + return math_ops.log(x) + + +class ConstantJacobian(bijector.Bijector): + """Only used for jacobian calculations.""" + + def __init__(self, forward_min_event_ndims=0): + super(ConstantJacobian, self).__init__( + validate_args=False, + is_constant_jacobian=True, + forward_min_event_ndims=forward_min_event_ndims, + name="c") + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(2., y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(-2., x.dtype) + + +class BijectorReduceEventDimsTest(test.TestCase): + """Test caching with BrokenBijector.""" + + def testReduceEventNdimsForward(self): + x = [[[1., 2.], [3., 4.]]] + bij = ExpOnlyJacobian() + self.assertAllClose( + np.log(x), + self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0))) + self.assertAllClose( + np.sum(np.log(x), axis=-1), + self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1))) + self.assertAllClose( + np.sum(np.log(x), axis=(-1, -2)), + self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2))) + + def testReduceEventNdimsForwardRaiseError(self): + x = [[[1., 2.], [3., 4.]]] + bij = ExpOnlyJacobian(forward_min_event_ndims=1) + with self.assertRaisesRegexp(ValueError, "must be larger than"): + bij.forward_log_det_jacobian(x, event_ndims=0) + + def testReduceEventNdimsInverse(self): + x = [[[1., 2.], [3., 4.]]] + bij = ExpOnlyJacobian() + self.assertAllClose( + -np.log(x), + self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0))) + self.assertAllClose( + np.sum(-np.log(x), axis=-1), + self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1))) + self.assertAllClose( + np.sum(-np.log(x), axis=(-1, -2)), + self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2))) + + def testReduceEventNdimsInverseRaiseError(self): + x = [[[1., 2.], [3., 4.]]] + bij = ExpOnlyJacobian(forward_min_event_ndims=1) + with self.assertRaisesRegexp(ValueError, "must be larger than"): + bij.inverse_log_det_jacobian(x, event_ndims=0) + + def testReduceEventNdimsForwardConstJacobian(self): + x = [[[1., 2.], [3., 4.]]] + bij = ConstantJacobian() + self.assertAllClose( + -2., + self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0))) + self.assertAllClose( + -4., + self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1))) + self.assertAllClose( + -8., + self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2))) + + def testReduceEventNdimsInverseConstJacobian(self): + x = [[[1., 2.], [3., 4.]]] + bij = ConstantJacobian() + self.assertAllClose( + 2., + self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0))) + self.assertAllClose( + 4., + self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1))) + self.assertAllClose( + 8., + self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2))) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py index e8f9d0b..b347c20 100644 --- a/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py +++ b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py @@ -27,14 +27,19 @@ class IdentityBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = X transformation.""" def testBijector(self): - with self.test_session(): - bijector = identity_bijector.Identity() - self.assertEqual("identity", bijector.name) - x = [[[0.], [1.]]] - self.assertAllEqual(x, bijector.forward(x).eval()) - self.assertAllEqual(x, bijector.inverse(x).eval()) - self.assertAllEqual(0., bijector.inverse_log_det_jacobian(x).eval()) - self.assertAllEqual(0., bijector.forward_log_det_jacobian(x).eval()) + bijector = identity_bijector.Identity(validate_args=True) + self.assertEqual("identity", bijector.name) + x = [[[0.], [1.]]] + self.assertAllEqual(x, self.evaluate(bijector.forward(x))) + self.assertAllEqual(x, self.evaluate(bijector.inverse(x))) + self.assertAllEqual( + 0., + self.evaluate( + bijector.inverse_log_det_jacobian(x, event_ndims=3))) + self.assertAllEqual( + 0., + self.evaluate( + bijector.forward_log_det_jacobian(x, event_ndims=3))) def testScalarCongruency(self): with self.test_session(): diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index ed43555..4ebc600 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -23,7 +23,6 @@ import collections import contextlib import re -import numpy as np import six from tensorflow.python.framework import dtypes @@ -31,8 +30,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops -from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -41,23 +40,24 @@ __all__ = [ class _Mapping(collections.namedtuple( - "_Mapping", ["x", "y", "ildj", "kwargs"])): + "_Mapping", ["x", "y", "ildj_map", "kwargs"])): """Helper class to make it easier to manage caching in `Bijector`.""" - def __new__(cls, x=None, y=None, ildj=None, kwargs=None): + def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None): """Custom __new__ so namedtuple items have defaults. Args: x: `Tensor`. Forward. y: `Tensor`. Inverse. - ildj: `Tensor`. Inverse log det Jacobian. + ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor` + representing the inverse log det jacobian. kwargs: Python dictionary. Extra args supplied to forward/inverse/etc functions. Returns: mapping: New instance of _Mapping. """ - return super(_Mapping, cls).__new__(cls, x, y, ildj, kwargs) + return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs) @property def x_key(self): @@ -69,13 +69,14 @@ class _Mapping(collections.namedtuple( """Returns key used for caching X=g^{-1}(Y).""" return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) - def merge(self, x=None, y=None, ildj=None, kwargs=None, mapping=None): + def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None): """Returns new _Mapping with args merged with self. Args: x: `Tensor`. Forward. y: `Tensor`. Inverse. - ildj: `Tensor`. Inverse log det Jacobian. + ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor` + representing the inverse log det jacobian. kwargs: Python dictionary. Extra args supplied to forward/inverse/etc functions. mapping: Instance of _Mapping to merge. Can only be specified if no other @@ -88,15 +89,30 @@ class _Mapping(collections.namedtuple( ValueError: if mapping and any other arg is not `None`. """ if mapping is None: - mapping = _Mapping(x=x, y=y, ildj=ildj, kwargs=kwargs) - elif not all(arg is None for arg in [x, y, ildj, kwargs]): - raise ValueError("Cannot specify mapping and individual args.") + mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs) + elif any(arg is not None for arg in [x, y, ildj_map, kwargs]): + raise ValueError("Cannot simultaneously specify mapping and individual " + "arguments.") + return _Mapping( x=self._merge(self.x, mapping.x), y=self._merge(self.y, mapping.y), - ildj=self._merge(self.ildj, mapping.ildj), + ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map), kwargs=self._merge(self.kwargs, mapping.kwargs)) + def _merge_dicts(self, old=None, new=None): + """Helper to merge two dictionaries.""" + old = dict() if old is None else old + new = dict() if new is None else new + for k, v in six.iteritems(new): + val = old.get(k, None) + if val is not None and val != v: + raise ValueError("Found different value for existing key " + "(key:{} old_value:{} new_value:{}".format( + k, old[k], v)) + old[k] = v + return old + def _merge(self, old, new): """Helper to merge which handles merging one value.""" if old is None: @@ -112,7 +128,6 @@ class _Mapping(collections.namedtuple( @six.add_metaclass(abc.ABCMeta) -@tf_export("distributions.bijectors.Bijector") class Bijector(object): r"""Interface for transformations of a `Distribution` sample. @@ -137,11 +152,11 @@ class Bijector(object): 2. Inverse\ Useful for "reversing" a transformation to compute one probability in terms of another. - 3. `(log o det o Jacobian o inverse)(x)`\ + 3. `log_det_jacobian(x)`\ "The log of the determinant of the matrix of all first-order partial derivatives of the inverse function."\ Useful for inverting a transformation to compute one probability in terms - of another. Geometrically, the det(Jacobian) is the volume of the + of another. Geometrically, the Jacobian determinant is the volume of the transformation and is used to scale the probability. By convention, transformations of random variables are named in terms of the @@ -164,7 +179,7 @@ class Bijector(object): ```python def transformed_log_prob(bijector, log_prob, x): - return (bijector.inverse_log_det_jacobian(x) + + return (bijector.inverse_log_det_jacobian(x, event_ndims=0) + log_prob(bijector.inverse(x))) ``` @@ -199,9 +214,11 @@ class Bijector(object): ```python class Exp(Bijector): - def __init__(self, event_ndims=0, validate_args=False, name="exp"): + def __init__(self, validate_args=False, name="exp"): super(Exp, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + validate_args=validate_args, + forward_min_event_ndims=0, + name=name) def _forward(self, x): return math_ops.exp(x) @@ -213,10 +230,11 @@ class Bijector(object): return -self._forward_log_det_jacobian(self._inverse(y)) def _forward_log_det_jacobian(self, x): - if self.event_ndims is None: - raise ValueError("Jacobian requires known event_ndims.") - event_dims = array_ops.shape(x)[-self.event_ndims:] - return math_ops.reduce_sum(x, axis=event_dims) + # Notice that we needn't do any reducing, even when`event_ndims > 0`. + # The base Bijector class will handle reducing for us; it knows how + # to do so because we called `super` `__init__` with + # `forward_min_event_ndims = 0`. + return x ``` - "Affine" @@ -237,18 +255,50 @@ class Bijector(object): MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d) ``` - #### Jacobian + #### Min_event_ndims and Naming + + Bijectors are named for the dimensionality of data they act on (i.e. without + broadcasting). We can think of bijectors having an intrinsic `min_event_ndims` + , which is the minimum number of dimensions for the bijector act on. For + instance, a Cholesky decomposition requires a matrix, and hence + `min_event_ndims=2`. + + Some examples: + + `AffineScalar: min_event_ndims=0` + `Affine: min_event_ndims=1` + `Cholesky: min_event_ndims=2` + `Exp: min_event_ndims=0` + `Sigmoid: min_event_ndims=0` + `SoftmaxCentered: min_event_ndims=1` + + Note the difference between `Affine` and `AffineScalar`. `AffineScalar` + operates on scalar events, whereas `Affine` operates on vector-valued events. - The Jacobian is a reduction over event dims. To see this, consider the `Exp` - `Bijector` applied to a `Tensor` which has sample, batch, and event (S, B, E) - shape semantics. Suppose the `Tensor`'s partitioned-shape is `(S=[4], B=[2], - E=[3, 3])`. The shape of the `Tensor` returned by `forward` and `inverse` is - unchanged, i.e., `[4, 2, 3, 3]`. However the shape returned by - `inverse_log_det_jacobian` is `[4, 2]` because the Jacobian is a reduction - over the event dimensions. + More generally, there is a `forward_min_event_ndims` and an + `inverse_min_event_ndims`. In most cases, these will be the same. + However, for some shape changing bijectors, these will be different + (e.g. a bijector which pads an extra dimension at the end, might have + `forward_min_event_ndims=0` and `inverse_min_event_ndims=1`. - It is sometimes useful to implement the inverse Jacobian as the negative - forward Jacobian. For example, + + #### Jacobian Determinant + + The Jacobian determinant is a reduction over `event_ndims - min_event_ndims` + (`forward_min_event_ndims` for `forward_log_det_jacobian` and + `inverse_min_event_ndims` for `inverse_log_det_jacobian`). + To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has + sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s + partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor` + returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`. + However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because + the Jacobian determinant is a reduction over the event dimensions. + + Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the + Jacobian determinant reduction is over `event_ndims - 1`. + + It is sometimes useful to implement the inverse Jacobian determinant as the + negative forward Jacobian determinant. For example, ```python def _inverse_log_det_jacobian(self, y): @@ -279,9 +329,54 @@ class Bijector(object): The claim follows from [properties of determinant]( https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups). - Generally its preferable to directly implement the inverse Jacobian. This - should have superior numerical stability and will often share subgraphs with - the `_inverse` implementation. + Generally its preferable to directly implement the inverse Jacobian + determinant. This should have superior numerical stability and will often + share subgraphs with the `_inverse` implementation. + + #### Is_constant_jacobian + + Certain bijectors will have constant jacobian matrices. For instance, the + `Affine` bijector encodes multiplication by a matrix plus a shift, with + jacobian matrix, the same aforementioned matrix. + + `is_constant_jacobian` encodes the fact that the jacobian matrix is constant. + The semantics of this argument are the following: + + * Repeated calls to "log_det_jacobian" functions with the same + `event_ndims` (but not necessarily same input), will return the first + computed jacobian (because the matrix is constant, and hence is input + independent). + * `log_det_jacobian` implementations are merely broadcastable to the true + `log_det_jacobian` (because, again, the jacobian matrix is input + independent). Specifically, `log_det_jacobian` is implemented as the + log jacobian determinant for a single input. + + ```python + class Identity(Bijector): + + def __init__(self, validate_args=False, name="identity"): + super(Identity, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + forward_min_event_ndims=0, + name=name) + + def _forward(self, x): + return x + + def _inverse(self, y): + return y + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(self._inverse(y)) + + def _forward_log_det_jacobian(self, x): + # The full log jacobian determinant would be array_ops.zero_like(x). + # However, we circumvent materializing that, since the jacobian + # calculation is input independent, and we specify it for one input. + return constant_op.constant(0., x.dtype.base_dtype) + + ``` #### Subclass Requirements @@ -364,14 +459,14 @@ class Bijector(object): ==> (-1., 1.) # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. - abs.inverse_log_det_jacobian(1.) + abs.inverse_log_det_jacobian(1., event_ndims=0) ==> (0., 0.) # Special case handling of 0. abs.inverse(0.) ==> (0., 0.) - abs.inverse_log_det_jacobian(0.) + abs.inverse_log_det_jacobian(0., event_ndims=0) ==> (0., 0.) ``` @@ -379,11 +474,12 @@ class Bijector(object): @abc.abstractmethod def __init__(self, - event_ndims=None, graph_parents=None, is_constant_jacobian=False, validate_args=False, dtype=None, + forward_min_event_ndims=None, + inverse_min_event_ndims=None, name=None): """Constructs Bijector. @@ -392,42 +488,61 @@ class Bijector(object): Examples: ```python - # Create the Y = g(X) = X transform which operates on vector events. - identity = Identity(event_ndims=1) + # Create the Y = g(X) = X transform. + identity = Identity() - # Create the Y = g(X) = exp(X) transform which operates on matrices. - exp = Exp(event_ndims=2) + # Create the Y = g(X) = exp(X) transform. + exp = Exp() ``` See `Bijector` subclass docstring for more details and specific examples. Args: - event_ndims: number of dimensions associated with event coordinates. graph_parents: Python list of graph prerequisites of this `Bijector`. - is_constant_jacobian: Python `bool` indicating that the Jacobian is not a - function of the input. + is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is + not a function of the input. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not enforced. + forward_min_event_ndims: Python `integer` indicating the minimum number of + dimensions `forward` operates on. + inverse_min_event_ndims: Python `integer` indicating the minimum number of + dimensions `inverse` operates on. Will be set to + `forward_min_event_ndims` by default, if no value is provided. name: The name to give Ops created by the initializer. Raises: + ValueError: If neither `forward_min_event_ndims` and + `inverse_min_event_ndims` are specified, or if either of them is + negative. ValueError: If a member of `graph_parents` is not a `Tensor`. """ - self._event_ndims = ( - ops.convert_to_tensor(event_ndims, dtype=dtypes.int32) - if event_ndims is not None else None) self._graph_parents = graph_parents or [] + + if forward_min_event_ndims is None and inverse_min_event_ndims is None: + raise ValueError("Must specify at least one of `forward_min_event_ndims` " + "and `inverse_min_event_ndims`.") + elif inverse_min_event_ndims is None: + inverse_min_event_ndims = forward_min_event_ndims + elif forward_min_event_ndims is None: + forward_min_event_ndims = inverse_min_event_ndims + + if forward_min_event_ndims < 0: + raise ValueError("forward_min_event_ndims must be a non-negative " + "integer.") + if inverse_min_event_ndims < 0: + raise ValueError("inverse_min_event_ndims must be a non-negative " + "integer.") + self._forward_min_event_ndims = forward_min_event_ndims + self._inverse_min_event_ndims = inverse_min_event_ndims self._is_constant_jacobian = is_constant_jacobian + self._constant_ildj_map = {} self._validate_args = validate_args self._dtype = dtype self._from_y = {} self._from_x = {} - # Using abbreviation ildj for "inverse log det Jacobian." - # This variable is not `None` iff is_constant_jacobian is `True`. - self._constant_ildj = None if name: self._name = name else: @@ -443,20 +558,26 @@ class Bijector(object): raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) @property - def event_ndims(self): - """Returns then number of event dimensions this bijector operates on.""" - return self._event_ndims - - @property def graph_parents(self): """Returns this `Bijector`'s graph_parents as a Python list.""" return self._graph_parents @property + def forward_min_event_ndims(self): + """Returns the minimal number of dimensions bijector.forward operates on.""" + return self._forward_min_event_ndims + + @property + def inverse_min_event_ndims(self): + """Returns the minimal number of dimensions bijector.inverse operates on.""" + return self._inverse_min_event_ndims + + @property def is_constant_jacobian(self): - """Returns true iff the Jacobian is not a function of x. + """Returns true iff the Jacobian matrix is not a function of x. - Note: Jacobian is either constant for both forward and inverse or neither. + Note: Jacobian matrix is either constant for both forward and inverse or + neither. Returns: is_constant_jacobian: Python `bool`. @@ -653,36 +774,57 @@ class Bijector(object): return self._call_inverse(y, name) def _inverse_log_det_jacobian(self, y): - """Subclass implementation of `inverse_log_det_jacobian` public function.""" + """Subclass implementation of `inverse_log_det_jacobian` public function. + + In particular, this method differs from the public function, in that it + does not take `event_ndims`. Thus, this implements the minimal Jacobian + determinant calculation (i.e. over `inverse_min_event_ndims`). + + Args: + y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation. + Returns: + inverse_log_det_jacobian: `Tensor`, if this bijector is injective. + If not injective, returns the k-tuple containing jacobians for the + unique `k` points `(x1, ..., xk)` such that `g(xi) = y`. + """ raise NotImplementedError("inverse_log_det_jacobian not implemented.") - def _call_inverse_log_det_jacobian(self, y, name, **kwargs): + def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs): with self._name_scope(name, [y]): - if self._constant_ildj is not None: - return self._constant_ildj + if event_ndims in self._constant_ildj_map: + return self._constant_ildj_map[event_ndims] y = ops.convert_to_tensor(y, name="y") self._maybe_assert_dtype(y) if not self._is_injective: # No caching for non-injective - return self._inverse_log_det_jacobian(y, **kwargs) + ildjs = self._inverse_log_det_jacobian(y, **kwargs) + return tuple(self._reduce_jacobian_det_over_event( + y, ildj, self.inverse_min_event_ndims, event_ndims) + for ildj in ildjs) mapping = self._lookup(y=y, kwargs=kwargs) - if mapping.ildj is not None: - return mapping.ildj + if mapping.ildj_map is not None and event_ndims in mapping.ildj_map: + return mapping.ildj_map[event_ndims] try: x = None # Not needed; leave cache as is. ildj = self._inverse_log_det_jacobian(y, **kwargs) + ildj = self._reduce_jacobian_det_over_event( + y, ildj, self.inverse_min_event_ndims, event_ndims) except NotImplementedError as original_exception: try: x = mapping.x if mapping.x is not None else self._inverse(y, **kwargs) ildj = -self._forward_log_det_jacobian(x, **kwargs) + ildj = self._reduce_jacobian_det_over_event( + x, ildj, self.forward_min_event_ndims, event_ndims) except NotImplementedError: raise original_exception - mapping = mapping.merge(x=x, ildj=ildj) + + mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj}) self._cache(mapping) if self.is_constant_jacobian: - self._constant_ildj = mapping.ildj - return mapping.ildj + self._constant_ildj_map[event_ndims] = ildj + return ildj - def inverse_log_det_jacobian(self, y, name="inverse_log_det_jacobian"): + def inverse_log_det_jacobian( + self, y, event_ndims, name="inverse_log_det_jacobian"): """Returns the (log o det o Jacobian o inverse)(y). Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.) @@ -691,7 +833,12 @@ class Bijector(object): evaluated at `g^{-1}(y)`. Args: - y: `Tensor`. The input to the "inverse" Jacobian evaluation. + y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation. + event_ndims: Number of dimensions in the probabilistic events being + transformed. Must be greater than or equal to + `self.inverse_min_event_ndims`. The result is summed over the final + dimensions to produce a scalar Jacobian determinant for each event, + i.e. it has shape `y.shape.ndims - event_ndims` dimensions. name: The name to give this op. Returns: @@ -705,45 +852,74 @@ class Bijector(object): `self.dtype`. NotImplementedError: if `_inverse_log_det_jacobian` is not implemented. """ - return self._call_inverse_log_det_jacobian(y, name) + with ops.control_dependencies(self._check_valid_event_ndims( + min_event_ndims=self.inverse_min_event_ndims, event_ndims=event_ndims)): + return self._call_inverse_log_det_jacobian(y, event_ndims, name) def _forward_log_det_jacobian(self, x): - """Subclass implementation of `forward_log_det_jacobian`.""" + """Subclass implementation of `forward_log_det_jacobian` public function. + + In particular, this method differs from the public function, in that it + does not take `event_ndims`. Thus, this implements the minimal Jacobian + determinant calculation (i.e. over `forward_min_event_ndims`). + + Args: + x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation. + + Returns: + forward_log_det_jacobian: `Tensor`, if this bijector is injective. + If not injective, returns the k-tuple containing jacobians for the + unique `k` points `(x1, ..., xk)` such that `g(xi) = y`. + """ + raise NotImplementedError( "forward_log_det_jacobian not implemented.") - def _call_forward_log_det_jacobian(self, x, name, **kwargs): + def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs): with self._name_scope(name, [x]): - if self._constant_ildj is not None: + if event_ndims in self._constant_ildj_map: # Need "-1. *" to avoid invalid-unary-operand-type linter warning. - return -1. * self._constant_ildj + return -1. * self._constant_ildj_map[event_ndims] x = ops.convert_to_tensor(x, name="x") self._maybe_assert_dtype(x) if not self._is_injective: - return self._forward_log_det_jacobian(x, **kwargs) # No caching. + fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching. + return tuple(self._reduce_jacobian_det_over_event( + x, fldj, self.forward_min_event_ndims, event_ndims) + for fldj in fldjs) mapping = self._lookup(x=x, kwargs=kwargs) - if mapping.ildj is not None: - return -mapping.ildj + if mapping.ildj_map is not None and event_ndims in mapping.ildj_map: + return -mapping.ildj_map[event_ndims] try: y = None # Not needed; leave cache as is. ildj = -self._forward_log_det_jacobian(x, **kwargs) + ildj = self._reduce_jacobian_det_over_event( + x, ildj, self.forward_min_event_ndims, event_ndims) except NotImplementedError as original_exception: try: y = mapping.y if mapping.y is not None else self._forward(x, **kwargs) ildj = self._inverse_log_det_jacobian(y, **kwargs) + ildj = self._reduce_jacobian_det_over_event( + y, ildj, self.inverse_min_event_ndims, event_ndims) except NotImplementedError: raise original_exception - mapping = mapping.merge(y=y, ildj=ildj) + mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj}) self._cache(mapping) if self.is_constant_jacobian: - self._constant_ildj = mapping.ildj - return -mapping.ildj + self._constant_ildj_map[event_ndims] = ildj + return -ildj - def forward_log_det_jacobian(self, x, name="forward_log_det_jacobian"): + def forward_log_det_jacobian( + self, x, event_ndims, name="forward_log_det_jacobian"): """Returns both the forward_log_det_jacobian. Args: - x: `Tensor`. The input to the "forward" Jacobian evaluation. + x: `Tensor`. The input to the "forward" Jacobian determinant evaluation. + event_ndims: Number of dimensions in the probabilistic events being + transformed. Must be greater than or equal to + `self.forward_min_event_ndims`. The result is summed over the final + dimensions to produce a scalar Jacobian determinant for each event, + i.e. it has shape `x.shape.ndims - event_ndims` dimensions. name: The name to give this op. Returns: @@ -761,7 +937,9 @@ class Bijector(object): raise NotImplementedError( "forward_log_det_jacobian cannot be implemented for non-injective " "transforms.") - return self._call_forward_log_det_jacobian(x, name) + with ops.control_dependencies(self._check_valid_event_ndims( + min_event_ndims=self.forward_min_event_ndims, event_ndims=event_ndims)): + return self._call_forward_log_det_jacobian(x, event_ndims, name) @contextlib.contextmanager def _name_scope(self, name=None, values=None): @@ -779,9 +957,6 @@ class Bijector(object): def _cache(self, mapping): """Helper which stores mapping info in forward/inverse dicts.""" - if self._constant_ildj is not None: - # Fold in ildj if known constant Jacobian. - mapping = mapping.merge(ildj=self._constant_ildj) # Merging from lookup is an added check that we're not overwriting anything # which is not None. mapping = mapping.merge(mapping=self._lookup( @@ -803,22 +978,66 @@ class Bijector(object): return self._from_y.get(mapping.y_key, mapping) return mapping - def _event_dims_tensor(self, sample): - """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`.""" - if self.event_ndims is None: - raise ValueError("Jacobian cannot be computed with unknown event_ndims") - static_event_ndims = tensor_util.constant_value(self.event_ndims) - static_rank = sample.get_shape().ndims - if static_event_ndims is not None and static_rank is not None: - return ops.convert_to_tensor( - static_rank + np.arange(-static_event_ndims, 0).astype(np.int32)) - - if static_event_ndims is not None: - event_range = np.arange(-static_event_ndims, 0).astype(np.int32) - else: - event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32) - - if static_rank is not None: - return event_range + static_rank + def _reduce_jacobian_det_over_event( + self, y, ildj, min_event_ndims, event_ndims): + """Reduce jacobian over event_ndims - min_event_ndims.""" + if not self.is_constant_jacobian: + return math_ops.reduce_sum( + ildj, + self._get_event_reduce_dims(min_event_ndims, event_ndims)) + + # In this case, we need to tile the jacobian over the event and reduce. + y_rank = array_ops.rank(y) + y_shape = array_ops.shape(y)[ + y_rank - event_ndims : y_rank - min_event_ndims] + + ones = array_ops.ones(y_shape, ildj.dtype) + reduced_ildj = math_ops.reduce_sum( + ones * ildj, + axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) + # The multiplication by ones can change the inferred static shape so we try + # to recover as much as possible. + if (isinstance(event_ndims, int) and + y.get_shape().ndims and ildj.get_shape().ndims): + y_shape = y.get_shape() + y_shape = y_shape[y_shape.ndims - event_ndims : + y_shape.ndims - min_event_ndims] + ildj_shape = ildj.get_shape() + broadcast_shape = array_ops.broadcast_static_shape( + ildj_shape, y_shape) + reduced_ildj.set_shape( + broadcast_shape[: broadcast_shape.ndims - ( + event_ndims - min_event_ndims)]) + + return reduced_ildj + + def _get_event_reduce_dims(self, min_event_ndims, event_ndims): + """Compute the reduction dimensions given event_ndims.""" + min_event_ndims_ = (min_event_ndims if isinstance(min_event_ndims, int) + else tensor_util.constant_value(min_event_ndims)) + event_ndims_ = (event_ndims if isinstance(event_ndims, int) + else tensor_util.constant_value(event_ndims)) + + if min_event_ndims_ is not None and event_ndims_ is not None: + return [-index for index in range(1, event_ndims_ - min_event_ndims_ + 1)] else: - return event_range + array_ops.rank(sample) + reduce_ndims = event_ndims - min_event_ndims + return math_ops.range(-reduce_ndims, 0) + + def _check_valid_event_ndims(self, min_event_ndims, event_ndims): + """Check whether event_ndims is atleast min_event_ndims.""" + min_event_ndims_ = (min_event_ndims if isinstance(min_event_ndims, int) + else tensor_util.constant_value(min_event_ndims)) + event_ndims_ = (event_ndims if isinstance(event_ndims, int) + else tensor_util.constant_value(event_ndims)) + + if min_event_ndims_ is not None and event_ndims_ is not None: + if min_event_ndims_ > event_ndims_: + raise ValueError("event_ndims ({}) must be larger than " + "min_event_ndims ({})".format( + event_ndims_, min_event_ndims_)) + return [] + + if self.validate_args: + return [check_ops.assert_greater_equal(event_ndims, min_event_ndims)] + return [] diff --git a/tensorflow/python/ops/distributions/bijector_test_util.py b/tensorflow/python/ops/distributions/bijector_test_util.py index ff3535c..784bfd5 100644 --- a/tensorflow/python/ops/distributions/bijector_test_util.py +++ b/tensorflow/python/ops/distributions/bijector_test_util.py @@ -79,9 +79,7 @@ def assert_scalar_congruency(bijector, Raises: AssertionError: If tests fail. """ - # Checks and defaults. - assert bijector.event_ndims.eval() == 0 if sess is None: sess = ops.get_default_session() @@ -111,7 +109,10 @@ def assert_scalar_congruency(bijector, # (b - a) = \int_a^b dx = \int_{y(a)}^{y(b)} |dx/dy| dy # "change_measure_dy_dx" below is a Monte Carlo approximation to the right # hand side, which should then be close to the left, which is (b - a). - dy_dx = math_ops.exp(bijector.inverse_log_det_jacobian(uniform_y_samps)) + # We assume event_ndims=0 because we assume scalar -> scalar. The log_det + # methods will handle whether they expect event_ndims > 0. + dy_dx = math_ops.exp(bijector.inverse_log_det_jacobian( + uniform_y_samps, event_ndims=0)) # E[|dx/dy|] under Uniform[lower_y, upper_y] # = \int_{y(a)}^{y(b)} |dx/dy| dP(u), where dP(u) is the uniform measure expectation_of_dy_dx_under_uniform = math_ops.reduce_mean(dy_dx) @@ -121,7 +122,8 @@ def assert_scalar_congruency(bijector, # We'll also check that dy_dx = 1 / dx_dy. dx_dy = math_ops.exp( - bijector.forward_log_det_jacobian(bijector.inverse(uniform_y_samps))) + bijector.forward_log_det_jacobian( + bijector.inverse(uniform_y_samps), event_ndims=0)) [ forward_on_10_pts_v, @@ -158,7 +160,8 @@ def assert_scalar_congruency(bijector, dy_dx_v, np.divide(1., dx_dy_v), atol=1e-5, rtol=1e-3) -def assert_bijective_and_finite(bijector, x, y, atol=0, rtol=1e-5, sess=None): +def assert_bijective_and_finite( + bijector, x, y, event_ndims, atol=0, rtol=1e-5, sess=None): """Assert that forward/inverse (along with jacobians) are inverses and finite. It is recommended to use x and y values that are very very close to the edge @@ -168,6 +171,8 @@ def assert_bijective_and_finite(bijector, x, y, atol=0, rtol=1e-5, sess=None): bijector: A Bijector instance. x: np.array of values in the domain of bijector.forward. y: np.array of values in the domain of bijector.inverse. + event_ndims: Integer describing the number of event dimensions this bijector + operates on. atol: Absolute tolerance. rtol: Relative tolerance. sess: TensorFlow session. Defaults to the default session. @@ -197,10 +202,10 @@ def assert_bijective_and_finite(bijector, x, y, atol=0, rtol=1e-5, sess=None): ] = sess.run([ bijector.inverse(f_x), bijector.forward(g_y), - bijector.inverse_log_det_jacobian(f_x), - bijector.forward_log_det_jacobian(x), - bijector.inverse_log_det_jacobian(y), - bijector.forward_log_det_jacobian(g_y), + bijector.inverse_log_det_jacobian(f_x, event_ndims=event_ndims), + bijector.forward_log_det_jacobian(x, event_ndims=event_ndims), + bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims), + bijector.forward_log_det_jacobian(g_y, event_ndims=event_ndims), f_x, g_y, ]) diff --git a/tensorflow/python/ops/distributions/bijectors.py b/tensorflow/python/ops/distributions/bijectors.py deleted file mode 100644 index 69c3a5d..0000000 --- a/tensorflow/python/ops/distributions/bijectors.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Core module for TensorFlow distribution bijectors.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -# go/tf-wildcard-import -# pylint: disable=wildcard-import,unused-import -from tensorflow.python.ops.distributions.bijector import Bijector -from tensorflow.python.ops.distributions.identity_bijector import Identity - -# pylint: enable=wildcard-import,unused-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ["Bijector", "Identity"] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/ops/distributions/distributions.py b/tensorflow/python/ops/distributions/distributions.py index 9df7d14..7c4b869 100644 --- a/tensorflow/python/ops/distributions/distributions.py +++ b/tensorflow/python/ops/distributions/distributions.py @@ -19,7 +19,6 @@ from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.python.ops.distributions import bijectors from tensorflow.python.ops.distributions.bernoulli import Bernoulli from tensorflow.python.ops.distributions.beta import Beta from tensorflow.python.ops.distributions.categorical import Categorical @@ -40,7 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - "bijectors", "Bernoulli", "Beta", "Categorical", diff --git a/tensorflow/python/ops/distributions/identity_bijector.py b/tensorflow/python/ops/distributions/identity_bijector.py index 2972c35..8628e68 100644 --- a/tensorflow/python/ops/distributions/identity_bijector.py +++ b/tensorflow/python/ops/distributions/identity_bijector.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.ops.distributions import bijector -from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -28,7 +27,6 @@ __all__ = [ ] -@tf_export("distributions.bijectors.Identity") class Identity(bijector.Bijector): """Compute Y = g(X) = X. @@ -37,7 +35,7 @@ class Identity(bijector.Bijector): ```python # Create the Y=g(X)=X transform which is intended for Tensors with 1 batch # ndim and 1 event ndim (i.e., vector of vectors). - identity = Identity(event_ndims=1) + identity = Identity() x = [[1., 2], [3, 4]] x == identity.forward(x) == identity.inverse(x) @@ -45,10 +43,10 @@ class Identity(bijector.Bijector): """ - def __init__(self, validate_args=False, event_ndims=0, name="identity"): + def __init__(self, validate_args=False, name="identity"): super(Identity, self).__init__( + forward_min_event_ndims=0, is_constant_jacobian=True, - event_ndims=event_ndims, validate_args=validate_args, name=name) diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py index 1efcf9d..1ad63a8 100644 --- a/tensorflow/python/ops/distributions/transformed_distribution.py +++ b/tensorflow/python/ops/distributions/transformed_distribution.py @@ -197,8 +197,7 @@ class TransformedDistribution(distribution_lib.Distribution): distribution=ds.Normal(loc=0., scale=1.), bijector=ds.bijectors.Affine( shift=-1., - scale_identity_multiplier=2., - event_ndims=0), + scale_identity_multiplier=2.) name="NormalTransformedDistribution") ``` @@ -419,48 +418,51 @@ class TransformedDistribution(distribution_lib.Distribution): # For caching to work, it is imperative that the bijector is the first to # modify the input. x = self.bijector.inverse(y) - ildj = self.bijector.inverse_log_det_jacobian(y) + event_ndims = self._maybe_get_event_ndims_statically() + + ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) if self.bijector._is_injective: # pylint: disable=protected-access - return self._finish_log_prob_for_one_fiber(y, x, ildj) + return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims) lp_on_fibers = [ - self._finish_log_prob_for_one_fiber(y, x_i, ildj_i) + self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims) for x_i, ildj_i in zip(x, ildj)] return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0) - def _finish_log_prob_for_one_fiber(self, y, x, ildj): + def _finish_log_prob_for_one_fiber(self, y, x, ildj, event_ndims): """Finish computation of log_prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) log_prob += math_ops.cast(ildj, log_prob.dtype) - if self._is_maybe_event_override: + if self._is_maybe_event_override and isinstance(event_ndims, int): log_prob.set_shape(array_ops.broadcast_static_shape( - y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape)) + x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) return log_prob def _prob(self, y): x = self.bijector.inverse(y) - ildj = self.bijector.inverse_log_det_jacobian(y) + event_ndims = self._maybe_get_event_ndims_statically() + ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) if self.bijector._is_injective: # pylint: disable=protected-access - return self._finish_prob_for_one_fiber(y, x, ildj) + return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims) prob_on_fibers = [ - self._finish_prob_for_one_fiber(y, x_i, ildj_i) + self._finish_prob_for_one_fiber(y, x_i, ildj_i, event_ndims) for x_i, ildj_i in zip(x, ildj)] return sum(prob_on_fibers) - def _finish_prob_for_one_fiber(self, y, x, ildj): + def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype)) - if self._is_maybe_event_override: + if self._is_maybe_event_override and isinstance(event_ndims, int): prob.set_shape(array_ops.broadcast_static_shape( - y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape)) + y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) return prob def _log_cdf(self, y): @@ -545,10 +547,17 @@ class TransformedDistribution(distribution_lib.Distribution): _ones_like(self.distribution.batch_shape_tensor()) ], 0) entropy = array_ops.tile(entropy, multiples) - dummy = array_ops.zeros([], self.dtype) - entropy -= math_ops.cast( - self.bijector.inverse_log_det_jacobian(dummy), - entropy.dtype) + dummy = array_ops.zeros( + shape=array_ops.concat( + [self.batch_shape_tensor(), self.event_shape_tensor()], + 0), + dtype=self.dtype) + event_ndims = (self.event_shape.ndims if self.event_shape.ndims is not None + else array_ops.size(self.event_shape_tensor())) + ildj = self.bijector.inverse_log_det_jacobian( + dummy, event_ndims=event_ndims) + + entropy -= math_ops.cast(ildj, entropy.dtype) entropy.set_shape(self.batch_shape) return entropy @@ -610,3 +619,16 @@ class TransformedDistribution(distribution_lib.Distribution): n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims return array_ops.transpose( x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n))) + + def _maybe_get_event_ndims_statically(self): + if self.event_shape.ndims is not None: + return self.event_shape.ndims + + event_ndims = array_ops.size(self.event_shape_tensor()) + + static_event_ndims = tensor_util.constant_value(event_ndims) + + if static_event_ndims is not None: + return static_event_ndims + + return event_ndims diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt deleted file mode 100644 index 11565bd..0000000 --- a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt +++ /dev/null @@ -1,65 +0,0 @@ -path: "tensorflow.distributions.bijectors.Bijector" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "dtype" - mtype: "" - } - member { - name: "event_ndims" - mtype: "" - } - member { - name: "graph_parents" - mtype: "" - } - member { - name: "is_constant_jacobian" - mtype: "" - } - member { - name: "name" - mtype: "" - } - member { - name: "validate_args" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'event_ndims\', \'graph_parents\', \'is_constant_jacobian\', \'validate_args\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "forward" - argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward\'], " - } - member_method { - name: "forward_event_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "forward_event_shape_tensor" - argspec: "args=[\'self\', \'input_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_event_shape_tensor\'], " - } - member_method { - name: "forward_log_det_jacobian" - argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_log_det_jacobian\'], " - } - member_method { - name: "inverse" - argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], " - } - member_method { - name: "inverse_event_shape" - argspec: "args=[\'self\', \'output_shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "inverse_event_shape_tensor" - argspec: "args=[\'self\', \'output_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_event_shape_tensor\'], " - } - member_method { - name: "inverse_log_det_jacobian" - argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_log_det_jacobian\'], " - } -} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt deleted file mode 100644 index 1e5fe62..0000000 --- a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt +++ /dev/null @@ -1,66 +0,0 @@ -path: "tensorflow.distributions.bijectors.Identity" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "dtype" - mtype: "" - } - member { - name: "event_ndims" - mtype: "" - } - member { - name: "graph_parents" - mtype: "" - } - member { - name: "is_constant_jacobian" - mtype: "" - } - member { - name: "name" - mtype: "" - } - member { - name: "validate_args" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'validate_args\', \'event_ndims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'0\', \'identity\'], " - } - member_method { - name: "forward" - argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward\'], " - } - member_method { - name: "forward_event_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "forward_event_shape_tensor" - argspec: "args=[\'self\', \'input_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_event_shape_tensor\'], " - } - member_method { - name: "forward_log_det_jacobian" - argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_log_det_jacobian\'], " - } - member_method { - name: "inverse" - argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], " - } - member_method { - name: "inverse_event_shape" - argspec: "args=[\'self\', \'output_shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "inverse_event_shape_tensor" - argspec: "args=[\'self\', \'output_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_event_shape_tensor\'], " - } - member_method { - name: "inverse_log_det_jacobian" - argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_log_det_jacobian\'], " - } -} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt deleted file mode 100644 index 1d0144f..0000000 --- a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt +++ /dev/null @@ -1,11 +0,0 @@ -path: "tensorflow.distributions.bijectors" -tf_module { - member { - name: "Bijector" - mtype: "" - } - member { - name: "Identity" - mtype: "" - } -} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt index 2fba7c5..90b60ef 100644 --- a/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt @@ -68,10 +68,6 @@ tf_module { name: "Uniform" mtype: "" } - member { - name: "bijectors" - mtype: "" - } member_method { name: "kl_divergence" argspec: "args=[\'distribution_a\', \'distribution_b\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " -- 2.7.4