BREAKING_CHANGE: Remove event_ndims in Bijector, and require `log_det_jacobian` metho...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Apr 2018 21:02:49 +0000 (14:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 11 Apr 2018 21:05:19 +0000 (14:05 -0700)
The class level event_ndims parameter is being deprecated in favor of passing it in
to the `log_det_jacobian` methods.

Specific changes:

  - `log_det_jacobian` signatures are now `log_det_jacobian(input, event_ndims)`

  - Constructors no long have event_ndims passed in (e.g. Affine() vs. Affine(event_ndims=0)).

  - All bijectors must specify a subset of [forward_min_event_ndims, inverse_min_event_ndims]. This is the minimal dimensionality the bijector operates on, with it being "broadcasted" to any passed in event_ndims (e.g. Exp has forward_min_event_ndims = 0. That means it operates on scalars. However, we can use the bijector on any event_ndims > 0 (i.e. we've broadcasted the transformation to work on any amount of event_ndims > 0), and jacobian reduction will work in those cases.

As a result of this change, all bijectors should "broadcast" (e.g. Sigmoid now works on any number of event_ndims).

Other changes (internal and documentation):
  - Added clarifications on Jacobian Determinant vs. Jacobian Matrix.
  - Added clarifications on min_event_ndims, and what the jacobian reduction is over.
  - Changed caching of ildj to be keyed on event_ndims.
  - Several bug fixes to bugs unearthed while writing this code (e.g. transformed distribution shape computation being incorrect)

PiperOrigin-RevId: 192504919

70 files changed:
tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py
tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py
tensorflow/contrib/distributions/python/ops/bijectors/affine.py
tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py
tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py
tensorflow/contrib/distributions/python/ops/bijectors/chain.py
tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py
tensorflow/contrib/distributions/python/ops/bijectors/exp.py
tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py
tensorflow/contrib/distributions/python/ops/bijectors/inline.py
tensorflow/contrib/distributions/python/ops/bijectors/invert.py
tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py
tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
tensorflow/contrib/distributions/python/ops/bijectors/permute.py
tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py
tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py
tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
tensorflow/contrib/distributions/python/ops/bijectors/softplus.py
tensorflow/contrib/distributions/python/ops/bijectors/square.py
tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
tensorflow/python/kernel_tests/distributions/bijector_test.py
tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
tensorflow/python/ops/distributions/bijector_impl.py
tensorflow/python/ops/distributions/bijector_test_util.py
tensorflow/python/ops/distributions/bijectors.py [deleted file]
tensorflow/python/ops/distributions/distributions.py
tensorflow/python/ops/distributions/identity_bijector.py
tensorflow/python/ops/distributions/transformed_distribution.py
tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt [deleted file]
tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt [deleted file]
tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt [deleted file]
tensorflow/tools/api/golden/tensorflow.distributions.pbtxt

index e0d65c7..042c8eb 100644 (file)
@@ -18,11 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
-
 # pylint: disable=g-importing-member
 from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import AbsoluteValue
-from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
@@ -35,50 +32,38 @@ class AbsoluteValueTest(test.TestCase):
 
   def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self):
     with self.test_session() as sess:
-      bijector = AbsoluteValue(event_ndims=0, validate_args=True)
+      bijector = AbsoluteValue(validate_args=True)
       self.assertEqual("absolute_value", bijector.name)
       x = array_ops.constant([[0., 1., -1], [0., -5., 3.]])  # Shape [2, 3]
       y = math_ops.abs(x)
 
       y_ = y.eval()
-      zeros = np.zeros((2, 3))
 
       self.assertAllClose(y_, bijector.forward(x).eval())
       self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y)))
-      self.assertAllClose((zeros, zeros),
-                          sess.run(bijector.inverse_log_det_jacobian(y)))
+      self.assertAllClose((0., 0.),
+                          sess.run(bijector.inverse_log_det_jacobian(
+                              y, event_ndims=0)))
 
       # Run things twice to make sure there are no issues in caching the tuples
       # returned by .inverse*
       self.assertAllClose(y_, bijector.forward(x).eval())
       self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y)))
-      self.assertAllClose((zeros, zeros),
-                          sess.run(bijector.inverse_log_det_jacobian(y)))
-
-  def testEventNdimsMustBeZeroOrRaiseStatic(self):
-    with self.test_session():
-      with self.assertRaisesRegexp(ValueError, "event_ndims.*was not 0"):
-        AbsoluteValue(event_ndims=1)
-
-  def testEventNdimsMustBeZeroOrRaiseDynamic(self):
-    with self.test_session() as sess:
-      event_ndims = array_ops.placeholder(dtypes.int32)
-      abs_bijector = AbsoluteValue(event_ndims=event_ndims, validate_args=True)
-      with self.assertRaisesOpError("event_ndims was not 0"):
-        sess.run(abs_bijector.inverse_log_det_jacobian([1.]),
-                 feed_dict={event_ndims: 1})
+      self.assertAllClose((0., 0.),
+                          sess.run(bijector.inverse_log_det_jacobian(
+                              y, event_ndims=0)))
 
   def testNegativeYRaisesForInverseIfValidateArgs(self):
     with self.test_session() as sess:
-      bijector = AbsoluteValue(event_ndims=0, validate_args=True)
+      bijector = AbsoluteValue(validate_args=True)
       with self.assertRaisesOpError("y was negative"):
         sess.run(bijector.inverse(-1.))
 
   def testNegativeYRaisesForILDJIfValidateArgs(self):
     with self.test_session() as sess:
-      bijector = AbsoluteValue(event_ndims=0, validate_args=True)
+      bijector = AbsoluteValue(validate_args=True)
       with self.assertRaisesOpError("y was negative"):
-        sess.run(bijector.inverse_log_det_jacobian(-1.))
+        sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0))
 
 
 if __name__ == "__main__":
index 405ddd2..1e4ad72 100644 (file)
@@ -38,9 +38,11 @@ class AffineLinearOperatorTest(test.TestCase):
       self.assertEqual(affine.name, "affine_linear_operator")
       self.assertAllClose(y, affine.forward(x).eval())
       self.assertAllClose(x, affine.inverse(y).eval())
-      self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval())
-      self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(),
-                          affine.forward_log_det_jacobian(x).eval())
+      self.assertAllClose(ildj, affine.inverse_log_det_jacobian(
+          y, event_ndims=2).eval())
+      self.assertAllClose(
+          -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(),
+          affine.forward_log_det_jacobian(x, event_ndims=2).eval())
 
   def testDiag(self):
     with self.test_session():
@@ -58,14 +60,16 @@ class AffineLinearOperatorTest(test.TestCase):
       self.assertEqual(affine.name, "affine_linear_operator")
       self.assertAllClose(y, affine.forward(x).eval())
       self.assertAllClose(x, affine.inverse(y).eval())
-      self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval())
-      self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(),
-                          affine.forward_log_det_jacobian(x).eval())
+      self.assertAllClose(
+          ildj, affine.inverse_log_det_jacobian(y, event_ndims=1).eval())
+      self.assertAllClose(
+          -affine.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          affine.forward_log_det_jacobian(x, event_ndims=1).eval())
 
   def testTriL(self):
     with self.test_session():
       shift = np.array([-1, 0, 1], dtype=np.float32)
-      tril = np.array([[[1, 0, 0],
+      tril = np.array([[[3, 0, 0],
                         [2, -1, 0],
                         [3, 2, 1]],
                        [[2, 0, 0],
@@ -85,15 +89,17 @@ class AffineLinearOperatorTest(test.TestCase):
       # y = np.matmul(x, tril) + shift.
       y = np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift
       ildj = -np.sum(np.log(np.abs(np.diagonal(
-          tril, axis1=-2, axis2=-1))),
-                     axis=-1)
+          tril, axis1=-2, axis2=-1))))
 
       self.assertEqual(affine.name, "affine_linear_operator")
       self.assertAllClose(y, affine.forward(x).eval())
       self.assertAllClose(x, affine.inverse(y).eval())
-      self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval())
-      self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(),
-                          affine.forward_log_det_jacobian(x).eval())
+      self.assertAllClose(
+          ildj, affine.inverse_log_det_jacobian(
+              y, event_ndims=2).eval())
+      self.assertAllClose(
+          -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(),
+          affine.forward_log_det_jacobian(x, event_ndims=2).eval())
 
 
 if __name__ == "__main__":
index 16173a1..d253362 100644 (file)
@@ -40,13 +40,13 @@ class AffineScalarBijectorTest(test.TestCase):
   def testNoBatchScalar(self):
     with self.test_session() as sess:
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
         x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(fun(x, **kwargs), feed_dict={x: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -55,19 +55,20 @@ class AffineScalarBijectorTest(test.TestCase):
         x = [1., 2, 3]  # Three scalar samples (no batches).
         self.assertAllClose([1., 3, 5], run(bijector.forward, x))
         self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
-        self.assertAllClose([-np.log(2.)] * 3,
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(2.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
 
   def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
     with self.test_session() as sess:
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value).astype(np.float64)
         x = array_ops.placeholder(dtypes.float64, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(fun(x, **kwargs), feed_dict={x: x_value})
 
       for run in (static_run, dynamic_run):
         mu = np.float64([1.])
@@ -77,18 +78,20 @@ class AffineScalarBijectorTest(test.TestCase):
         x = np.float64([1.])  # One sample from one batches.
         self.assertAllClose([2.], run(bijector.forward, x))
         self.assertAllClose([0.], run(bijector.inverse, x))
-        self.assertAllClose([0.], run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            0.,
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
 
   def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self):
     with self.test_session() as sess:
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value).astype(np.float64)
         x = array_ops.placeholder(dtypes.float64, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(fun(x, **kwargs), feed_dict={x: x_value})
 
       for run in (static_run, dynamic_run):
         multiplier = np.float64([2.])
@@ -98,19 +101,20 @@ class AffineScalarBijectorTest(test.TestCase):
         x = np.float64([1.])  # One sample from one batches.
         self.assertAllClose([2.], run(bijector.forward, x))
         self.assertAllClose([0.5], run(bijector.inverse, x))
-        self.assertAllClose([np.log(0.5)],
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            [np.log(0.5)],
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
 
   def testTwoBatchScalarIdentityViaIdentity(self):
     with self.test_session() as sess:
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
+      def dynamic_run(fun, x_value, **kwargs):
+        x_value = np.array(x_value).astype(np.float32)
         x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(fun(x, **kwargs), feed_dict={x: x_value})
 
       for run in (static_run, dynamic_run):
         mu = [1., -1]
@@ -120,18 +124,20 @@ class AffineScalarBijectorTest(test.TestCase):
         x = [1., 1]  # One sample from each of two batches.
         self.assertAllClose([2., 0], run(bijector.forward, x))
         self.assertAllClose([0., 2], run(bijector.inverse, x))
-        self.assertAllClose([0., 0.], run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            0.,
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
 
   def testTwoBatchScalarIdentityViaScale(self):
     with self.test_session() as sess:
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value)
+      def dynamic_run(fun, x_value, **kwargs):
+        x_value = np.array(x_value).astype(np.float32)
         x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(fun(x, **kwargs), feed_dict={x: x_value})
 
       for run in (static_run, dynamic_run):
         mu = [1., -1]
@@ -142,7 +148,8 @@ class AffineScalarBijectorTest(test.TestCase):
         self.assertAllClose([3., 0], run(bijector.forward, x))
         self.assertAllClose([0., 2], run(bijector.inverse, x))
         self.assertAllClose(
-            [-np.log(2), 0.], run(bijector.inverse_log_det_jacobian, x))
+            [-np.log(2), 0.],
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
 
   def testScalarCongruency(self):
     with self.test_session():
index 077e617..9e14b9a 100644 (file)
@@ -40,14 +40,15 @@ class AffineBijectorTest(test.TestCase):
 
   def testNoBatchMultivariateIdentity(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = [1., -1]
@@ -66,18 +67,20 @@ class AffineBijectorTest(test.TestCase):
         x = [[1., 1], [-1., -1]]
         self.assertAllClose([[2., 0], [0., -2]], run(bijector.forward, x))
         self.assertAllClose([[0., 2], [-2., 0]], run(bijector.inverse, x))
-        self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testNoBatchMultivariateDiag(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = [1., -1]
@@ -89,9 +92,12 @@ class AffineBijectorTest(test.TestCase):
         # = [-1, -1] + [1, -1]
         self.assertAllClose([3., 0], run(bijector.forward, x))
         self.assertAllClose([0., 2], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(2.),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(2.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
+        # Reset bijector.
+        bijector = Affine(shift=mu, scale_diag=[2., 1])
         # x is a 2-batch of 2-vectors.
         # The first vector is [1, 1], the second is [-1, -1].
         # Each undergoes matmul(sigma, x) + shift.
@@ -103,8 +109,9 @@ class AffineBijectorTest(test.TestCase):
         self.assertAllClose([[0., 2],
                              [-1., 0]],
                             run(bijector.inverse, x))
-        self.assertAllClose(-np.log(2.),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(2.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testNoBatchMultivariateFullDynamic(self):
     with self.test_session() as sess:
@@ -126,18 +133,20 @@ class AffineBijectorTest(test.TestCase):
       self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict))
       self.assertAllClose(
           -np.log(4),
-          sess.run(bijector.inverse_log_det_jacobian(x), feed_dict))
+          sess.run(bijector.inverse_log_det_jacobian(x, event_ndims=1),
+                   feed_dict))
 
   def testBatchMultivariateIdentity(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value, dtype=np.float32)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+      def dynamic_run(fun, x_value, **kwargs):
+        x_value = np.array(x_value)
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = [[1., -1]]
@@ -147,19 +156,21 @@ class AffineBijectorTest(test.TestCase):
         x = [[[1., 1]]]
         self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
         self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(4),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(4),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testBatchMultivariateDiag(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
-        x_value = np.array(x_value, dtype=np.float32)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+      def dynamic_run(fun, x_value, **kwargs):
+        x_value = np.array(x_value)
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = [[1., -1]]
@@ -169,8 +180,9 @@ class AffineBijectorTest(test.TestCase):
         x = [[[1., 1]]]
         self.assertAllClose([[[3., 1]]], run(bijector.forward, x))
         self.assertAllClose([[[0., 1]]], run(bijector.inverse, x))
-        self.assertAllClose([-np.log(4)],
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            [-np.log(4)],
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testBatchMultivariateFullDynamic(self):
     with self.test_session() as sess:
@@ -191,20 +203,22 @@ class AffineBijectorTest(test.TestCase):
       bijector = Affine(shift=mu, scale_diag=scale_diag)
       self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict))
       self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict))
-      self.assertAllClose([-np.log(4)],
-                          sess.run(
-                              bijector.inverse_log_det_jacobian(x), feed_dict))
+      self.assertAllClose(
+          [-np.log(4)],
+          sess.run(bijector.inverse_log_det_jacobian(
+              x, event_ndims=1), feed_dict))
 
   def testIdentityWithDiagUpdate(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -216,19 +230,21 @@ class AffineBijectorTest(test.TestCase):
         x = [1., 2, 3]  # Three scalar samples (no batches).
         self.assertAllClose([1., 3, 5], run(bijector.forward, x))
         self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(2.**3),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(2.**3),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testIdentityWithTriL(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -240,19 +256,21 @@ class AffineBijectorTest(test.TestCase):
         x = [[1., 2]]  # One multivariate sample.
         self.assertAllClose([[1., 5]], run(bijector.forward, x))
         self.assertAllClose([[1., 0.5]], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(4.),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(4.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testDiagWithTriL(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -262,19 +280,21 @@ class AffineBijectorTest(test.TestCase):
         x = [[1., 2]]  # One multivariate sample.
         self.assertAllClose([[1., 7]], run(bijector.forward, x))
         self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(6.),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(6.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testIdentityAndDiagWithTriL(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -287,19 +307,21 @@ class AffineBijectorTest(test.TestCase):
         x = [[1., 2]]  # One multivariate sample.
         self.assertAllClose([[2., 9]], run(bijector.forward, x))
         self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x))
-        self.assertAllClose(-np.log(12.),
-                            run(bijector.inverse_log_det_jacobian, x))
+        self.assertAllClose(
+            -np.log(12.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testIdentityWithVDVTUpdate(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -319,22 +341,24 @@ class AffineBijectorTest(test.TestCase):
         self.assertAllClose([0.2, 1.5, 4 / 3.], run(bijector.inverse, x))
         self.assertAllClose(
             run(bijector_ref.inverse, x), run(bijector.inverse, x))
-        self.assertAllClose(-np.log(60.),
-                            run(bijector.inverse_log_det_jacobian, x))
         self.assertAllClose(
-            run(bijector.inverse_log_det_jacobian, x),
-            run(bijector_ref.inverse_log_det_jacobian, x))
+            -np.log(60.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
+        self.assertAllClose(
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1),
+            run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testDiagWithVDVTUpdate(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -353,22 +377,24 @@ class AffineBijectorTest(test.TestCase):
         self.assertAllClose([0.2, 1., 0.8], run(bijector.inverse, x))
         self.assertAllClose(
             run(bijector_ref.inverse, x), run(bijector.inverse, x))
-        self.assertAllClose(-np.log(150.),
-                            run(bijector.inverse_log_det_jacobian, x))
         self.assertAllClose(
-            run(bijector.inverse_log_det_jacobian, x),
-            run(bijector_ref.inverse_log_det_jacobian, x))
+            -np.log(150.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
+        self.assertAllClose(
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1),
+            run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testTriLWithVDVTUpdate(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -388,22 +414,24 @@ class AffineBijectorTest(test.TestCase):
         self.assertAllClose([0.2, 14 / 15., 4 / 25.], run(bijector.inverse, x))
         self.assertAllClose(
             run(bijector_ref.inverse, x), run(bijector.inverse, x))
-        self.assertAllClose(-np.log(150.),
-                            run(bijector.inverse_log_det_jacobian, x))
         self.assertAllClose(
-            run(bijector.inverse_log_det_jacobian, x),
-            run(bijector_ref.inverse_log_det_jacobian, x))
+            -np.log(150.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
+        self.assertAllClose(
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1),
+            run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testTriLWithVDVTUpdateNoDiagonal(self):
     with self.test_session() as sess:
+      placeholder = array_ops.placeholder(dtypes.float32, name="x")
 
-      def static_run(fun, x):
-        return fun(x).eval()
+      def static_run(fun, x, **kwargs):
+        return fun(x, **kwargs).eval()
 
-      def dynamic_run(fun, x_value):
+      def dynamic_run(fun, x_value, **kwargs):
         x_value = np.array(x_value)
-        x = array_ops.placeholder(dtypes.float32, name="x")
-        return sess.run(fun(x), feed_dict={x: x_value})
+        return sess.run(
+            fun(placeholder, **kwargs), feed_dict={placeholder: x_value})
 
       for run in (static_run, dynamic_run):
         mu = -1.
@@ -423,11 +451,12 @@ class AffineBijectorTest(test.TestCase):
         self.assertAllClose([1 / 3., 8 / 9., 4 / 30.], run(bijector.inverse, x))
         self.assertAllClose(
             run(bijector_ref.inverse, x), run(bijector.inverse, x))
-        self.assertAllClose(-np.log(90.),
-                            run(bijector.inverse_log_det_jacobian, x))
         self.assertAllClose(
-            run(bijector.inverse_log_det_jacobian, x),
-            run(bijector_ref.inverse_log_det_jacobian, x))
+            -np.log(90.),
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
+        self.assertAllClose(
+            run(bijector.inverse_log_det_jacobian, x, event_ndims=1),
+            run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
 
   def testNoBatchMultivariateRaisesWhenSingular(self):
     with self.test_session():
@@ -530,6 +559,7 @@ class AffineBijectorTest(test.TestCase):
             backward = np.squeeze(backward, axis=-1)
           self.assertAllClose(backward, bijector.inverse(x).eval())
 
+          scale *= np.ones(shape=x.shape[:-1], dtype=scale.dtype)
           ildj = -np.log(np.abs(np.linalg.det(scale)))
           # TODO(jvdillon): We need to make it so the scale_identity_multiplier
           # case does not deviate in expected shape. Fixing this will get rid of
@@ -540,7 +570,8 @@ class AffineBijectorTest(test.TestCase):
             ildj = np.squeeze(ildj[0])
           elif ildj.ndim < scale.ndim - 2:
             ildj = np.reshape(ildj, scale.shape[0:-2])
-          self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(x).eval())
+          self.assertAllClose(
+              ildj, bijector.inverse_log_det_jacobian(x, event_ndims=1).eval())
 
   def testLegalInputs(self):
     self._testLegalInputs(
index a215a4a..c832fca 100644 (file)
@@ -83,10 +83,11 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
           moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean)
           moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance)
           denorm_x = batch_norm.forward(array_ops.identity(norm_x))
-          fldj = batch_norm.forward_log_det_jacobian(x)
+          fldj = batch_norm.forward_log_det_jacobian(
+              x, event_ndims=len(event_dims))
           # Use identity to invalidate cache.
           ildj = batch_norm.inverse_log_det_jacobian(
-              array_ops.identity(denorm_x))
+              array_ops.identity(denorm_x), event_ndims=len(event_dims))
         variables.global_variables_initializer().run()
         # Update variables.
         norm_x_ = sess.run(norm_x)
index a748acd..ca20442 100644 (file)
@@ -20,21 +20,33 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
 from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
 from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
 from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
 from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops.distributions import bijector
 from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
 from tensorflow.python.platform import test
 
 
+class ShapeChanging(bijector.Bijector):
+  """Only used for op_ndims manipulation."""
+
+  def __init__(self, forward_min_event_ndims=0, inverse_min_event_ndims=3):
+    super(ShapeChanging, self).__init__(
+        forward_min_event_ndims=forward_min_event_ndims,
+        inverse_min_event_ndims=inverse_min_event_ndims,
+        validate_args=False, name="shape_changer")
+
+
 class ChainBijectorTest(test.TestCase):
   """Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation."""
 
   def testBijector(self):
     with self.test_session():
-      chain = Chain((Exp(event_ndims=1), Softplus(event_ndims=1)))
+      chain = Chain((Exp(), Softplus()))
       self.assertEqual("chain_of_exp_of_softplus", chain.name)
       x = np.asarray([[[1., 2.],
                        [2., 3.]]])
@@ -42,9 +54,10 @@ class ChainBijectorTest(test.TestCase):
       self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval())
       self.assertAllClose(
           -np.sum(np.log(x - 1.), axis=2),
-          chain.inverse_log_det_jacobian(x).eval())
+          chain.inverse_log_det_jacobian(x, event_ndims=1).eval())
       self.assertAllClose(
-          np.sum(x, axis=2), chain.forward_log_det_jacobian(x).eval())
+          np.sum(x, axis=2),
+          chain.forward_log_det_jacobian(x, event_ndims=1).eval())
 
   def testBijectorIdentity(self):
     with self.test_session():
@@ -54,31 +67,126 @@ class ChainBijectorTest(test.TestCase):
                        [2., 3.]]])
       self.assertAllClose(x, chain.forward(x).eval())
       self.assertAllClose(x, chain.inverse(x).eval())
-      self.assertAllClose(0., chain.inverse_log_det_jacobian(x).eval())
-      self.assertAllClose(0., chain.forward_log_det_jacobian(x).eval())
+      self.assertAllClose(
+          0., chain.inverse_log_det_jacobian(x, event_ndims=1).eval())
+      self.assertAllClose(
+          0., chain.forward_log_det_jacobian(x, event_ndims=1).eval())
 
   def testScalarCongruency(self):
     with self.test_session():
-      bijector = Chain((Exp(), Softplus()))
+      chain = Chain((Exp(), Softplus()))
       assert_scalar_congruency(
-          bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
+          chain, lower_x=1e-3, upper_x=1.5, rtol=0.05)
 
   def testShapeGetters(self):
     with self.test_session():
-      bijector = Chain([
+      chain = Chain([
           SoftmaxCentered(validate_args=True),
           SoftmaxCentered(validate_args=True),
       ])
       x = tensor_shape.TensorShape([1])
       y = tensor_shape.TensorShape([2 + 1])
-      self.assertAllEqual(y, bijector.forward_event_shape(x))
+      self.assertAllEqual(y, chain.forward_event_shape(x))
       self.assertAllEqual(
           y.as_list(),
-          bijector.forward_event_shape_tensor(x.as_list()).eval())
-      self.assertAllEqual(x, bijector.inverse_event_shape(y))
+          chain.forward_event_shape_tensor(x.as_list()).eval())
+      self.assertAllEqual(x, chain.inverse_event_shape(y))
       self.assertAllEqual(
           x.as_list(),
-          bijector.inverse_event_shape_tensor(y.as_list()).eval())
+          chain.inverse_event_shape_tensor(y.as_list()).eval())
+
+  def testMinEventNdimsChain(self):
+    chain = Chain([Exp(), Exp(), Exp()])
+    self.assertEqual(0, chain.forward_min_event_ndims)
+    self.assertEqual(0, chain.inverse_min_event_ndims)
+
+    chain = Chain([Affine(), Affine(), Affine()])
+    self.assertEqual(1, chain.forward_min_event_ndims)
+    self.assertEqual(1, chain.inverse_min_event_ndims)
+
+    chain = Chain([Exp(), Affine()])
+    self.assertEqual(1, chain.forward_min_event_ndims)
+    self.assertEqual(1, chain.inverse_min_event_ndims)
+
+    chain = Chain([Affine(), Exp()])
+    self.assertEqual(1, chain.forward_min_event_ndims)
+    self.assertEqual(1, chain.inverse_min_event_ndims)
+
+    chain = Chain([Affine(), Exp(), Softplus(), Affine()])
+    self.assertEqual(1, chain.forward_min_event_ndims)
+    self.assertEqual(1, chain.inverse_min_event_ndims)
+
+  def testMinEventNdimsShapeChangingAddDims(self):
+    chain = Chain([ShapeChanging()])
+    self.assertEqual(0, chain.forward_min_event_ndims)
+    self.assertEqual(3, chain.inverse_min_event_ndims)
+
+    chain = Chain([ShapeChanging(), Affine()])
+    self.assertEqual(1, chain.forward_min_event_ndims)
+    self.assertEqual(4, chain.inverse_min_event_ndims)
+
+    chain = Chain([Affine(), ShapeChanging()])
+    self.assertEqual(0, chain.forward_min_event_ndims)
+    self.assertEqual(3, chain.inverse_min_event_ndims)
+
+    chain = Chain([ShapeChanging(), ShapeChanging()])
+    self.assertEqual(0, chain.forward_min_event_ndims)
+    self.assertEqual(6, chain.inverse_min_event_ndims)
+
+  def testMinEventNdimsShapeChangingRemoveDims(self):
+    chain = Chain([ShapeChanging(3, 0)])
+    self.assertEqual(3, chain.forward_min_event_ndims)
+    self.assertEqual(0, chain.inverse_min_event_ndims)
+
+    chain = Chain([ShapeChanging(3, 0), Affine()])
+    self.assertEqual(3, chain.forward_min_event_ndims)
+    self.assertEqual(0, chain.inverse_min_event_ndims)
+
+    chain = Chain([Affine(), ShapeChanging(3, 0)])
+    self.assertEqual(4, chain.forward_min_event_ndims)
+    self.assertEqual(1, chain.inverse_min_event_ndims)
+
+    chain = Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)])
+    self.assertEqual(6, chain.forward_min_event_ndims)
+    self.assertEqual(0, chain.inverse_min_event_ndims)
+
+  def testMinEventNdimsShapeChangingAddRemoveDims(self):
+    chain = Chain([
+        ShapeChanging(2, 1),
+        ShapeChanging(3, 0),
+        ShapeChanging(1, 2)])
+    self.assertEqual(4, chain.forward_min_event_ndims)
+    self.assertEqual(1, chain.inverse_min_event_ndims)
+
+  def testChainExpAffine(self):
+    scale_diag = np.array([1., 2., 3.], dtype=np.float32)
+    chain = Chain([Exp(), Affine(scale_diag=scale_diag)])
+    x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
+    y = [1., 4., 27.]
+    self.assertAllClose(y, self.evaluate(chain.forward(x)))
+    self.assertAllClose(x, self.evaluate(chain.inverse(y)))
+    self.assertAllClose(
+        np.log(6, dtype=np.float32) + np.sum(scale_diag * x),
+        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))
+
+    self.assertAllClose(
+        -np.log(6, dtype=np.float32) - np.sum(scale_diag * x),
+        self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
+
+  def testChainAffineExp(self):
+    scale_diag = np.array([1., 2., 3.], dtype=np.float32)
+    chain = Chain([Affine(scale_diag=scale_diag), Exp()])
+    x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)]
+    y = [1., 4., 9.]
+    self.assertAllClose(y, self.evaluate(chain.forward(x)))
+    self.assertAllClose(x, self.evaluate(chain.inverse(y)))
+    self.assertAllClose(
+        np.log(6, dtype=np.float32) + np.sum(x),
+        self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))
+
+    self.assertAllClose(
+        -np.log(6, dtype=np.float32) - np.sum(x),
+        self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1)))
 
 
 if __name__ == "__main__":
index f392e83..e281e81 100644 (file)
@@ -51,10 +51,13 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
       self.assertAllClose(y, bijector.forward(x).eval())
       self.assertAllClose(x, bijector.inverse(y).eval())
       self.assertAllClose(
-          ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7)
+          ildj, bijector.inverse_log_det_jacobian(
+              y, event_ndims=2).eval(), atol=0., rtol=1e-7)
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
+          -bijector.inverse_log_det_jacobian(
+              y, event_ndims=2).eval(),
+          bijector.forward_log_det_jacobian(
+              x, event_ndims=2).eval(),
           atol=0.,
           rtol=1e-7)
 
index 26e0d2a..8b279eb 100644 (file)
@@ -27,7 +27,7 @@ class _TestBijector(ConditionalBijector):
 
   def __init__(self):
     super(_TestBijector, self).__init__(
-        event_ndims=0,
+        forward_min_event_ndims=0,
         graph_parents=[],
         is_constant_jacobian=True,
         validate_args=False,
@@ -51,11 +51,15 @@ class ConditionalBijectorTest(test.TestCase):
 
   def testConditionalBijector(self):
     b = _TestBijector()
-    for name in ["forward", "inverse", "inverse_log_det_jacobian",
-                 "forward_log_det_jacobian"]:
+    for name in ["forward", "inverse"]:
       method = getattr(b, name)
       with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"):
-        method(1.0, arg1="b1", arg2="b2")
+        method(1., arg1="b1", arg2="b2")
+
+    for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]:
+      method = getattr(b, name)
+      with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"):
+        method(1., event_ndims=0., arg1="b1", arg2="b2")
 
 
 if __name__ == "__main__":
index 9970c0b..7be939c 100644 (file)
@@ -31,17 +31,21 @@ class ExpBijectorTest(test.TestCase):
 
   def testBijector(self):
     with self.test_session():
-      bijector = Exp(event_ndims=1)
+      bijector = Exp()
       self.assertEqual("exp", bijector.name)
       x = [[[1.], [2.]]]
       y = np.exp(x)
       self.assertAllClose(y, bijector.forward(x).eval())
       self.assertAllClose(x, bijector.inverse(y).eval())
       self.assertAllClose(
-          -np.sum(np.log(y), axis=-1),
-          bijector.inverse_log_det_jacobian(y).eval())
-      self.assertAllClose(-bijector.inverse_log_det_jacobian(np.exp(x)).eval(),
-                          bijector.forward_log_det_jacobian(x).eval())
+          -np.squeeze(np.log(y), axis=-1),
+          bijector.inverse_log_det_jacobian(
+              y, event_ndims=1).eval())
+      self.assertAllClose(
+          -bijector.inverse_log_det_jacobian(
+              np.exp(x), event_ndims=1).eval(),
+          bijector.forward_log_det_jacobian(
+              x, event_ndims=1).eval())
 
   def testScalarCongruency(self):
     with self.test_session():
@@ -51,10 +55,10 @@ class ExpBijectorTest(test.TestCase):
 
   def testBijectiveAndFinite(self):
     with self.test_session():
-      bijector = Exp(event_ndims=0)
+      bijector = Exp()
       x = np.linspace(-10, 10, num=10).astype(np.float32)
       y = np.logspace(-10, 10, num=10).astype(np.float32)
-      assert_bijective_and_finite(bijector, x, y)
+      assert_bijective_and_finite(bijector, x, y, event_ndims=0)
 
 
 if __name__ == "__main__":
index 9a90598..54e54c3 100644 (file)
@@ -34,7 +34,7 @@ class GumbelBijectorTest(test.TestCase):
     with self.test_session():
       loc = 0.3
       scale = 5.
-      bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True)
+      bijector = Gumbel(loc=loc, scale=scale, validate_args=True)
       self.assertEqual("gumbel", bijector.name)
       x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32)
       # Gumbel distribution
@@ -43,13 +43,11 @@ class GumbelBijectorTest(test.TestCase):
       self.assertAllClose(y, bijector.forward(x).eval())
       self.assertAllClose(x, bijector.inverse(y).eval())
       self.assertAllClose(
-          # We should lose a dimension from calculating the determinant of the
-          # jacobian.
-          np.squeeze(gumbel_dist.logpdf(x), axis=2),
-          bijector.forward_log_det_jacobian(x).eval())
+          np.squeeze(gumbel_dist.logpdf(x), axis=-1),
+          bijector.forward_log_det_jacobian(x, event_ndims=1).eval())
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
+          -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          bijector.forward_log_det_jacobian(x, event_ndims=1).eval(),
           rtol=1e-4,
           atol=0.)
 
@@ -60,10 +58,10 @@ class GumbelBijectorTest(test.TestCase):
 
   def testBijectiveAndFinite(self):
     with self.test_session():
-      bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True)
+      bijector = Gumbel(loc=0., scale=3.0, validate_args=True)
       x = np.linspace(-10., 10., num=10).astype(np.float32)
       y = np.linspace(0.01, 0.99, num=10).astype(np.float32)
-      assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+      assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3)
 
 
 if __name__ == "__main__":
index 739fa6d..7d3bd75 100644 (file)
@@ -33,15 +33,13 @@ class InlineBijectorTest(test.TestCase):
 
   def testBijector(self):
     with self.test_session():
-      exp = Exp(event_ndims=1)
+      exp = Exp()
       inline = Inline(
           forward_fn=math_ops.exp,
           inverse_fn=math_ops.log,
-          inverse_log_det_jacobian_fn=(
-              lambda y: -math_ops.reduce_sum(  # pylint: disable=g-long-lambda
-                  math_ops.log(y), reduction_indices=-1)),
-          forward_log_det_jacobian_fn=(
-              lambda x: math_ops.reduce_sum(x, reduction_indices=-1)),
+          inverse_log_det_jacobian_fn=lambda y: -math_ops.log(y),
+          forward_log_det_jacobian_fn=lambda x: x,
+          forward_min_event_ndims=0,
           name="exp")
 
       self.assertEqual(exp.name, inline.name)
@@ -51,9 +49,10 @@ class InlineBijectorTest(test.TestCase):
       self.assertAllClose(x, inline.inverse(y).eval())
       self.assertAllClose(
           -np.sum(np.log(y), axis=-1),
-          inline.inverse_log_det_jacobian(y).eval())
-      self.assertAllClose(-inline.inverse_log_det_jacobian(y).eval(),
-                          inline.forward_log_det_jacobian(x).eval())
+          inline.inverse_log_det_jacobian(y, event_ndims=1).eval())
+      self.assertAllClose(
+          -inline.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          inline.forward_log_det_jacobian(x, event_ndims=1).eval())
 
   def testShapeGetters(self):
     with self.test_session():
@@ -62,6 +61,7 @@ class InlineBijectorTest(test.TestCase):
           forward_event_shape_fn=lambda x: x.as_list() + [1],
           inverse_event_shape_tensor_fn=lambda x: x[:-1],
           inverse_event_shape_fn=lambda x: x[:-1],
+          forward_min_event_ndims=0,
           name="shape_only")
       x = tensor_shape.TensorShape([1, 2, 3])
       y = tensor_shape.TensorShape([1, 2, 3, 1])
index 58ba9ce..8b14c83 100644 (file)
@@ -34,9 +34,9 @@ class InvertBijectorTest(test.TestCase):
     with self.test_session():
       for fwd in [
           bijectors.Identity(),
-          bijectors.Exp(event_ndims=1),
+          bijectors.Exp(),
           bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]),
-          bijectors.Softplus(event_ndims=1),
+          bijectors.Softplus(),
           bijectors.SoftmaxCentered(),
       ]:
         rev = bijectors.Invert(fwd)
@@ -46,11 +46,11 @@ class InvertBijectorTest(test.TestCase):
         self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval())
         self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval())
         self.assertAllClose(
-            fwd.forward_log_det_jacobian(x).eval(),
-            rev.inverse_log_det_jacobian(x).eval())
+            fwd.forward_log_det_jacobian(x, event_ndims=1).eval(),
+            rev.inverse_log_det_jacobian(x, event_ndims=1).eval())
         self.assertAllClose(
-            fwd.inverse_log_det_jacobian(x).eval(),
-            rev.forward_log_det_jacobian(x).eval())
+            fwd.inverse_log_det_jacobian(x, event_ndims=1).eval(),
+            rev.forward_log_det_jacobian(x, event_ndims=1).eval())
 
   def testScalarCongruency(self):
     with self.test_session():
index 074b5f2..a808988 100644 (file)
@@ -34,8 +34,7 @@ class KumaraswamyBijectorTest(test.TestCase):
       a = 2.
       b = 0.3
       bijector = Kumaraswamy(
-          concentration1=a, concentration0=b,
-          event_ndims=0, validate_args=True)
+          concentration1=a, concentration0=b, validate_args=True)
       self.assertEqual("kumaraswamy", bijector.name)
       x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32)
       # Kumaraswamy cdf. This is the same as inverse(x).
@@ -46,13 +45,11 @@ class KumaraswamyBijectorTest(test.TestCase):
                              (b - 1) * np.log1p(-x ** a))
 
       self.assertAllClose(
-          # We should lose a dimension from calculating the determinant of the
-          # jacobian.
-          kumaraswamy_log_pdf,
-          bijector.inverse_log_det_jacobian(x).eval())
+          np.squeeze(kumaraswamy_log_pdf, axis=-1),
+          bijector.inverse_log_det_jacobian(x, event_ndims=1).eval())
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(x).eval(),
-          bijector.forward_log_det_jacobian(y).eval(),
+          -bijector.inverse_log_det_jacobian(x, event_ndims=1).eval(),
+          bijector.forward_log_det_jacobian(y, event_ndims=1).eval(),
           rtol=1e-4,
           atol=0.)
 
@@ -73,7 +70,7 @@ class KumaraswamyBijectorTest(test.TestCase):
       # endpoints.
       y = np.linspace(.01, 0.99, num=10).astype(np.float32)
       x = 1 - (1 - y ** concentration1) ** concentration0
-      assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+      assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3)
 
 
 if __name__ == "__main__":
index dcfb0eb..5ba5a20 100644 (file)
@@ -79,9 +79,10 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
       forward_x = ma.forward(x)
       # Use identity to invalidate cache.
       inverse_y = ma.inverse(array_ops.identity(forward_x))
-      fldj = ma.forward_log_det_jacobian(x)
+      fldj = ma.forward_log_det_jacobian(x, event_ndims=1)
       # Use identity to invalidate cache.
-      ildj = ma.inverse_log_det_jacobian(array_ops.identity(forward_x))
+      ildj = ma.inverse_log_det_jacobian(
+          array_ops.identity(forward_x), event_ndims=1)
       variables.global_variables_initializer().run()
       [
           forward_x_,
index 54590de..7eef4ab 100644 (file)
@@ -53,8 +53,8 @@ class PermuteBijectorTest(test.TestCase):
           bijector.permutation,
           bijector.inverse(expected_y),
           bijector.forward(expected_x),
-          bijector.forward_log_det_jacobian(expected_x),
-          bijector.inverse_log_det_jacobian(expected_y),
+          bijector.forward_log_det_jacobian(expected_x, event_ndims=1),
+          bijector.inverse_log_det_jacobian(expected_y, event_ndims=1),
       ], feed_dict={permutation_ph: expected_permutation})
       self.assertEqual("permute", bijector.name)
       self.assertAllEqual(expected_permutation, permutation_)
@@ -78,10 +78,9 @@ class PermuteBijectorTest(test.TestCase):
     x = np.random.randn(4, 2, 3)
     y = x[..., permutation]
     with self.test_session():
-      bijector = Permute(
-          permutation=permutation,
-          validate_args=True)
-      assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
+      bijector = Permute(permutation=permutation, validate_args=True)
+      assert_bijective_and_finite(
+          bijector, x, y, event_ndims=1, rtol=1e-6, atol=0)
 
 if __name__ == "__main__":
   test.main()
index de1659a..85d2283 100644 (file)
@@ -32,8 +32,7 @@ class PowerTransformBijectorTest(test.TestCase):
   def testBijector(self):
     with self.test_session():
       c = 0.2
-      bijector = PowerTransform(
-          power=c, event_ndims=1, validate_args=True)
+      bijector = PowerTransform(power=c, validate_args=True)
       self.assertEqual("power_transform", bijector.name)
       x = np.array([[[-1.], [2.], [-5. + 1e-4]]])
       y = (1. + x * c)**(1. / c)
@@ -41,27 +40,25 @@ class PowerTransformBijectorTest(test.TestCase):
       self.assertAllClose(x, bijector.inverse(y).eval())
       self.assertAllClose(
           (c - 1.) * np.sum(np.log(y), axis=-1),
-          bijector.inverse_log_det_jacobian(y).eval())
+          bijector.inverse_log_det_jacobian(y, event_ndims=1).eval())
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
+          -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          bijector.forward_log_det_jacobian(x, event_ndims=1).eval(),
           rtol=1e-4,
           atol=0.)
 
   def testScalarCongruency(self):
     with self.test_session():
-      bijector = PowerTransform(
-          power=0.2, validate_args=True)
+      bijector = PowerTransform(power=0.2, validate_args=True)
       assert_scalar_congruency(
           bijector, lower_x=-2., upper_x=1.5, rtol=0.05)
 
   def testBijectiveAndFinite(self):
     with self.test_session():
-      bijector = PowerTransform(
-          power=0.2, event_ndims=0, validate_args=True)
+      bijector = PowerTransform(power=0.2, validate_args=True)
       x = np.linspace(-4.999, 10, num=10).astype(np.float32)
       y = np.logspace(0.001, 10, num=10).astype(np.float32)
-      assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+      assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3)
 
 
 if __name__ == "__main__":
index 46fe779..2d52895 100644 (file)
@@ -52,24 +52,28 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
       forward_x = nvp.forward(x)
       # Use identity to invalidate cache.
       inverse_y = nvp.inverse(array_ops.identity(forward_x))
-      fldj = nvp.forward_log_det_jacobian(x)
+      forward_inverse_y = nvp.forward(inverse_y)
+      fldj = nvp.forward_log_det_jacobian(x, event_ndims=1)
       # Use identity to invalidate cache.
-      ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x))
+      ildj = nvp.inverse_log_det_jacobian(
+          array_ops.identity(forward_x), event_ndims=1)
       variables.global_variables_initializer().run()
       [
           forward_x_,
           inverse_y_,
+          forward_inverse_y_,
           ildj_,
           fldj_,
       ] = sess.run([
           forward_x,
           inverse_y,
+          forward_inverse_y,
           ildj,
           fldj,
       ])
       self.assertEqual("real_nvp", nvp.name)
-      self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.)
-      self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.)
+      self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-1, atol=0.)
+      self.assertAllClose(x_, inverse_y_, rtol=1e-1, atol=0.)
       self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.)
 
   def testMutuallyConsistent(self):
index e216d88..46f2c63 100644 (file)
@@ -65,8 +65,8 @@ class _ReshapeBijectorTest(object):
        ildj_) = sess.run((
            bijector.inverse(expected_y),
            bijector.forward(expected_x),
-           bijector.forward_log_det_jacobian(expected_x),
-           bijector.inverse_log_det_jacobian(expected_y),
+           bijector.forward_log_det_jacobian(expected_x, event_ndims=2),
+           bijector.inverse_log_det_jacobian(expected_y, event_ndims=2),
        ), feed_dict=feed_dict)
       self.assertEqual("reshape", bijector.name)
       self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
@@ -301,7 +301,8 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
           event_shape_in=[2, 3],
           event_shape_out=[1, 2, 3],
           validate_args=True)
-      assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
+      assert_bijective_and_finite(
+          bijector, x, y, event_ndims=2, rtol=1e-6, atol=0)
 
   def testInvalidDimensionsOpError(self):
     if ops._USE_C_API:
index e4f9d72..cea4a62 100644 (file)
@@ -36,12 +36,13 @@ class SigmoidBijectorTest(test.TestCase):
       x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32)
       y = special.expit(x)
       ildj = -np.log(y) - np.log1p(-y)
-      self.assertAllClose(y, Sigmoid().forward(x).eval(), atol=0., rtol=1e-2)
-      self.assertAllClose(x, Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4)
-      self.assertAllClose(ildj, Sigmoid().inverse_log_det_jacobian(y).eval(),
-                          atol=0., rtol=1e-6)
-      self.assertAllClose(-ildj, Sigmoid().forward_log_det_jacobian(x).eval(),
-                          atol=0., rtol=1e-4)
+      bijector = Sigmoid()
+      self.assertAllClose(y, bijector.forward(x).eval(), atol=0., rtol=1e-2)
+      self.assertAllClose(x, bijector.inverse(y).eval(), atol=0., rtol=1e-4)
+      self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(
+          y, event_ndims=0).eval(), atol=0., rtol=1e-6)
+      self.assertAllClose(-ildj, bijector.forward_log_det_jacobian(
+          x, event_ndims=0).eval(), atol=0., rtol=1e-4)
 
   def testScalarCongruency(self):
     with self.test_session():
@@ -52,7 +53,8 @@ class SigmoidBijectorTest(test.TestCase):
       x = np.linspace(-7., 7., 100).astype(np.float32)
       eps = 1e-3
       y = np.linspace(eps, 1. - eps, 100).astype(np.float32)
-      assert_bijective_and_finite(Sigmoid(), x, y, atol=0., rtol=1e-4)
+      assert_bijective_and_finite(
+          Sigmoid(), x, y, event_ndims=0, atol=0., rtol=1e-4)
 
 
 if __name__ == "__main__":
index 172c180..45760a2 100644 (file)
@@ -39,7 +39,6 @@ class SinhArcsinhBijectorTest(test.TestCase):
       bijector = SinhArcsinh(
           skewness=skewness,
           tailweight=tailweight,
-          event_ndims=1,
           validate_args=True)
       self.assertEqual("SinhArcsinh", bijector.name)
       x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32)
@@ -50,10 +49,11 @@ class SinhArcsinhBijectorTest(test.TestCase):
           np.sum(
               np.log(np.cosh(np.arcsinh(y) / tailweight - skewness)) -
               np.log(tailweight) - np.log(np.sqrt(y**2 + 1)),
-              axis=-1), bijector.inverse_log_det_jacobian(y).eval())
+              axis=-1),
+          bijector.inverse_log_det_jacobian(y, event_ndims=1).eval())
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
+          -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          bijector.forward_log_det_jacobian(x, event_ndims=1).eval(),
           rtol=1e-4,
           atol=0.)
 
@@ -106,14 +106,15 @@ class SinhArcsinhBijectorTest(test.TestCase):
       bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True)
       x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace(
           -2, 10, 1000))).astype(np.float32)
-      assert_bijective_and_finite(bijector, x, x, rtol=1e-3)
+      assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3)
 
   def testBijectiveAndFiniteSkewness1Tailweight3(self):
     with self.test_session():
       bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True)
       x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace(
           -2, 5, 1000))).astype(np.float32)
-      assert_bijective_and_finite(bijector, x, x, rtol=1e-3)
+      assert_bijective_and_finite(
+          bijector, x, x, event_ndims=0, rtol=1e-3)
 
   def testBijectorEndpoints(self):
     with self.test_session():
@@ -124,7 +125,8 @@ class SinhArcsinhBijectorTest(test.TestCase):
             [np.finfo(dtype).min, np.finfo(dtype).max], dtype=dtype)
         # Note that the above bijector is the identity bijector. Hence, the
         # log_det_jacobian will be 0. Because of this we use atol.
-        assert_bijective_and_finite(bijector, bounds, bounds, atol=2e-6)
+        assert_bijective_and_finite(
+            bijector, bounds, bounds, event_ndims=0, atol=2e-6)
 
   def testBijectorOverRange(self):
     with self.test_session():
@@ -156,12 +158,12 @@ class SinhArcsinhBijectorTest(test.TestCase):
                 np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt(
                     y_float128**2 + 1)) -
             np.log(tailweight),
-            bijector.inverse_log_det_jacobian(y).eval(),
+            bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
             rtol=1e-4,
             atol=0.)
         self.assertAllClose(
-            -bijector.inverse_log_det_jacobian(y).eval(),
-            bijector.forward_log_det_jacobian(x).eval(),
+            -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
+            bijector.forward_log_det_jacobian(x, event_ndims=0).eval(),
             rtol=1e-4,
             atol=0.)
 
index cad4dd1..0f0a2fa 100644 (file)
@@ -44,12 +44,12 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
       self.assertAllClose(x, softmax.inverse(y).eval())
       self.assertAllClose(
           -np.sum(np.log(y), axis=1),
-          softmax.inverse_log_det_jacobian(y).eval(),
+          softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(),
           atol=0.,
           rtol=1e-7)
       self.assertAllClose(
-          -softmax.inverse_log_det_jacobian(y).eval(),
-          softmax.forward_log_det_jacobian(x).eval(),
+          -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(),
+          softmax.forward_log_det_jacobian(x, event_ndims=1).eval(),
           atol=0.,
           rtol=1e-7)
 
@@ -67,14 +67,14 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
           feed_dict={y: real_y}))
       self.assertAllClose(
           -np.sum(np.log(real_y), axis=1),
-          softmax.inverse_log_det_jacobian(y).eval(
+          softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(
               feed_dict={y: real_y}),
           atol=0.,
           rtol=1e-7)
       self.assertAllClose(
-          -softmax.inverse_log_det_jacobian(y).eval(
+          -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(
               feed_dict={y: real_y}),
-          softmax.forward_log_det_jacobian(x).eval(
+          softmax.forward_log_det_jacobian(x, event_ndims=1).eval(
               feed_dict={x: real_x}),
           atol=0.,
           rtol=1e-7)
@@ -104,7 +104,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
       y = np.array([y_0, y_1, y_2])
       y /= y.sum(axis=0)
       y = y.T  # y.shape = [5, 3]
-      assert_bijective_and_finite(softmax, x, y)
+      assert_bijective_and_finite(softmax, x, y, event_ndims=1)
 
 
 if __name__ == "__main__":
index d9af9ae..3d8a0a3 100644 (file)
@@ -43,13 +43,13 @@ class SoftplusBijectorTest(test.TestCase):
 
   def testHingeSoftnessZeroRaises(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0, hinge_softness=0., validate_args=True)
+      bijector = Softplus(hinge_softness=0., validate_args=True)
       with self.assertRaisesOpError("must be non-zero"):
         bijector.forward([1., 1.]).eval()
 
   def testBijectorForwardInverseEventDimsZero(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0)
+      bijector = Softplus()
       self.assertEqual("softplus", bijector.name)
       x = 2 * rng.randn(2, 10)
       y = self._softplus(x)
@@ -59,7 +59,7 @@ class SoftplusBijectorTest(test.TestCase):
 
   def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0, hinge_softness=1.5)
+      bijector = Softplus(hinge_softness=1.5)
       x = 2 * rng.randn(2, 10)
       y = 1.5 * self._softplus(x / 1.5)
 
@@ -68,16 +68,17 @@ class SoftplusBijectorTest(test.TestCase):
 
   def testBijectorLogDetJacobianEventDimsZero(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0)
+      bijector = Softplus()
       y = 2 * rng.rand(2, 10)
       # No reduction needed if event_dims = 0.
       ildj = self._softplus_ildj_before_reduction(y)
 
-      self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval())
+      self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(
+          y, event_ndims=0).eval())
 
   def testBijectorForwardInverseEventDimsOne(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=1)
+      bijector = Softplus()
       self.assertEqual("softplus", bijector.name)
       x = 2 * rng.randn(2, 10)
       y = self._softplus(x)
@@ -87,58 +88,59 @@ class SoftplusBijectorTest(test.TestCase):
 
   def testBijectorLogDetJacobianEventDimsOne(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=1)
+      bijector = Softplus()
       y = 2 * rng.rand(2, 10)
       ildj_before = self._softplus_ildj_before_reduction(y)
       ildj = np.sum(ildj_before, axis=1)
 
-      self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval())
+      self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(
+          y, event_ndims=1).eval())
 
   def testScalarCongruency(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0)
+      bijector = Softplus()
       assert_scalar_congruency(
           bijector, lower_x=-2., upper_x=2.)
 
   def testScalarCongruencyWithPositiveHingeSoftness(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0, hinge_softness=1.3)
+      bijector = Softplus(hinge_softness=1.3)
       assert_scalar_congruency(
           bijector, lower_x=-2., upper_x=2.)
 
   def testScalarCongruencyWithNegativeHingeSoftness(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0, hinge_softness=-1.3)
+      bijector = Softplus(hinge_softness=-1.3)
       assert_scalar_congruency(
           bijector, lower_x=-2., upper_x=2.)
 
   def testBijectiveAndFinite32bit(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0)
+      bijector = Softplus()
       x = np.linspace(-20., 20., 100).astype(np.float32)
       y = np.logspace(-10, 10, 100).astype(np.float32)
       assert_bijective_and_finite(
-          bijector, x, y, rtol=1e-2, atol=1e-2)
+          bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
 
   def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0, hinge_softness=1.23)
+      bijector = Softplus(hinge_softness=1.23)
       x = np.linspace(-20., 20., 100).astype(np.float32)
       y = np.logspace(-10, 10, 100).astype(np.float32)
       assert_bijective_and_finite(
-          bijector, x, y, rtol=1e-2, atol=1e-2)
+          bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
 
   def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0, hinge_softness=-0.7)
+      bijector = Softplus(hinge_softness=-0.7)
       x = np.linspace(-20., 20., 100).astype(np.float32)
       y = -np.logspace(-10, 10, 100).astype(np.float32)
       assert_bijective_and_finite(
-          bijector, x, y, rtol=1e-2, atol=1e-2)
+          bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
 
   def testBijectiveAndFinite16bit(self):
     with self.test_session():
-      bijector = Softplus(event_ndims=0)
+      bijector = Softplus()
       # softplus(-20) is zero, so we can't use such a large range as in 32bit.
       x = np.linspace(-10., 20., 100).astype(np.float16)
       # Note that float16 is only in the open set (0, inf) for a smaller
@@ -146,7 +148,7 @@ class SoftplusBijectorTest(test.TestCase):
       # for the test.
       y = np.logspace(-6, 3, 100).astype(np.float16)
       assert_bijective_and_finite(
-          bijector, x, y, rtol=1e-1, atol=1e-3)
+          bijector, x, y, event_ndims=0, rtol=1e-1, atol=1e-3)
 
 
 if __name__ == "__main__":
index f03d6f1..30c7a73 100644 (file)
@@ -41,10 +41,11 @@ class SquareBijectorTest(test.TestCase):
       self.assertAllClose(y, bijector.forward(x).eval())
       self.assertAllClose(x, bijector.inverse(y).eval())
       self.assertAllClose(
-          ildj, bijector.inverse_log_det_jacobian(y).eval(), atol=0., rtol=1e-7)
+          ildj, bijector.inverse_log_det_jacobian(
+              y, event_ndims=0).eval(), atol=0., rtol=1e-7)
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
+          -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
+          bijector.forward_log_det_jacobian(x, event_ndims=0).eval(),
           atol=0.,
           rtol=1e-7)
 
index 7a31228..f57adcd 100644 (file)
@@ -36,7 +36,7 @@ class WeibullBijectorTest(test.TestCase):
       concentration = 0.3
       bijector = Weibull(
           scale=scale, concentration=concentration,
-          event_ndims=1, validate_args=True)
+          validate_args=True)
       self.assertEqual("weibull", bijector.name)
       x = np.array([[[0.], [1.], [14.], [20.], [100.]]], dtype=np.float32)
       # Weibull distribution
@@ -45,13 +45,11 @@ class WeibullBijectorTest(test.TestCase):
       self.assertAllClose(y, bijector.forward(x).eval())
       self.assertAllClose(x, bijector.inverse(y).eval())
       self.assertAllClose(
-          # We should lose a dimension from calculating the determinant of the
-          # jacobian.
-          np.squeeze(weibull_dist.logpdf(x), axis=2),
-          bijector.forward_log_det_jacobian(x).eval())
+          weibull_dist.logpdf(x),
+          bijector.forward_log_det_jacobian(x, event_ndims=0).eval())
       self.assertAllClose(
-          -bijector.inverse_log_det_jacobian(y).eval(),
-          bijector.forward_log_det_jacobian(x).eval(),
+          -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
+          bijector.forward_log_det_jacobian(x, event_ndims=0).eval(),
           rtol=1e-4,
           atol=0.)
 
@@ -64,12 +62,12 @@ class WeibullBijectorTest(test.TestCase):
   def testBijectiveAndFinite(self):
     with self.test_session():
       bijector = Weibull(
-          scale=20., concentration=2., event_ndims=0, validate_args=True)
+          scale=20., concentration=2., validate_args=True)
       x = np.linspace(1., 8., num=10).astype(np.float32)
       y = np.linspace(
           -np.expm1(-1 / 400.),
           -np.expm1(-16), num=10).astype(np.float32)
-      assert_bijective_and_finite(bijector, x, y, rtol=1e-3)
+      assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3)
 
 
 if __name__ == "__main__":
index 5454719..4e8989b 100644 (file)
@@ -44,6 +44,7 @@ class _ChooseLocation(ConditionalBijector):
           graph_parents=[self._loc],
           is_constant_jacobian=True,
           validate_args=False,
+          forward_min_event_ndims=0,
           name=name)
 
   def _forward(self, x, z):
@@ -52,7 +53,7 @@ class _ChooseLocation(ConditionalBijector):
   def _inverse(self, x, z):
     return x - self._gather_loc(z)
 
-  def _inverse_log_det_jacobian(self, x, z=None):
+  def _inverse_log_det_jacobian(self, x, event_ndims, z=None):
     return 0.
 
   def _gather_loc(self, z):
index 933756a..9635134 100644 (file)
@@ -68,7 +68,7 @@ class MultivariateNormalDiagTest(test.TestCase):
       dist = ds.TransformedDistribution(
           base_dist,
           validate_args=True,
-          bijector=bijectors.Softplus(event_ndims=1))
+          bijector=bijectors.Softplus())
       samps = dist.sample(5)  # Shape [5, 1, 3].
       self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape())
 
index f0ba1ec..5fe1331 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
@@ -36,6 +37,35 @@ ds = distributions
 la = linalg
 
 
+class DummyMatrixTransform(bs.Bijector):
+  """Tractable matrix transformation.
+
+  This is a non-sensical bijector that has forward/inverse_min_event_ndims=2.
+  The main use is to check that transformed distribution calculations are done
+  appropriately.
+  """
+
+  def __init__(self):
+    super(DummyMatrixTransform, self).__init__(
+        forward_min_event_ndims=2,
+        is_constant_jacobian=False,
+        validate_args=False,
+        name="dummy")
+
+  def _forward(self, x):
+    return x
+
+  def _inverse(self, y):
+    return y
+
+  # Note: These jacobians don't make sense.
+  def _forward_log_det_jacobian(self, x):
+    return -linalg_ops.matrix_determinant(x)
+
+  def _inverse_log_det_jacobian(self, x):
+    return linalg_ops.matrix_determinant(x)
+
+
 class TransformedDistributionTest(test.TestCase):
 
   def _cls(self):
@@ -55,7 +85,7 @@ class TransformedDistributionTest(test.TestCase):
       # you may or may not need a reduce_sum.
       log_normal = self._cls()(
           distribution=ds.Normal(loc=mu, scale=sigma),
-          bijector=bs.Exp(event_ndims=0))
+          bijector=bs.Exp())
       sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu))
 
       # sample
@@ -87,7 +117,7 @@ class TransformedDistributionTest(test.TestCase):
       sigma = 2.0
       abs_normal = self._cls()(
           distribution=ds.Normal(loc=mu, scale=sigma),
-          bijector=bs.AbsoluteValue(event_ndims=0))
+          bijector=bs.AbsoluteValue())
       sp_normal = stats.norm(mu, sigma)
 
       # sample
@@ -129,7 +159,7 @@ class TransformedDistributionTest(test.TestCase):
       self.assertAllClose(grid, cdf_, rtol=1e-6, atol=0.)
 
   def testCachedSamples(self):
-    exp_forward_only = bs.Exp(event_ndims=0)
+    exp_forward_only = bs.Exp()
     exp_forward_only._inverse = self._make_unimplemented(
         "inverse")
     exp_forward_only._inverse_event_shape_tensor = self._make_unimplemented(
@@ -153,7 +183,7 @@ class TransformedDistributionTest(test.TestCase):
       self.assertAllClose(expected_log_pdf, log_pdf_val, rtol=1e-4, atol=0.)
 
   def testCachedSamplesInvert(self):
-    exp_inverse_only = bs.Exp(event_ndims=0)
+    exp_inverse_only = bs.Exp()
     exp_inverse_only._forward = self._make_unimplemented(
         "forward")
     exp_inverse_only._forward_event_shape_tensor = self._make_unimplemented(
@@ -210,8 +240,11 @@ class TransformedDistributionTest(test.TestCase):
       int_identity = bs.Inline(
           forward_fn=array_ops.identity,
           inverse_fn=array_ops.identity,
-          inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32),
-          forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32),
+          inverse_log_det_jacobian_fn=(
+              lambda y: math_ops.cast(0, dtypes.int32)),
+          forward_log_det_jacobian_fn=(
+              lambda x: math_ops.cast(0, dtypes.int32)),
+          forward_min_event_ndims=0,
           is_constant_jacobian=True)
       normal = self._cls()(
           distribution=ds.Normal(loc=0., scale=1.),
@@ -435,6 +468,82 @@ class ScalarToMultiTest(test.TestCase):
             event_shape=[3],
             validate_args=True)
 
+  def testMatrixEvent(self):
+    with self.test_session() as sess:
+      batch_shape = [2]
+      event_shape = [2, 3, 3]
+      batch_shape_pl = array_ops.placeholder(
+          dtypes.int32, name="dynamic_batch_shape")
+      event_shape_pl = array_ops.placeholder(
+          dtypes.int32, name="dynamic_event_shape")
+      feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32),
+                   event_shape_pl: np.array(event_shape, dtype=np.int32)}
+
+      scale = 2.
+      loc = 0.
+      fake_mvn_dynamic = self._cls()(
+          distribution=ds.Normal(
+              loc=loc,
+              scale=scale),
+          bijector=DummyMatrixTransform(),
+          batch_shape=batch_shape_pl,
+          event_shape=event_shape_pl,
+          validate_args=True)
+
+      fake_mvn_static = self._cls()(
+          distribution=ds.Normal(
+              loc=loc,
+              scale=scale),
+          bijector=DummyMatrixTransform(),
+          batch_shape=batch_shape,
+          event_shape=event_shape,
+          validate_args=True)
+
+      def actual_mvn_log_prob(x):
+        # This distribution is the normal PDF, reduced over the
+        # last 3 dimensions + a jacobian term which corresponds
+        # to the determinant of x.
+        return (np.sum(
+            stats.norm(loc, scale).logpdf(x), axis=(-1, -2, -3)) +
+                np.sum(np.linalg.det(x), axis=-1))
+
+      self.assertAllEqual([2, 3, 3], fake_mvn_static.event_shape)
+      self.assertAllEqual([2], fake_mvn_static.batch_shape)
+
+      self.assertAllEqual(tensor_shape.TensorShape(None),
+                          fake_mvn_dynamic.event_shape)
+      self.assertAllEqual(tensor_shape.TensorShape(None),
+                          fake_mvn_dynamic.batch_shape)
+
+      num_samples = 5e3
+      for fake_mvn, feed_dict in ((fake_mvn_static, {}),
+                                  (fake_mvn_dynamic, feed_dict)):
+        # Ensure sample works by checking first, second moments.
+        y = fake_mvn.sample(int(num_samples), seed=0)
+        x = y[0:5, ...]
+        [
+            x_,
+            fake_event_shape_,
+            fake_batch_shape_,
+            fake_log_prob_,
+            fake_prob_,
+        ] = sess.run([
+            x,
+            fake_mvn.event_shape_tensor(),
+            fake_mvn.batch_shape_tensor(),
+            fake_mvn.log_prob(x),
+            fake_mvn.prob(x),
+        ], feed_dict=feed_dict)
+
+        # Ensure all other functions work as intended.
+        self.assertAllEqual([5, 2, 2, 3, 3], x_.shape)
+        self.assertAllEqual([2, 3, 3], fake_event_shape_)
+        self.assertAllEqual([2], fake_batch_shape_)
+        self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_,
+                            atol=0., rtol=1e-6)
+        self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_,
+                            atol=0., rtol=1e-5)
+
 
 if __name__ == "__main__":
   test.main()
index c355ade..1226c66 100644 (file)
@@ -61,7 +61,7 @@ class VectorLaplaceDiagTest(test.TestCase):
       dist = ds.TransformedDistribution(
           base_dist,
           validate_args=True,
-          bijector=bijectors.Softplus(event_ndims=1))
+          bijector=bijectors.Softplus())
       samps = dist.sample(5)  # Shape [5, 1, 3].
       self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape())
 
index 0fe9f6a..c9e31d7 100644 (file)
@@ -18,9 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
+from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
@@ -72,38 +70,22 @@ class AbsoluteValue(bijector.Bijector):
 
   """
 
-  def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"):
+  def __init__(self, validate_args=False, name="absolute_value"):
     """Instantiates the `AbsoluteValue` bijector.
 
     Args:
-      event_ndims: Python scalar indicating the number of dimensions associated
-        with a particular draw from the distribution.  Currently only zero is
-        supported.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness, in particular whether inputs to `inverse` and
         `inverse_log_det_jacobian` are non-negative.
       name: Python `str` name given to ops managed by this object.
-
-    Raises:
-      ValueError:  If `event_ndims` is not zero.
     """
     self._graph_parents = []
     self._name = name
 
-    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
-    event_ndims_const = tensor_util.constant_value(event_ndims)
-    if event_ndims_const is not None and event_ndims_const not in (0,):
-      raise ValueError("event_ndims(%s) was not 0" % event_ndims_const)
-    else:
-      if validate_args:
-        event_ndims = control_flow_ops.with_dependencies(
-            [check_ops.assert_equal(
-                event_ndims, 0, message="event_ndims was not 0")],
-            event_ndims)
-
     with self._name_scope("init"):
       super(AbsoluteValue, self).__init__(
-          event_ndims=event_ndims,
+          forward_min_event_ndims=0,
+          is_constant_jacobian=True,
           validate_args=validate_args,
           name=name)
 
@@ -121,8 +103,7 @@ class AbsoluteValue(bijector.Bijector):
     # If event_ndims = 2,
     # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1),
     # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0].
-    batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims]
-    zeros = array_ops.zeros(batch_shape, dtype=y.dtype)
+    zeros = constant_op.constant(0., dtype=y.dtype)
     if self.validate_args:
       zeros = control_flow_ops.with_dependencies(
           [check_ops.assert_non_negative(y, message="Argument y was negative")],
index bef7bbb..b4c2939 100644 (file)
@@ -184,6 +184,7 @@ class Affine(bijector.Bijector):
     with self._name_scope("init", values=[
         shift, scale_identity_multiplier, scale_diag, scale_tril,
         scale_perturb_diag, scale_perturb_factor]):
+
       # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
       dtype = dtypes.float32
 
@@ -234,7 +235,7 @@ class Affine(bijector.Bijector):
           event_ndims=1,
           validate_args=validate_args)
       super(Affine, self).__init__(
-          event_ndims=1,
+          forward_min_event_ndims=1,
           graph_parents=(
               [self._scale] if tensor_util.is_tensor(self._scale)
               else self._scale.graph_parents +
@@ -360,16 +361,17 @@ class Affine(bijector.Bijector):
         x, sample_shape, expand_batch_dim=False)
     return x
 
-  def _inverse_log_det_jacobian(self, y):
-    return -self._forward_log_det_jacobian(y)
-
   def _forward_log_det_jacobian(self, x):
+    # is_constant_jacobian = True for this bijector, hence the
+    # `log_det_jacobian` need only be specified for a single input, as this will
+    # be tiled to match `event_ndims`.
     if self._is_only_identity_multiplier:
       # We don't pad in this case and instead let the fldj be applied
       # via broadcast.
       event_size = array_ops.shape(x)[-1]
       event_size = math_ops.cast(event_size, dtype=self._scale.dtype)
       return math_ops.log(math_ops.abs(self._scale)) * event_size
+
     return self.scale.log_abs_determinant()
 
   def _maybe_check_scale(self):
index 89043b1..59f9742 100644 (file)
@@ -22,9 +22,6 @@ from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops.distributions import bijector
 from tensorflow.python.ops.linalg import linear_operator
 
@@ -94,7 +91,6 @@ class AffineLinearOperator(bijector.Bijector):
   def __init__(self,
                shift=None,
                scale=None,
-               event_ndims=1,
                validate_args=False,
                name="affine_linear_operator"):
     """Instantiates the `AffineLinearOperator` bijector.
@@ -103,14 +99,11 @@ class AffineLinearOperator(bijector.Bijector):
       shift: Floating-point `Tensor`.
       scale:  Subclass of `LinearOperator`. Represents the (batch) positive
         definite matrix `M` in `R^{k x k}`.
-      event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions
-        associated with a particular draw from the distribution. Must be 0 or 1.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
 
     Raises:
-      ValueError: if `event_ndims` is not 0 or 1.
       TypeError: if `scale` is not a `LinearOperator`.
       TypeError: if `shift.dtype` does not match `scale.dtype`.
       ValueError: if not `scale.is_non_singular`.
@@ -120,20 +113,6 @@ class AffineLinearOperator(bijector.Bijector):
     self._validate_args = validate_args
     graph_parents = []
     with self._name_scope("init", values=[shift]):
-      event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
-      if tensor_util.constant_value(event_ndims) is not None:
-        event_ndims = tensor_util.constant_value(event_ndims)
-        if event_ndims not in (0, 1):
-          raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims))
-      else:
-        if validate_args:
-          # Shape tool will catch if event_ndims is negative.
-          event_ndims = control_flow_ops.with_dependencies(
-              [check_ops.assert_less(
-                  event_ndims, 2, message="event_ndims must be 0 or 1")],
-              event_ndims)
-        graph_parents += [event_ndims]
-
       # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`.
       dtype = dtypes.float32
 
@@ -166,10 +145,10 @@ class AffineLinearOperator(bijector.Bijector):
       self._scale = scale
       self._shaper = _DistributionShape(
           batch_ndims=batch_ndims,
-          event_ndims=event_ndims,
+          event_ndims=1,
           validate_args=validate_args)
       super(AffineLinearOperator, self).__init__(
-          event_ndims=event_ndims,
+          forward_min_event_ndims=1,
           graph_parents=graph_parents,
           is_constant_jacobian=True,
           dtype=dtype,
@@ -213,12 +192,13 @@ class AffineLinearOperator(bijector.Bijector):
           x, sample_shape, expand_batch_dim=False)
     return x
 
-  def _inverse_log_det_jacobian(self, y):
-    return -self._forward_log_det_jacobian(y)
-
-  def _forward_log_det_jacobian(self, x):  # pylint: disable=unused-argument
+  def _forward_log_det_jacobian(self, x):
+    # is_constant_jacobian = True for this bijector, hence the
+    # `log_det_jacobian` need only be specified for a single input, as this will
+    # be tiled to match `event_ndims`.
     if self.scale is None:
-      return constant_op.constant(0, dtype=x.dtype.base_dtype)
+      return constant_op.constant(0., dtype=x.dtype.base_dtype)
+
     with ops.control_dependencies(self._maybe_collect_assertions() if
                                   self.validate_args else []):
       return self.scale.log_abs_determinant()
index 8adaa54..cd792e2 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
@@ -99,7 +100,7 @@ class AffineScalar(bijector.Bijector):
               self._scale)
 
       super(AffineScalar, self).__init__(
-          event_ndims=0,
+          forward_min_event_ndims=0,
           is_constant_jacobian=True,
           validate_args=validate_args,
           name=name)
@@ -131,8 +132,10 @@ class AffineScalar(bijector.Bijector):
     return x
 
   def _forward_log_det_jacobian(self, x):
-    log_det_jacobian = array_ops.zeros_like(x)
+    # is_constant_jacobian = True for this bijector, hence the
+    # `log_det_jacobian` need only be specified for a single input, as this will
+    # be tiled to match `event_ndims`.
     if self.scale is None:
-      return log_det_jacobian
-    log_det_jacobian += math_ops.log(math_ops.abs(self.scale))
-    return log_det_jacobian
+      return constant_op.constant(0., dtype=x.dtype.base_dtype)
+
+    return math_ops.log(math_ops.abs(self.scale))
index 33fdd32..224cec8 100644 (file)
@@ -157,7 +157,12 @@ class BatchNormalization(bijector.Bijector):
         gamma_constraint=g_constraint)
     self._validate_bn_layer(self.batchnorm)
     self._training = training
+    if isinstance(self.batchnorm.axis, int):
+      forward_min_event_ndims = 1
+    else:
+      forward_min_event_ndims = len(self.batchnorm.axis)
     super(BatchNormalization, self).__init__(
+        forward_min_event_ndims=forward_min_event_ndims,
         validate_args=validate_args, name=name)
 
   def _validate_bn_layer(self, layer):
@@ -186,7 +191,6 @@ class BatchNormalization(bijector.Bijector):
     input_shape = np.int32(x.shape.as_list())
 
     ndims = len(input_shape)
-    # event_dims = self._compute_event_dims(x)
     reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis]
     # Broadcasting only necessary for single-axis batch norm where the axis is
     # not the last dimension
index 3ce7c26..85ad23e 100644 (file)
@@ -21,6 +21,9 @@ from __future__ import print_function
 import itertools
 
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.distributions import bijector
 
 
@@ -29,6 +32,91 @@ __all__ = [
 ]
 
 
+def _use_static_shape(input_tensor, ndims):
+  return input_tensor.shape.is_fully_defined() and isinstance(ndims, int)
+
+
+def _maybe_get_event_ndims_statically(event_ndims):
+  static_event_ndims = (event_ndims if isinstance(event_ndims, int)
+                        else tensor_util.constant_value(event_ndims))
+  if static_event_ndims is not None:
+    return static_event_ndims
+
+  return event_ndims
+
+
+def _compute_min_event_ndims(bijector_list, compute_forward=True):
+  """Computes the min_event_ndims associated with the give list of bijectors.
+
+  Given a list `bijector_list` of bijectors, compute the min_event_ndims that is
+  associated with the composition of bijectors in that list.
+
+  min_event_ndims is the # of right most dimensions for which the bijector has
+  done necessary computation on (i.e. the non-broadcastable part of the
+  computation).
+
+  We can derive the min_event_ndims for a chain of bijectors as follows:
+
+  In the case where there are no rank changing bijectors, this will simply be
+  `max(b.forward_min_event_ndims for b in bijector_list)`. This is because the
+  bijector with the most forward_min_event_ndims requires the most dimensions,
+  and hence the chain also requires operating on those dimensions.
+
+  However in the case of rank changing, more care is needed in determining the
+  exact amount of dimensions. Padding dimensions causes subsequent bijectors to
+  operate on the padded dimensions, and Removing dimensions causes bijectors to
+  operate more left.
+
+  Args:
+    bijector_list: List of bijectors to be composed by chain.
+    compute_forward: Boolean. If True, computes the min_event_ndims associated
+      with a forward call to Chain, and otherwise computes the min_event_ndims
+      associated with an inverse call to Chain. The latter is the same as the
+      min_event_ndims associated with a forward call to Invert(Chain(....)).
+
+  Returns:
+    min_event_ndims
+  """
+  min_event_ndims = 0
+  # This is a mouthful, but what this encapsulates is that if not for rank
+  # changing bijectors, we'd only need to compute the largest of the min
+  # required ndims. Hence "max_min". Due to rank changing bijectors, we need to
+  # account for synthetic rank growth / synthetic rank decrease from a rank
+  # changing bijector.
+  rank_changed_adjusted_max_min_event_ndims = 0
+
+  if compute_forward:
+    bijector_list = reversed(bijector_list)
+
+  for b in bijector_list:
+    if compute_forward:
+      current_min_event_ndims = b.forward_min_event_ndims
+      current_inverse_min_event_ndims = b.inverse_min_event_ndims
+    else:
+      current_min_event_ndims = b.inverse_min_event_ndims
+      current_inverse_min_event_ndims = b.forward_min_event_ndims
+
+    # New dimensions were touched.
+    if rank_changed_adjusted_max_min_event_ndims < current_min_event_ndims:
+      min_event_ndims += (
+          current_min_event_ndims - rank_changed_adjusted_max_min_event_ndims)
+    rank_changed_adjusted_max_min_event_ndims = max(
+        current_min_event_ndims, rank_changed_adjusted_max_min_event_ndims)
+
+    # If the number of dimensions has increased via forward, then
+    # inverse_min_event_ndims > forward_min_event_ndims, and hence the
+    # dimensions we computed on, have moved left (so we have operated
+    # on additional dimensions).
+    # Conversely, if the number of dimensions has decreased via forward,
+    # then we have inverse_min_event_ndims < forward_min_event_ndims,
+    # and so we will have operated on fewer right most dimensions.
+
+    number_of_changed_dimensions = (
+        current_min_event_ndims - current_inverse_min_event_ndims)
+    rank_changed_adjusted_max_min_event_ndims -= number_of_changed_dimensions
+  return min_event_ndims
+
+
 class Chain(bijector.Bijector):
   """Bijector which applies a sequence of bijectors.
 
@@ -93,21 +181,24 @@ class Chain(bijector.Bijector):
       raise ValueError("incompatible dtypes: %s" % dtype)
     elif len(dtype) == 2:
       dtype = dtype[1] if dtype[0] is None else dtype[0]
-      event_ndims = bijectors[0].event_ndims
     elif len(dtype) == 1:
       dtype = dtype[0]
-      event_ndims = bijectors[0].event_ndims
     else:
       dtype = None
-      event_ndims = None
+
+    inverse_min_event_ndims = _compute_min_event_ndims(
+        bijectors, compute_forward=False)
+    forward_min_event_ndims = _compute_min_event_ndims(
+        bijectors, compute_forward=True)
 
     super(Chain, self).__init__(
         graph_parents=list(itertools.chain.from_iterable(
             b.graph_parents for b in bijectors)),
+        forward_min_event_ndims=forward_min_event_ndims,
+        inverse_min_event_ndims=inverse_min_event_ndims,
         is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors),
         validate_args=validate_args,
         dtype=dtype,
-        event_ndims=event_ndims,
         name=name or ("identity" if not bijectors else
                       "_of_".join(["chain"] + [b.name for b in bijectors])))
 
@@ -147,10 +238,31 @@ class Chain(bijector.Bijector):
     return y
 
   def _inverse_log_det_jacobian(self, y, **kwargs):
-    ildj = constant_op.constant(0., dtype=y.dtype,
-                                name="inverse_log_det_jacobian")
+    ildj = constant_op.constant(
+        0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian")
+
+    if not self.bijectors:
+      return ildj
+
+    event_ndims = _maybe_get_event_ndims_statically(
+        self.inverse_min_event_ndims)
+
+    if _use_static_shape(y, event_ndims):
+      event_shape = y.shape[y.shape.ndims - event_ndims:]
+    else:
+      event_shape = array_ops.shape(y)[array_ops.rank(y) - event_ndims:]
+
     for b in self.bijectors:
-      ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {}))
+      ildj += b.inverse_log_det_jacobian(
+          y, event_ndims=event_ndims, **kwargs.get(b.name, {}))
+
+      if _use_static_shape(y, event_ndims):
+        event_shape = b.inverse_event_shape(event_shape)
+        event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+      else:
+        event_shape = b.inverse_event_shape_tensor(event_shape)
+        event_ndims = _maybe_get_event_ndims_statically(
+            array_ops.rank(event_shape))
       y = b.inverse(y, **kwargs.get(b.name, {}))
     return ildj
 
@@ -160,9 +272,34 @@ class Chain(bijector.Bijector):
     return x
 
   def _forward_log_det_jacobian(self, x, **kwargs):
-    fldj = constant_op.constant(0., dtype=x.dtype,
-                                name="forward_log_det_jacobian")
+    x = ops.convert_to_tensor(x, name="x")
+
+    fldj = constant_op.constant(
+        0., dtype=x.dtype, name="inverse_log_det_jacobian")
+
+    if not self.bijectors:
+      return fldj
+
+    event_ndims = _maybe_get_event_ndims_statically(
+        self.forward_min_event_ndims)
+
+    if _use_static_shape(x, event_ndims):
+      event_shape = x.shape[x.shape.ndims - event_ndims:]
+    else:
+      event_shape = array_ops.shape(x)[array_ops.rank(x) - event_ndims:]
+
     for b in reversed(self.bijectors):
-      fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {}))
+      fldj += b.forward_log_det_jacobian(
+          x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
+      if _use_static_shape(x, event_ndims):
+        event_shape = b.forward_event_shape(event_shape)
+        event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims)
+      else:
+        event_shape = b.forward_event_shape_tensor(event_shape)
+        event_ndims = _maybe_get_event_ndims_statically(
+            array_ops.rank(event_shape))
+
       x = b.forward(x, **kwargs.get(b.name, {}))
+
     return fldj
+
index 8f09e16..caae2ad 100644 (file)
@@ -80,7 +80,7 @@ class CholeskyOuterProduct(bijector.Bijector):
     self._graph_parents = []
     self._name = name
     super(CholeskyOuterProduct, self).__init__(
-        event_ndims=2,
+        forward_min_event_ndims=2,
         validate_args=validate_args,
         name=name)
 
index ccb1f02..e9e994f 100644 (file)
@@ -44,12 +44,16 @@ class ConditionalBijector(bijector.Bijector):
       "**condition_kwargs":
       "Named arguments forwarded to subclass implementation."})
   def inverse_log_det_jacobian(
-      self, y, name="inverse_log_det_jacobian", **condition_kwargs):
-    return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs)
+      self, y, event_ndims, name="inverse_log_det_jacobian",
+      **condition_kwargs):
+    return self._call_inverse_log_det_jacobian(
+        y, event_ndims, name, **condition_kwargs)
 
   @distribution_util.AppendDocstring(kwargs_dict={
       "**condition_kwargs":
       "Named arguments forwarded to subclass implementation."})
   def forward_log_det_jacobian(
-      self, x, name="forward_log_det_jacobian", **condition_kwargs):
-    return self._call_forward_log_det_jacobian(x, name, **condition_kwargs)
+      self, x, event_ndims, name="forward_log_det_jacobian",
+      **condition_kwargs):
+    return self._call_forward_log_det_jacobian(
+        x, event_ndims, name, **condition_kwargs)
index b1ff840..9fc1bbf 100644 (file)
@@ -33,8 +33,8 @@ class Exp(power_transform.PowerTransform):
 
     ```python
     # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1
-    # batch ndim and 2 event ndims (i.e., vector of matrices).
-    exp = Exp(event_ndims=2)
+    # batch ndim 2.
+    exp = Exp()
     x = [[[1., 2],
            [3, 4]],
           [[5, 6],
@@ -48,19 +48,17 @@ class Exp(power_transform.PowerTransform):
   """
 
   def __init__(self,
-               event_ndims=0,
                validate_args=False,
                name="exp"):
     """Instantiates the `Exp` bijector.
 
     Args:
-      event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions
-        associated with a particular draw from the distribution.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
     """
+    # forward_min_event_ndims = 0.
+    # No forward_min_event_ndims specified as this is done in PowerTransform.
     super(Exp, self).__init__(
-        event_ndims=event_ndims,
         validate_args=validate_args,
         name=name)
index 67f3978..e656a25 100644 (file)
@@ -48,7 +48,6 @@ class Gumbel(bijector.Bijector):
   def __init__(self,
                loc=0.,
                scale=1.,
-               event_ndims=0,
                validate_args=False,
                name="gumbel"):
     """Instantiates the `Gumbel` bijector.
@@ -60,8 +59,6 @@ class Gumbel(bijector.Bijector):
       scale: Positive Float-like `Tensor` that is the same dtype and is
         broadcastable with `loc`.
         This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
-      event_ndims: Python scalar indicating the number of dimensions associated
-        with a particular draw from the distribution.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
@@ -80,7 +77,9 @@ class Gumbel(bijector.Bijector):
         ], self._scale)
 
     super(Gumbel, self).__init__(
-        event_ndims=event_ndims, validate_args=validate_args, name=name)
+        validate_args=validate_args,
+        forward_min_event_ndims=0,
+        name=name)
 
   @property
   def loc(self):
@@ -102,15 +101,11 @@ class Gumbel(bijector.Bijector):
 
   def _inverse_log_det_jacobian(self, y):
     y = self._maybe_assert_valid_y(y)
-    event_dims = self._event_dims_tensor(y)
-    return math_ops.reduce_sum(
-        math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims)
+    return math_ops.log(self.scale / (-math_ops.log(y) * y))
 
   def _forward_log_det_jacobian(self, x):
-    event_dims = self._event_dims_tensor(x)
     z = (x - self.loc) / self.scale
-    return math_ops.reduce_sum(
-        -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims)
+    return -z - math_ops.exp(-z) - math_ops.log(self.scale)
 
   def _maybe_assert_valid_y(self, y):
     if not self.validate_args:
index fab1b22..2bde956 100644 (file)
@@ -40,7 +40,7 @@ class Inline(bijector.Bijector):
     name="exp")
   ```
 
-  The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`.
+  The above example is equivalent to the `Bijector` `Exp()`.
   """
 
   def __init__(self,
@@ -54,6 +54,8 @@ class Inline(bijector.Bijector):
                inverse_event_shape_tensor_fn=None,
                is_constant_jacobian=False,
                validate_args=False,
+               forward_min_event_ndims=None,
+               inverse_min_event_ndims=None,
                name="inline"):
     """Creates a `Bijector` from callables.
 
@@ -76,10 +78,15 @@ class Inline(bijector.Bijector):
         constant for all input arguments.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
+      forward_min_event_ndims: Python `int` indicating the minimal
+        dimensionality this bijector acts on.
+      inverse_min_event_ndims: Python `int` indicating the minimal
+        dimensionality this bijector acts on.
       name: Python `str`, name given to ops managed by this object.
     """
     super(Inline, self).__init__(
-        event_ndims=0,
+        forward_min_event_ndims=forward_min_event_ndims,
+        inverse_min_event_ndims=inverse_min_event_ndims,
         is_constant_jacobian=is_constant_jacobian,
         validate_args=validate_args,
         name=name)
@@ -134,8 +141,8 @@ class Inline(bijector.Bijector):
           "inverse_log_det_jacobian_fn is not a callable function.")
     return self._inverse_log_det_jacobian_fn(y, **kwargs)
 
-  def _forward_log_det_jacobian(self, y, **kwargs):
+  def _forward_log_det_jacobian(self, x, **kwargs):
     if not callable(self._forward_log_det_jacobian_fn):
       raise NotImplementedError(
           "forward_log_det_jacobian_fn is not a callable function.")
-    return self._forward_log_det_jacobian_fn(y, **kwargs)
+    return self._forward_log_det_jacobian_fn(x, **kwargs)
index 2c603fe..1904239 100644 (file)
@@ -66,8 +66,9 @@ class Invert(bijector_lib.Bijector):
 
     self._bijector = bijector
     super(Invert, self).__init__(
-        event_ndims=bijector.event_ndims,
         graph_parents=bijector.graph_parents,
+        forward_min_event_ndims=bijector.inverse_min_event_ndims,
+        inverse_min_event_ndims=bijector.forward_min_event_ndims,
         is_constant_jacobian=bijector.is_constant_jacobian,
         validate_args=validate_args,
         dtype=bijector.dtype,
index f5de052..97000c1 100644 (file)
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
@@ -48,7 +47,6 @@ class Kumaraswamy(bijector.Bijector):
   def __init__(self,
                concentration1=None,
                concentration0=None,
-               event_ndims=0,
                validate_args=False,
                name="kumaraswamy"):
     """Instantiates the `Kumaraswamy` bijector.
@@ -60,31 +58,14 @@ class Kumaraswamy(bijector.Bijector):
       concentration0: Python `float` scalar indicating the transform power,
         i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is
         `concentration0`.
-      event_ndims: Python scalar indicating the number of dimensions associated
-        with a particular draw from the distribution. Currently only zero is
-        supported.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
-
-    Raises:
-      ValueError:  If `event_ndims` is not zero.
     """
     self._graph_parents = []
     self._name = name
     self._validate_args = validate_args
 
-    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
-    event_ndims_const = tensor_util.constant_value(event_ndims)
-    if event_ndims_const is not None and event_ndims_const not in (0,):
-      raise ValueError("event_ndims(%s) was not 0" % event_ndims_const)
-    else:
-      if validate_args:
-        event_ndims = control_flow_ops.with_dependencies(
-            [check_ops.assert_equal(
-                event_ndims, 0, message="event_ndims was not 0")],
-            event_ndims)
-
     with self._name_scope("init", values=[concentration1, concentration0]):
       concentration1 = self._maybe_assert_valid_concentration(
           ops.convert_to_tensor(concentration1, name="concentration1"),
@@ -96,7 +77,7 @@ class Kumaraswamy(bijector.Bijector):
     self._concentration1 = concentration1
     self._concentration0 = concentration0
     super(Kumaraswamy, self).__init__(
-        event_ndims=0,
+        forward_min_event_ndims=0,
         validate_args=validate_args,
         name=name)
 
@@ -123,12 +104,10 @@ class Kumaraswamy(bijector.Bijector):
 
   def _inverse_log_det_jacobian(self, y):
     y = self._maybe_assert_valid(y)
-    event_dims = self._event_dims_tensor(y)
-    return math_ops.reduce_sum(
+    return (
         math_ops.log(self.concentration1) + math_ops.log(self.concentration0) +
         (self.concentration1 - 1) * math_ops.log(y) +
-        (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1),
-        axis=event_dims)
+        (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1))
 
   def _maybe_assert_valid_concentration(self, concentration, validate_args):
     """Checks the validity of a concentration parameter."""
index 84b2340..ef56cf6 100644 (file)
@@ -61,7 +61,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
   this property by zeroing out weights in its `masked_dense` layers.
 
   In the `tf.distributions` framework, a "normalizing flow" is implemented as a
-  `tf.distributions.bijectors.Bijector`. The `forward` "autoregression"
+  `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression"
   is implemented using a `tf.while_loop` and a deep neural network (DNN) with
   masked weights such that the autoregressive property is automatically met in
   the `inverse`.
@@ -220,6 +220,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
     self._shift_and_log_scale_fn = shift_and_log_scale_fn
     self._unroll_loop = unroll_loop
     super(MaskedAutoregressiveFlow, self).__init__(
+        forward_min_event_ndims=1,
         is_constant_jacobian=is_constant_jacobian,
         validate_args=validate_args,
         name=name)
index 8654cc3..4978167 100644 (file)
@@ -114,6 +114,7 @@ class Permute(bijector_lib.Bijector):
         ], permutation)
       self._permutation = permutation
       super(Permute, self).__init__(
+          forward_min_event_ndims=1,
           is_constant_jacobian=True,
           validate_args=validate_args,
           name=name or "permute")
@@ -132,7 +133,10 @@ class Permute(bijector_lib.Bijector):
         axis=-1)
 
   def _inverse_log_det_jacobian(self, y):
-    return constant_op.constant(0., dtype=y.dtype)
+    # is_constant_jacobian = True for this bijector, hence the
+    # `log_det_jacobian` need only be specified for a single input, as this will
+    # be tiled to match `event_ndims`.
+    return constant_op.constant(0., dtype=y.dtype.base_dtype)
 
   def _forward_log_det_jacobian(self, x):
-    return constant_op.constant(0., dtype=x.dtype)
+    return constant_op.constant(0., dtype=x.dtype.base_dtype)
index c37db61..71f123f 100644 (file)
@@ -43,7 +43,6 @@ class PowerTransform(bijector.Bijector):
 
   def __init__(self,
                power=0.,
-               event_ndims=0,
                validate_args=False,
                name="power_transform"):
     """Instantiates the `PowerTransform` bijector.
@@ -51,8 +50,6 @@ class PowerTransform(bijector.Bijector):
     Args:
       power: Python `float` scalar indicating the transform power, i.e.,
         `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`.
-      event_ndims: Python scalar indicating the number of dimensions associated
-        with a particular draw from the distribution.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
@@ -70,7 +67,7 @@ class PowerTransform(bijector.Bijector):
       raise ValueError("`power` must be a non-negative TF constant.")
     self._power = power
     super(PowerTransform, self).__init__(
-        event_ndims=event_ndims,
+        forward_min_event_ndims=0,
         validate_args=validate_args,
         name=name)
 
@@ -97,18 +94,13 @@ class PowerTransform(bijector.Bijector):
 
   def _inverse_log_det_jacobian(self, y):
     y = self._maybe_assert_valid_y(y)
-    event_dims = self._event_dims_tensor(y)
-    return (self.power - 1.) * math_ops.reduce_sum(
-        math_ops.log(y), axis=event_dims)
+    return (self.power - 1.) * math_ops.log(y)
 
   def _forward_log_det_jacobian(self, x):
     x = self._maybe_assert_valid_x(x)
-    event_dims = self._event_dims_tensor(x)
     if self.power == 0.:
-      return math_ops.reduce_sum(x, axis=event_dims)
-    return (1. / self.power - 1.) * math_ops.reduce_sum(
-        math_ops.log1p(x * self.power),
-        axis=event_dims)
+      return x
+    return (1. / self.power - 1.) * math_ops.log1p(x * self.power)
 
   def _maybe_assert_valid_x(self, x):
     if not self.validate_args or self.power == 0.:
index 71ab369..f09ab21 100644 (file)
@@ -166,7 +166,7 @@ class RealNVP(bijector_lib.Bijector):
     self._input_depth = None
     self._shift_and_log_scale_fn = shift_and_log_scale_fn
     super(RealNVP, self).__init__(
-        event_ndims=1,
+        forward_min_event_ndims=1,
         is_constant_jacobian=is_constant_jacobian,
         validate_args=validate_args,
         name=name)
@@ -224,7 +224,7 @@ class RealNVP(bijector_lib.Bijector):
     _, log_scale = self._shift_and_log_scale_fn(
         x0, self._input_depth - self._num_masked)
     if log_scale is None:
-      return constant_op.constant(0., dtype=x.dtype, name="ildj")
+      return constant_op.constant(0., dtype=x.dtype, name="fldj")
     return math_ops.reduce_sum(log_scale, axis=-1)
 
 
index 55eca06..82210cd 100644 (file)
@@ -128,9 +128,11 @@ class Reshape(bijector_lib.Bijector):
       self._event_shape_in = event_shape_in
       self._event_shape_out = event_shape_out
 
-      super(Reshape, self).__init__(is_constant_jacobian=True,
-                                    validate_args=validate_args,
-                                    name=name or "reshape")
+      super(Reshape, self).__init__(
+          forward_min_event_ndims=0,
+          is_constant_jacobian=True,
+          validate_args=validate_args,
+          name=name or "reshape")
 
   def _maybe_check_valid_shape(self, shape, validate_args):
     """Check that a shape Tensor is int-type and otherwise sane."""
index a640dfe..5df8c88 100644 (file)
@@ -33,7 +33,9 @@ class Sigmoid(bijector.Bijector):
 
   def __init__(self, validate_args=False, name="sigmoid"):
     super(Sigmoid, self).__init__(
-        event_ndims=0, validate_args=validate_args, name=name)
+        forward_min_event_ndims=0,
+        validate_args=validate_args,
+        name=name)
 
   def _forward(self, x):
     return math_ops.sigmoid(x)
index 3a75e4a..2a32e8a 100644 (file)
@@ -91,7 +91,6 @@ class SinhArcsinh(bijector.Bijector):
   def __init__(self,
                skewness=None,
                tailweight=None,
-               event_ndims=0,
                validate_args=False,
                name="SinhArcsinh"):
     """Instantiates the `SinhArcsinh` bijector.
@@ -101,8 +100,6 @@ class SinhArcsinh(bijector.Bijector):
         of type `float32`.
       tailweight:  Tailweight parameter.  Positive `Tensor` of same `dtype` as
         `skewness` and broadcastable `shape`.  Default is `1` of type `float32`.
-      event_ndims: Python scalar indicating the number of dimensions associated
-        with a particular draw from the distribution.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
@@ -125,7 +122,9 @@ class SinhArcsinh(bijector.Bijector):
                 message="Argument tailweight was not positive")
         ], self._tailweight)
     super(SinhArcsinh, self).__init__(
-        event_ndims=event_ndims, validate_args=validate_args, name=name)
+        forward_min_event_ndims=0,
+        validate_args=validate_args,
+        name=name)
 
   @property
   def skewness(self):
@@ -149,31 +148,29 @@ class SinhArcsinh(bijector.Bijector):
     # dx/dy
     # = cosh(arcsinh(y) / tailweight - skewness)
     #     / (tailweight * sqrt(y**2 + 1))
-    event_dims = self._event_dims_tensor(y)
-    return math_ops.reduce_sum(
-        # This is computed inside the log to avoid catastrophic cancellations
-        # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1).
+
+    # This is computed inside the log to avoid catastrophic cancellations
+    # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1).
+    return (
         math_ops.log(math_ops.cosh(
             math_ops.asinh(y) / self.tailweight - self.skewness)
                      # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases
                      # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x).
                      / _sqrtx2p1(y))
-        - math_ops.log(self.tailweight),
-        axis=event_dims)
+        - math_ops.log(self.tailweight))
 
   def _forward_log_det_jacobian(self, x):
     # y = sinh((arcsinh(x) + skewness) * tailweight)
     # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1),
     # dy/dx
     # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1)
-    event_dims = self._event_dims_tensor(x)
-    return math_ops.reduce_sum(
-        # This is computed inside the log to avoid catastrophic cancellations
-        # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1).
+
+    # This is computed inside the log to avoid catastrophic cancellations
+    # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1).
+    return (
         math_ops.log(math_ops.cosh(
             (math_ops.asinh(x) + self.skewness) * self.tailweight)
                      # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases
                      # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x).
                      / _sqrtx2p1(x))
-        + math_ops.log(self.tailweight),
-        axis=event_dims)
+        + math_ops.log(self.tailweight))
index dc94fd0..f52b915 100644 (file)
@@ -66,7 +66,7 @@ class SoftmaxCentered(bijector.Bijector):
     self._graph_parents = []
     self._name = name
     super(SoftmaxCentered, self).__init__(
-        event_ndims=1,
+        forward_min_event_ndims=1,
         validate_args=validate_args,
         name=name)
 
@@ -105,8 +105,6 @@ class SoftmaxCentered(bijector.Bijector):
       y.shape.assert_is_compatible_with(shape)
       y.set_shape(shape)
 
-    # Since we only support event_ndims in [0, 1] and we do padding, we always
-    # reduce over the last dimension, i.e., dim=-1 (which is the default).
     return nn_ops.softmax(y)
 
   def _inverse(self, y):
@@ -162,8 +160,6 @@ class SoftmaxCentered(bijector.Bijector):
     #   -log_normalization + reduce_sum(logits - log_normalization)
     log_normalization = nn_ops.softplus(
         math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True))
-    fldj = (-log_normalization +
-            math_ops.reduce_sum(x - log_normalization,
-                                axis=-1,
-                                keep_dims=True))
-    return array_ops.squeeze(fldj, squeeze_dims=-1)
+    return array_ops.squeeze(
+        (-log_normalization + math_ops.reduce_sum(
+            x - log_normalization, axis=-1, keepdims=True)), axis=-1)
index 81957fc..96a938c 100644 (file)
@@ -62,7 +62,7 @@ class Softplus(bijector.Bijector):
     ```python
     # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1
     # batch ndim and 2 event ndims (i.e., vector of matrices).
-    softplus = Softplus(event_ndims=2)
+    softplus = Softplus()
     x = [[[1., 2],
           [3, 4]],
          [[5, 6],
@@ -81,7 +81,6 @@ class Softplus(bijector.Bijector):
               "Nonzero floating point `Tensor`.  Controls the softness of what "
               "would otherwise be a kink at the origin.  Default is 1.0")})
   def __init__(self,
-               event_ndims=0,
                hinge_softness=None,
                validate_args=False,
                name="softplus"):
@@ -101,7 +100,7 @@ class Softplus(bijector.Bijector):
             [nonzero_check], self.hinge_softness)
 
     super(Softplus, self).__init__(
-        event_ndims=event_ndims,
+        forward_min_event_ndims=0,
         validate_args=validate_args,
         name=name)
 
@@ -130,14 +129,12 @@ class Softplus(bijector.Bijector):
     # 1 - exp{-Y} approx Y.
     if self.hinge_softness is not None:
       y /= math_ops.cast(self.hinge_softness, y.dtype)
-    return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)),
-                                axis=self._event_dims_tensor(y))
+    return -math_ops.log(-math_ops.expm1(-y))
 
   def _forward_log_det_jacobian(self, x):
     if self.hinge_softness is not None:
       x /= math_ops.cast(self.hinge_softness, x.dtype)
-    return -math_ops.reduce_sum(nn_ops.softplus(-x),
-                                axis=self._event_dims_tensor(x))
+    return -nn_ops.softplus(-x)
 
   @property
   def hinge_softness(self):
index 1e9dbf3..2ccfdc9 100644 (file)
@@ -59,7 +59,7 @@ class Square(bijector.Bijector):
     """
     self._name = name
     super(Square, self).__init__(
-        event_ndims=0,
+        forward_min_event_ndims=0,
         validate_args=validate_args,
         name=name)
 
index 00520bc..39129cd 100644 (file)
@@ -50,7 +50,6 @@ class Weibull(bijector.Bijector):
   def __init__(self,
                scale=1.,
                concentration=1.,
-               event_ndims=0,
                validate_args=False,
                name="weibull"):
     """Instantiates the `Weibull` bijector.
@@ -62,8 +61,6 @@ class Weibull(bijector.Bijector):
       concentration: Positive Float-type `Tensor` that is the same dtype and is
         broadcastable with `scale`.
         This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
-      event_ndims: Python scalar indicating the number of dimensions associated
-        with a particular draw from the distribution.
       validate_args: Python `bool` indicating whether arguments should be
         checked for correctness.
       name: Python `str` name given to ops managed by this object.
@@ -89,7 +86,7 @@ class Weibull(bijector.Bijector):
         ], self._concentration)
 
     super(Weibull, self).__init__(
-        event_ndims=event_ndims,
+        forward_min_event_ndims=0,
         validate_args=validate_args,
         name=name)
 
@@ -113,22 +110,18 @@ class Weibull(bijector.Bijector):
 
   def _inverse_log_det_jacobian(self, y):
     y = self._maybe_assert_valid_y(y)
-    event_dims = self._event_dims_tensor(y)
-    return math_ops.reduce_sum(
+    return (
         -math_ops.log1p(-y) +
         (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) +
-        math_ops.log(self.scale / self.concentration),
-        axis=event_dims)
+        math_ops.log(self.scale / self.concentration))
 
   def _forward_log_det_jacobian(self, x):
     x = self._maybe_assert_valid_x(x)
-    event_dims = self._event_dims_tensor(x)
-    return math_ops.reduce_sum(
+    return (
         -(x / self.scale) ** self.concentration +
         (self.concentration - 1) * math_ops.log(x) +
         math_ops.log(self.concentration) +
-        -self.concentration * math_ops.log(self.scale),
-        axis=event_dims)
+        -self.concentration * math_ops.log(self.scale))
 
   def _maybe_assert_valid_x(self, x):
     if not self.validate_args:
index 1d4c566..10b4536 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 from tensorflow.contrib.distributions.python.ops import conditional_distribution
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.distributions import transformed_distribution
@@ -105,7 +106,9 @@ class ConditionalTransformedDistribution(
     bijector_kwargs = bijector_kwargs or {}
     distribution_kwargs = distribution_kwargs or {}
     x = self.bijector.inverse(y, **bijector_kwargs)
-    ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs)
+    event_ndims = self._maybe_get_event_ndims_statically()
+    ildj = self.bijector.inverse_log_det_jacobian(
+        y, event_ndims=event_ndims, **bijector_kwargs)
     if self.bijector._is_injective:  # pylint: disable=protected-access
       return self._finish_log_prob_for_one_fiber(y, x, ildj,
                                                  distribution_kwargs)
@@ -128,7 +131,9 @@ class ConditionalTransformedDistribution(
     bijector_kwargs = bijector_kwargs or {}
     distribution_kwargs = distribution_kwargs or {}
     x = self.bijector.inverse(y, **bijector_kwargs)
-    ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs)
+    event_ndims = self._maybe_get_event_ndims_statically()
+    ildj = self.bijector.inverse_log_det_jacobian(
+        y, event_ndims=event_ndims, **bijector_kwargs)
     if self.bijector._is_injective:  # pylint: disable=protected-access
       return self._finish_prob_for_one_fiber(y, x, ildj, distribution_kwargs)
 
@@ -214,3 +219,15 @@ class ConditionalTransformedDistribution(
     # implies the qth quantile of Y is g(x_q).
     inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
     return self.bijector.forward(inv_cdf, **bijector_kwargs)
+
+  def _maybe_get_event_ndims_statically(self):
+    if self.event_shape.ndims is not None:
+      return self.event_shape.ndims
+
+    event_ndims = array_ops.size(self.event_shape_tensor())
+    static_event_ndims = tensor_util.constant_value(event_ndims)
+
+    if static_event_ndims is not None:
+      return static_event_ndims
+
+    return event_ndims
index 92f2bba..3314181 100644 (file)
@@ -114,7 +114,7 @@ def quadrature_scheme_lognormal_quantiles(
     # Create a LogNormal distribution.
     dist = transformed_lib.TransformedDistribution(
         distribution=normal_lib.Normal(loc=loc, scale=scale),
-        bijector=Exp(event_ndims=0),
+        bijector=Exp(),
         validate_args=validate_args)
     batch_ndims = dist.batch_shape.ndims
     if batch_ndims is None:
index f56ba07..02cf3c7 100644 (file)
@@ -409,5 +409,5 @@ class RelaxedOneHotCategorical(
                                        validate_args=validate_args,
                                        allow_nan_stats=allow_nan_stats)
     super(RelaxedOneHotCategorical, self).__init__(dist,
-                                                   bijectors.Exp(event_ndims=1),
+                                                   bijectors.Exp(),
                                                    name=name)
index 0d8a192..cde6d85 100644 (file)
@@ -166,13 +166,13 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
 
       # Make the SAS bijector, 'F'.
       f = bijectors.SinhArcsinh(
-          skewness=skewness, tailweight=tailweight, event_ndims=0)
+          skewness=skewness, tailweight=tailweight)
       if has_default_skewness:
         f_noskew = f
       else:
         f_noskew = bijectors.SinhArcsinh(
             skewness=skewness.dtype.as_numpy_dtype(0.),
-            tailweight=tailweight, event_ndims=0)
+            tailweight=tailweight)
 
       # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
       c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype))
index 971d65c..da271a8 100644 (file)
@@ -427,7 +427,6 @@ class VectorDiffeomixture(distribution_lib.Distribution):
       self._endpoint_affine = [
           AffineLinearOperator(shift=loc_,
                                scale=scale_,
-                               event_ndims=1,
                                validate_args=validate_args,
                                name="endpoint_affine_{}".format(k))
           for k, (loc_, scale_) in enumerate(zip(loc, scale))]
@@ -467,7 +466,6 @@ class VectorDiffeomixture(distribution_lib.Distribution):
       self._interpolated_affine = [
           AffineLinearOperator(shift=loc_,
                                scale=scale_,
-                               event_ndims=1,
                                validate_args=validate_args,
                                name="interpolated_affine_{}".format(k))
           for k, (loc_, scale_) in enumerate(zip(
@@ -621,9 +619,11 @@ class VectorDiffeomixture(distribution_lib.Distribution):
     log_prob = math_ops.reduce_sum(self.distribution.log_prob(y), axis=-2)
     # Because the affine transformation has a constant Jacobian, it is the case
     # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general.
-    fldj = array_ops.stack(
-        [aff.forward_log_det_jacobian(x) for aff in self.interpolated_affine],
-        axis=-1)
+    fldj = array_ops.stack([
+        aff.forward_log_det_jacobian(
+            x,
+            event_ndims=array_ops.rank(self.event_shape_tensor())
+        ) for aff in self.interpolated_affine], axis=-1)
     return math_ops.reduce_logsumexp(
         self.mixture_distribution.logits - fldj + log_prob, axis=-1)
 
index 003c66b..05919be 100644 (file)
@@ -215,13 +215,13 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
       tailweight = ops.convert_to_tensor(
           tailweight, dtype=dtype, name="tailweight")
       f = bijectors.SinhArcsinh(
-          skewness=skewness, tailweight=tailweight, event_ndims=1)
+          skewness=skewness, tailweight=tailweight)
       if has_default_skewness:
         f_noskew = f
       else:
         f_noskew = bijectors.SinhArcsinh(
             skewness=skewness.dtype.as_numpy_dtype(0.),
-            tailweight=tailweight, event_ndims=0)
+            tailweight=tailweight)
 
       # Make the Affine bijector, Z --> loc + C * Z.
       c = 2 * scale_diag_part / f_noskew.forward(
index 9f9fb5c..1858224 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import abc
 
+import numpy as np
 import six
 
 from tensorflow.python.framework import constant_op
@@ -43,11 +44,10 @@ class BaseBijectorTest(test.TestCase):
       """Minimal specification of a `Bijector`."""
 
       def __init__(self):
-        super(_BareBonesBijector, self).__init__()
+        super(_BareBonesBijector, self).__init__(forward_min_event_ndims=0)
 
     with self.test_session() as sess:
       bij = _BareBonesBijector()
-      self.assertEqual(None, bij.event_ndims)
       self.assertEqual([], bij.graph_parents)
       self.assertEqual(False, bij.is_constant_jacobian)
       self.assertEqual(False, bij.validate_args)
@@ -67,13 +67,21 @@ class BaseBijectorTest(test.TestCase):
         self.assertAllEqual(shape, inverse_event_shape_)
         self.assertAllEqual(shape, bij.inverse_event_shape(shape))
 
-      for fn in ["forward",
-                 "inverse",
-                 "inverse_log_det_jacobian",
-                 "forward_log_det_jacobian"]:
-        with self.assertRaisesRegexp(
-            NotImplementedError, fn + " not implemented"):
-          getattr(bij, fn)(0)
+      with self.assertRaisesRegexp(
+          NotImplementedError, "inverse not implemented"):
+        bij.inverse(0)
+
+      with self.assertRaisesRegexp(
+          NotImplementedError, "forward not implemented"):
+        bij.forward(0)
+
+      with self.assertRaisesRegexp(
+          NotImplementedError, "inverse_log_det_jacobian not implemented"):
+        bij.inverse_log_det_jacobian(0, event_ndims=0)
+
+      with self.assertRaisesRegexp(
+          NotImplementedError, "forward_log_det_jacobian not implemented"):
+        bij.forward_log_det_jacobian(0, event_ndims=0)
 
 
 class IntentionallyMissingError(Exception):
@@ -85,7 +93,7 @@ class BrokenBijector(bijector.Bijector):
 
   def __init__(self, forward_missing=False, inverse_missing=False):
     super(BrokenBijector, self).__init__(
-        event_ndims=0, validate_args=False, name="broken")
+        validate_args=False, forward_min_event_ndims=0, name="broken")
     self._forward_missing = forward_missing
     self._inverse_missing = inverse_missing
 
@@ -120,35 +128,42 @@ class BijectorCachingTestBase(object):
 
   def testCachingOfForwardResults(self):
     broken_bijector = self.broken_bijector_cls(inverse_missing=True)
-    with self.test_session():
-      x = constant_op.constant(1.1)
+    x = constant_op.constant(1.1)
+
+    # Call forward and forward_log_det_jacobian one-by-one (not together).
+    y = broken_bijector.forward(x)
+    _ = broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
 
-      # Call forward and forward_log_det_jacobian one-by-one (not together).
-      y = broken_bijector.forward(x)
-      _ = broken_bijector.forward_log_det_jacobian(x)
+    # Now, everything should be cached if the argument is y.
+    broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
+    try:
+      broken_bijector.inverse(y)
+      broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
+    except IntentionallyMissingError:
+      raise AssertionError("Tests failed! Cached values not used.")
 
-      # Now, everything should be cached if the argument is y.
-      try:
-        broken_bijector.inverse(y)
-        broken_bijector.inverse_log_det_jacobian(y)
-      except IntentionallyMissingError:
-        raise AssertionError("Tests failed! Cached values not used.")
+    # Different event_ndims should not be cached.
+    with self.assertRaises(IntentionallyMissingError):
+      broken_bijector.inverse_log_det_jacobian(y, event_ndims=1)
 
   def testCachingOfInverseResults(self):
     broken_bijector = self.broken_bijector_cls(forward_missing=True)
-    with self.test_session():
-      y = constant_op.constant(1.1)
+    y = constant_op.constant(1.1)
 
-      # Call inverse and inverse_log_det_jacobian one-by-one (not together).
-      x = broken_bijector.inverse(y)
-      _ = broken_bijector.inverse_log_det_jacobian(y)
+    # Call inverse and inverse_log_det_jacobian one-by-one (not together).
+    x = broken_bijector.inverse(y)
+    _ = broken_bijector.inverse_log_det_jacobian(y, event_ndims=0)
 
-      # Now, everything should be cached if the argument is x.
-      try:
-        broken_bijector.forward(x)
-        broken_bijector.forward_log_det_jacobian(x)
-      except IntentionallyMissingError:
-        raise AssertionError("Tests failed! Cached values not used.")
+    # Now, everything should be cached if the argument is x.
+    try:
+      broken_bijector.forward(x)
+      broken_bijector.forward_log_det_jacobian(x, event_ndims=0)
+    except IntentionallyMissingError:
+      raise AssertionError("Tests failed! Cached values not used.")
+
+    # Different event_ndims should not be cached.
+    with self.assertRaises(IntentionallyMissingError):
+      broken_bijector.forward_log_det_jacobian(x, event_ndims=1)
 
 
 class BijectorCachingTest(BijectorCachingTestBase, test.TestCase):
@@ -159,5 +174,107 @@ class BijectorCachingTest(BijectorCachingTestBase, test.TestCase):
     return BrokenBijector
 
 
+class ExpOnlyJacobian(bijector.Bijector):
+  """Only used for jacobian calculations."""
+
+  def __init__(self, forward_min_event_ndims=0):
+    super(ExpOnlyJacobian, self).__init__(
+        validate_args=False,
+        is_constant_jacobian=False,
+        forward_min_event_ndims=forward_min_event_ndims,
+        name="exp")
+
+  def _inverse_log_det_jacobian(self, y):
+    return -math_ops.log(y)
+
+  def _forward_log_det_jacobian(self, x):
+    return math_ops.log(x)
+
+
+class ConstantJacobian(bijector.Bijector):
+  """Only used for jacobian calculations."""
+
+  def __init__(self, forward_min_event_ndims=0):
+    super(ConstantJacobian, self).__init__(
+        validate_args=False,
+        is_constant_jacobian=True,
+        forward_min_event_ndims=forward_min_event_ndims,
+        name="c")
+
+  def _inverse_log_det_jacobian(self, y):
+    return constant_op.constant(2., y.dtype)
+
+  def _forward_log_det_jacobian(self, x):
+    return constant_op.constant(-2., x.dtype)
+
+
+class BijectorReduceEventDimsTest(test.TestCase):
+  """Test caching with BrokenBijector."""
+
+  def testReduceEventNdimsForward(self):
+    x = [[[1., 2.], [3., 4.]]]
+    bij = ExpOnlyJacobian()
+    self.assertAllClose(
+        np.log(x),
+        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0)))
+    self.assertAllClose(
+        np.sum(np.log(x), axis=-1),
+        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1)))
+    self.assertAllClose(
+        np.sum(np.log(x), axis=(-1, -2)),
+        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2)))
+
+  def testReduceEventNdimsForwardRaiseError(self):
+    x = [[[1., 2.], [3., 4.]]]
+    bij = ExpOnlyJacobian(forward_min_event_ndims=1)
+    with self.assertRaisesRegexp(ValueError, "must be larger than"):
+      bij.forward_log_det_jacobian(x, event_ndims=0)
+
+  def testReduceEventNdimsInverse(self):
+    x = [[[1., 2.], [3., 4.]]]
+    bij = ExpOnlyJacobian()
+    self.assertAllClose(
+        -np.log(x),
+        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0)))
+    self.assertAllClose(
+        np.sum(-np.log(x), axis=-1),
+        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1)))
+    self.assertAllClose(
+        np.sum(-np.log(x), axis=(-1, -2)),
+        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2)))
+
+  def testReduceEventNdimsInverseRaiseError(self):
+    x = [[[1., 2.], [3., 4.]]]
+    bij = ExpOnlyJacobian(forward_min_event_ndims=1)
+    with self.assertRaisesRegexp(ValueError, "must be larger than"):
+      bij.inverse_log_det_jacobian(x, event_ndims=0)
+
+  def testReduceEventNdimsForwardConstJacobian(self):
+    x = [[[1., 2.], [3., 4.]]]
+    bij = ConstantJacobian()
+    self.assertAllClose(
+        -2.,
+        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=0)))
+    self.assertAllClose(
+        -4.,
+        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=1)))
+    self.assertAllClose(
+        -8.,
+        self.evaluate(bij.forward_log_det_jacobian(x, event_ndims=2)))
+
+  def testReduceEventNdimsInverseConstJacobian(self):
+    x = [[[1., 2.], [3., 4.]]]
+    bij = ConstantJacobian()
+    self.assertAllClose(
+        2.,
+        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=0)))
+    self.assertAllClose(
+        4.,
+        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=1)))
+    self.assertAllClose(
+        8.,
+        self.evaluate(bij.inverse_log_det_jacobian(x, event_ndims=2)))
+
+
 if __name__ == "__main__":
   test.main()
index e8f9d0b..b347c20 100644 (file)
@@ -27,14 +27,19 @@ class IdentityBijectorTest(test.TestCase):
   """Tests correctness of the Y = g(X) = X transformation."""
 
   def testBijector(self):
-    with self.test_session():
-      bijector = identity_bijector.Identity()
-      self.assertEqual("identity", bijector.name)
-      x = [[[0.], [1.]]]
-      self.assertAllEqual(x, bijector.forward(x).eval())
-      self.assertAllEqual(x, bijector.inverse(x).eval())
-      self.assertAllEqual(0., bijector.inverse_log_det_jacobian(x).eval())
-      self.assertAllEqual(0., bijector.forward_log_det_jacobian(x).eval())
+    bijector = identity_bijector.Identity(validate_args=True)
+    self.assertEqual("identity", bijector.name)
+    x = [[[0.], [1.]]]
+    self.assertAllEqual(x, self.evaluate(bijector.forward(x)))
+    self.assertAllEqual(x, self.evaluate(bijector.inverse(x)))
+    self.assertAllEqual(
+        0.,
+        self.evaluate(
+            bijector.inverse_log_det_jacobian(x, event_ndims=3)))
+    self.assertAllEqual(
+        0.,
+        self.evaluate(
+            bijector.forward_log_det_jacobian(x, event_ndims=3)))
 
   def testScalarCongruency(self):
     with self.test_session():
index ed43555..4ebc600 100644 (file)
@@ -23,7 +23,6 @@ import collections
 import contextlib
 import re
 
-import numpy as np
 import six
 
 from tensorflow.python.framework import dtypes
@@ -31,8 +30,8 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.util.tf_export import tf_export
 
 
 __all__ = [
@@ -41,23 +40,24 @@ __all__ = [
 
 
 class _Mapping(collections.namedtuple(
-    "_Mapping", ["x", "y", "ildj", "kwargs"])):
+    "_Mapping", ["x", "y", "ildj_map", "kwargs"])):
   """Helper class to make it easier to manage caching in `Bijector`."""
 
-  def __new__(cls, x=None, y=None, ildj=None, kwargs=None):
+  def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None):
     """Custom __new__ so namedtuple items have defaults.
 
     Args:
       x: `Tensor`. Forward.
       y: `Tensor`. Inverse.
-      ildj: `Tensor`. Inverse log det Jacobian.
+      ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
+        representing the inverse log det jacobian.
       kwargs: Python dictionary. Extra args supplied to
         forward/inverse/etc functions.
 
     Returns:
       mapping: New instance of _Mapping.
     """
-    return super(_Mapping, cls).__new__(cls, x, y, ildj, kwargs)
+    return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs)
 
   @property
   def x_key(self):
@@ -69,13 +69,14 @@ class _Mapping(collections.namedtuple(
     """Returns key used for caching X=g^{-1}(Y)."""
     return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
 
-  def merge(self, x=None, y=None, ildj=None, kwargs=None, mapping=None):
+  def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
     """Returns new _Mapping with args merged with self.
 
     Args:
       x: `Tensor`. Forward.
       y: `Tensor`. Inverse.
-      ildj: `Tensor`. Inverse log det Jacobian.
+      ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
+        representing the inverse log det jacobian.
       kwargs: Python dictionary. Extra args supplied to
         forward/inverse/etc functions.
       mapping: Instance of _Mapping to merge. Can only be specified if no other
@@ -88,15 +89,30 @@ class _Mapping(collections.namedtuple(
       ValueError: if mapping and any other arg is not `None`.
     """
     if mapping is None:
-      mapping = _Mapping(x=x, y=y, ildj=ildj, kwargs=kwargs)
-    elif not all(arg is None for arg in [x, y, ildj, kwargs]):
-      raise ValueError("Cannot specify mapping and individual args.")
+      mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs)
+    elif any(arg is not None for arg in [x, y, ildj_map, kwargs]):
+      raise ValueError("Cannot simultaneously specify mapping and individual "
+                       "arguments.")
+
     return _Mapping(
         x=self._merge(self.x, mapping.x),
         y=self._merge(self.y, mapping.y),
-        ildj=self._merge(self.ildj, mapping.ildj),
+        ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map),
         kwargs=self._merge(self.kwargs, mapping.kwargs))
 
+  def _merge_dicts(self, old=None, new=None):
+    """Helper to merge two dictionaries."""
+    old = dict() if old is None else old
+    new = dict() if new is None else new
+    for k, v in six.iteritems(new):
+      val = old.get(k, None)
+      if val is not None and val != v:
+        raise ValueError("Found different value for existing key "
+                         "(key:{} old_value:{} new_value:{}".format(
+                             k, old[k], v))
+      old[k] = v
+    return old
+
   def _merge(self, old, new):
     """Helper to merge which handles merging one value."""
     if old is None:
@@ -112,7 +128,6 @@ class _Mapping(collections.namedtuple(
 
 
 @six.add_metaclass(abc.ABCMeta)
-@tf_export("distributions.bijectors.Bijector")
 class Bijector(object):
   r"""Interface for transformations of a `Distribution` sample.
 
@@ -137,11 +152,11 @@ class Bijector(object):
   2. Inverse\
      Useful for "reversing" a transformation to compute one probability in
      terms of another.
-  3. `(log o det o Jacobian o inverse)(x)`\
+  3. `log_det_jacobian(x)`\
      "The log of the determinant of the matrix of all first-order partial
      derivatives of the inverse function."\
      Useful for inverting a transformation to compute one probability in terms
-     of another. Geometrically, the det(Jacobian) is the volume of the
+     of another. Geometrically, the Jacobian determinant is the volume of the
      transformation and is used to scale the probability.
 
   By convention, transformations of random variables are named in terms of the
@@ -164,7 +179,7 @@ class Bijector(object):
 
   ```python
   def transformed_log_prob(bijector, log_prob, x):
-    return (bijector.inverse_log_det_jacobian(x) +
+    return (bijector.inverse_log_det_jacobian(x, event_ndims=0) +
             log_prob(bijector.inverse(x)))
   ```
 
@@ -199,9 +214,11 @@ class Bijector(object):
     ```python
       class Exp(Bijector):
 
-        def __init__(self, event_ndims=0, validate_args=False, name="exp"):
+        def __init__(self, validate_args=False, name="exp"):
           super(Exp, self).__init__(
-              event_ndims=event_ndims, validate_args=validate_args, name=name)
+              validate_args=validate_args,
+              forward_min_event_ndims=0,
+              name=name)
 
         def _forward(self, x):
           return math_ops.exp(x)
@@ -213,10 +230,11 @@ class Bijector(object):
           return -self._forward_log_det_jacobian(self._inverse(y))
 
         def _forward_log_det_jacobian(self, x):
-          if self.event_ndims is None:
-            raise ValueError("Jacobian requires known event_ndims.")
-          event_dims = array_ops.shape(x)[-self.event_ndims:]
-          return math_ops.reduce_sum(x, axis=event_dims)
+          # Notice that we needn't do any reducing, even when`event_ndims > 0`.
+          # The base Bijector class will handle reducing for us; it knows how
+          # to do so because we called `super` `__init__` with
+          # `forward_min_event_ndims = 0`.
+          return x
       ```
 
   - "Affine"
@@ -237,18 +255,50 @@ class Bijector(object):
                   MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
       ```
 
-  #### Jacobian
+  #### Min_event_ndims and Naming
+
+  Bijectors are named for the dimensionality of data they act on (i.e. without
+  broadcasting). We can think of bijectors having an intrinsic `min_event_ndims`
+  , which is the minimum number of dimensions for the bijector act on. For
+  instance, a Cholesky decomposition requires a matrix, and hence
+  `min_event_ndims=2`.
+
+  Some examples:
+
+  `AffineScalar:  min_event_ndims=0`
+  `Affine:  min_event_ndims=1`
+  `Cholesky:  min_event_ndims=2`
+  `Exp:  min_event_ndims=0`
+  `Sigmoid:  min_event_ndims=0`
+  `SoftmaxCentered:  min_event_ndims=1`
+
+  Note the difference between `Affine` and `AffineScalar`. `AffineScalar`
+  operates on scalar events, whereas `Affine` operates on vector-valued events.
 
-  The Jacobian is a reduction over event dims. To see this, consider the `Exp`
-  `Bijector` applied to a `Tensor` which has sample, batch, and event (S, B, E)
-  shape semantics. Suppose the `Tensor`'s partitioned-shape is `(S=[4], B=[2],
-  E=[3, 3])`. The shape of the `Tensor` returned by `forward` and `inverse` is
-  unchanged, i.e., `[4, 2, 3, 3]`.  However the shape returned by
-  `inverse_log_det_jacobian` is `[4, 2]` because the Jacobian is a reduction
-  over the event dimensions.
+  More generally, there is a `forward_min_event_ndims` and an
+  `inverse_min_event_ndims`. In most cases, these will be the same.
+  However, for some shape changing bijectors, these will be different
+  (e.g. a bijector which pads an extra dimension at the end, might have
+  `forward_min_event_ndims=0` and `inverse_min_event_ndims=1`.
 
-  It is sometimes useful to implement the inverse Jacobian as the negative
-  forward Jacobian. For example,
+
+  #### Jacobian Determinant
+
+  The Jacobian determinant is a reduction over `event_ndims - min_event_ndims`
+  (`forward_min_event_ndims` for `forward_log_det_jacobian` and
+  `inverse_min_event_ndims` for `inverse_log_det_jacobian`).
+  To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has
+  sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s
+  partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor`
+  returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`.
+  However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because
+  the Jacobian determinant is a reduction over the event dimensions.
+
+  Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the
+  Jacobian determinant reduction is over `event_ndims - 1`.
+
+  It is sometimes useful to implement the inverse Jacobian determinant as the
+  negative forward Jacobian determinant. For example,
 
   ```python
   def _inverse_log_det_jacobian(self, y):
@@ -279,9 +329,54 @@ class Bijector(object):
       The claim follows from [properties of determinant](
   https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups).
 
-  Generally its preferable to directly implement the inverse Jacobian. This
-  should have superior numerical stability and will often share subgraphs with
-  the `_inverse` implementation.
+  Generally its preferable to directly implement the inverse Jacobian
+  determinant.  This should have superior numerical stability and will often
+  share subgraphs with the `_inverse` implementation.
+
+  #### Is_constant_jacobian
+
+  Certain bijectors will have constant jacobian matrices. For instance, the
+  `Affine` bijector encodes multiplication by a matrix plus a shift, with
+  jacobian matrix, the same aforementioned matrix.
+
+  `is_constant_jacobian` encodes the fact that the jacobian matrix is constant.
+  The semantics of this argument are the following:
+
+    * Repeated calls to "log_det_jacobian" functions with the same
+      `event_ndims` (but not necessarily same input), will return the first
+      computed jacobian (because the matrix is constant, and hence is input
+      independent).
+    * `log_det_jacobian` implementations are merely broadcastable to the true
+      `log_det_jacobian` (because, again, the jacobian matrix is input
+      independent). Specifically, `log_det_jacobian` is implemented as the
+      log jacobian determinant for a single input.
+
+      ```python
+      class Identity(Bijector):
+
+        def __init__(self, validate_args=False, name="identity"):
+          super(Identity, self).__init__(
+              is_constant_jacobian=True,
+              validate_args=validate_args,
+              forward_min_event_ndims=0,
+              name=name)
+
+        def _forward(self, x):
+          return x
+
+        def _inverse(self, y):
+          return y
+
+        def _inverse_log_det_jacobian(self, y):
+          return -self._forward_log_det_jacobian(self._inverse(y))
+
+        def _forward_log_det_jacobian(self, x):
+          # The full log jacobian determinant would be array_ops.zero_like(x).
+          # However, we circumvent materializing that, since the jacobian
+          # calculation is input independent, and we specify it for one input.
+          return constant_op.constant(0., x.dtype.base_dtype)
+
+      ```
 
   #### Subclass Requirements
 
@@ -364,14 +459,14 @@ class Bijector(object):
   ==> (-1., 1.)
 
   # The |dX/dY| is constant, == 1.  So Log|dX/dY| == 0.
-  abs.inverse_log_det_jacobian(1.)
+  abs.inverse_log_det_jacobian(1., event_ndims=0)
   ==> (0., 0.)
 
   # Special case handling of 0.
   abs.inverse(0.)
   ==> (0., 0.)
 
-  abs.inverse_log_det_jacobian(0.)
+  abs.inverse_log_det_jacobian(0., event_ndims=0)
   ==> (0., 0.)
   ```
 
@@ -379,11 +474,12 @@ class Bijector(object):
 
   @abc.abstractmethod
   def __init__(self,
-               event_ndims=None,
                graph_parents=None,
                is_constant_jacobian=False,
                validate_args=False,
                dtype=None,
+               forward_min_event_ndims=None,
+               inverse_min_event_ndims=None,
                name=None):
     """Constructs Bijector.
 
@@ -392,42 +488,61 @@ class Bijector(object):
     Examples:
 
     ```python
-    # Create the Y = g(X) = X transform which operates on vector events.
-    identity = Identity(event_ndims=1)
+    # Create the Y = g(X) = X transform.
+    identity = Identity()
 
-    # Create the Y = g(X) = exp(X) transform which operates on matrices.
-    exp = Exp(event_ndims=2)
+    # Create the Y = g(X) = exp(X) transform.
+    exp = Exp()
     ```
 
     See `Bijector` subclass docstring for more details and specific examples.
 
     Args:
-      event_ndims: number of dimensions associated with event coordinates.
       graph_parents: Python list of graph prerequisites of this `Bijector`.
-      is_constant_jacobian: Python `bool` indicating that the Jacobian is not a
-        function of the input.
+      is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is
+        not a function of the input.
       validate_args: Python `bool`, default `False`. Whether to validate input
         with asserts. If `validate_args` is `False`, and the inputs are invalid,
         correct behavior is not guaranteed.
       dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
         enforced.
+      forward_min_event_ndims: Python `integer` indicating the minimum number of
+        dimensions `forward` operates on.
+      inverse_min_event_ndims: Python `integer` indicating the minimum number of
+        dimensions `inverse` operates on. Will be set to
+        `forward_min_event_ndims` by default, if no value is provided.
       name: The name to give Ops created by the initializer.
 
     Raises:
+      ValueError:  If neither `forward_min_event_ndims` and
+        `inverse_min_event_ndims` are specified, or if either of them is
+        negative.
       ValueError:  If a member of `graph_parents` is not a `Tensor`.
     """
-    self._event_ndims = (
-        ops.convert_to_tensor(event_ndims, dtype=dtypes.int32)
-        if event_ndims is not None else None)
     self._graph_parents = graph_parents or []
+
+    if forward_min_event_ndims is None and inverse_min_event_ndims is None:
+      raise ValueError("Must specify at least one of `forward_min_event_ndims` "
+                       "and `inverse_min_event_ndims`.")
+    elif inverse_min_event_ndims is None:
+      inverse_min_event_ndims = forward_min_event_ndims
+    elif forward_min_event_ndims is None:
+      forward_min_event_ndims = inverse_min_event_ndims
+
+    if forward_min_event_ndims < 0:
+      raise ValueError("forward_min_event_ndims must be a non-negative "
+                       "integer.")
+    if inverse_min_event_ndims < 0:
+      raise ValueError("inverse_min_event_ndims must be a non-negative "
+                       "integer.")
+    self._forward_min_event_ndims = forward_min_event_ndims
+    self._inverse_min_event_ndims = inverse_min_event_ndims
     self._is_constant_jacobian = is_constant_jacobian
+    self._constant_ildj_map = {}
     self._validate_args = validate_args
     self._dtype = dtype
     self._from_y = {}
     self._from_x = {}
-    # Using abbreviation ildj for "inverse log det Jacobian."
-    # This variable is not `None` iff is_constant_jacobian is `True`.
-    self._constant_ildj = None
     if name:
       self._name = name
     else:
@@ -443,20 +558,26 @@ class Bijector(object):
         raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
 
   @property
-  def event_ndims(self):
-    """Returns then number of event dimensions this bijector operates on."""
-    return self._event_ndims
-
-  @property
   def graph_parents(self):
     """Returns this `Bijector`'s graph_parents as a Python list."""
     return self._graph_parents
 
   @property
+  def forward_min_event_ndims(self):
+    """Returns the minimal number of dimensions bijector.forward operates on."""
+    return self._forward_min_event_ndims
+
+  @property
+  def inverse_min_event_ndims(self):
+    """Returns the minimal number of dimensions bijector.inverse operates on."""
+    return self._inverse_min_event_ndims
+
+  @property
   def is_constant_jacobian(self):
-    """Returns true iff the Jacobian is not a function of x.
+    """Returns true iff the Jacobian matrix is not a function of x.
 
-    Note: Jacobian is either constant for both forward and inverse or neither.
+    Note: Jacobian matrix is either constant for both forward and inverse or
+    neither.
 
     Returns:
       is_constant_jacobian: Python `bool`.
@@ -653,36 +774,57 @@ class Bijector(object):
     return self._call_inverse(y, name)
 
   def _inverse_log_det_jacobian(self, y):
-    """Subclass implementation of `inverse_log_det_jacobian` public function."""
+    """Subclass implementation of `inverse_log_det_jacobian` public function.
+
+    In particular, this method differs from the public function, in that it
+    does not take `event_ndims`. Thus, this implements the minimal Jacobian
+    determinant calculation (i.e. over `inverse_min_event_ndims`).
+
+    Args:
+      y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation.
+    Returns:
+      inverse_log_det_jacobian: `Tensor`, if this bijector is injective.
+        If not injective, returns the k-tuple containing jacobians for the
+        unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
+    """
     raise NotImplementedError("inverse_log_det_jacobian not implemented.")
 
-  def _call_inverse_log_det_jacobian(self, y, name, **kwargs):
+  def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
     with self._name_scope(name, [y]):
-      if self._constant_ildj is not None:
-        return self._constant_ildj
+      if event_ndims in self._constant_ildj_map:
+        return self._constant_ildj_map[event_ndims]
       y = ops.convert_to_tensor(y, name="y")
       self._maybe_assert_dtype(y)
       if not self._is_injective:  # No caching for non-injective
-        return self._inverse_log_det_jacobian(y, **kwargs)
+        ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+        return tuple(self._reduce_jacobian_det_over_event(
+            y, ildj, self.inverse_min_event_ndims, event_ndims)
+                     for ildj in ildjs)
       mapping = self._lookup(y=y, kwargs=kwargs)
-      if mapping.ildj is not None:
-        return mapping.ildj
+      if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
+        return mapping.ildj_map[event_ndims]
       try:
         x = None  # Not needed; leave cache as is.
         ildj = self._inverse_log_det_jacobian(y, **kwargs)
+        ildj = self._reduce_jacobian_det_over_event(
+            y, ildj, self.inverse_min_event_ndims, event_ndims)
       except NotImplementedError as original_exception:
         try:
           x = mapping.x if mapping.x is not None else self._inverse(y, **kwargs)
           ildj = -self._forward_log_det_jacobian(x, **kwargs)
+          ildj = self._reduce_jacobian_det_over_event(
+              x, ildj, self.forward_min_event_ndims, event_ndims)
         except NotImplementedError:
           raise original_exception
-      mapping = mapping.merge(x=x, ildj=ildj)
+
+      mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj})
       self._cache(mapping)
       if self.is_constant_jacobian:
-        self._constant_ildj = mapping.ildj
-      return mapping.ildj
+        self._constant_ildj_map[event_ndims] = ildj
+      return ildj
 
-  def inverse_log_det_jacobian(self, y, name="inverse_log_det_jacobian"):
+  def inverse_log_det_jacobian(
+      self, y, event_ndims, name="inverse_log_det_jacobian"):
     """Returns the (log o det o Jacobian o inverse)(y).
 
     Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.)
@@ -691,7 +833,12 @@ class Bijector(object):
     evaluated at `g^{-1}(y)`.
 
     Args:
-      y: `Tensor`. The input to the "inverse" Jacobian evaluation.
+      y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation.
+      event_ndims: Number of dimensions in the probabilistic events being
+        transformed. Must be greater than or equal to
+        `self.inverse_min_event_ndims`. The result is summed over the final
+        dimensions to produce a scalar Jacobian determinant for each event,
+        i.e. it has shape `y.shape.ndims - event_ndims` dimensions.
       name: The name to give this op.
 
     Returns:
@@ -705,45 +852,74 @@ class Bijector(object):
         `self.dtype`.
       NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
     """
-    return self._call_inverse_log_det_jacobian(y, name)
+    with ops.control_dependencies(self._check_valid_event_ndims(
+        min_event_ndims=self.inverse_min_event_ndims, event_ndims=event_ndims)):
+      return self._call_inverse_log_det_jacobian(y, event_ndims, name)
 
   def _forward_log_det_jacobian(self, x):
-    """Subclass implementation of `forward_log_det_jacobian`."""
+    """Subclass implementation of `forward_log_det_jacobian` public function.
+
+    In particular, this method differs from the public function, in that it
+    does not take `event_ndims`. Thus, this implements the minimal Jacobian
+    determinant calculation (i.e. over `forward_min_event_ndims`).
+
+    Args:
+      x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation.
+
+    Returns:
+      forward_log_det_jacobian: `Tensor`, if this bijector is injective.
+        If not injective, returns the k-tuple containing jacobians for the
+        unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
+    """
+
     raise NotImplementedError(
         "forward_log_det_jacobian not implemented.")
 
-  def _call_forward_log_det_jacobian(self, x, name, **kwargs):
+  def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
     with self._name_scope(name, [x]):
-      if self._constant_ildj is not None:
+      if event_ndims in self._constant_ildj_map:
         # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
-        return -1. * self._constant_ildj
+        return -1. * self._constant_ildj_map[event_ndims]
       x = ops.convert_to_tensor(x, name="x")
       self._maybe_assert_dtype(x)
       if not self._is_injective:
-        return self._forward_log_det_jacobian(x, **kwargs)  # No caching.
+        fldjs = self._forward_log_det_jacobian(x, **kwargs)  # No caching.
+        return tuple(self._reduce_jacobian_det_over_event(
+            x, fldj, self.forward_min_event_ndims, event_ndims)
+                     for fldj in fldjs)
       mapping = self._lookup(x=x, kwargs=kwargs)
-      if mapping.ildj is not None:
-        return -mapping.ildj
+      if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
+        return -mapping.ildj_map[event_ndims]
       try:
         y = None  # Not needed; leave cache as is.
         ildj = -self._forward_log_det_jacobian(x, **kwargs)
+        ildj = self._reduce_jacobian_det_over_event(
+            x, ildj, self.forward_min_event_ndims, event_ndims)
       except NotImplementedError as original_exception:
         try:
           y = mapping.y if mapping.y is not None else self._forward(x, **kwargs)
           ildj = self._inverse_log_det_jacobian(y, **kwargs)
+          ildj = self._reduce_jacobian_det_over_event(
+              y, ildj, self.inverse_min_event_ndims, event_ndims)
         except NotImplementedError:
           raise original_exception
-      mapping = mapping.merge(y=y, ildj=ildj)
+      mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj})
       self._cache(mapping)
       if self.is_constant_jacobian:
-        self._constant_ildj = mapping.ildj
-      return -mapping.ildj
+        self._constant_ildj_map[event_ndims] = ildj
+      return -ildj
 
-  def forward_log_det_jacobian(self, x, name="forward_log_det_jacobian"):
+  def forward_log_det_jacobian(
+      self, x, event_ndims, name="forward_log_det_jacobian"):
     """Returns both the forward_log_det_jacobian.
 
     Args:
-      x: `Tensor`. The input to the "forward" Jacobian evaluation.
+      x: `Tensor`. The input to the "forward" Jacobian determinant evaluation.
+      event_ndims: Number of dimensions in the probabilistic events being
+        transformed. Must be greater than or equal to
+        `self.forward_min_event_ndims`. The result is summed over the final
+        dimensions to produce a scalar Jacobian determinant for each event,
+        i.e. it has shape `x.shape.ndims - event_ndims` dimensions.
       name: The name to give this op.
 
     Returns:
@@ -761,7 +937,9 @@ class Bijector(object):
       raise NotImplementedError(
           "forward_log_det_jacobian cannot be implemented for non-injective "
           "transforms.")
-    return self._call_forward_log_det_jacobian(x, name)
+    with ops.control_dependencies(self._check_valid_event_ndims(
+        min_event_ndims=self.forward_min_event_ndims, event_ndims=event_ndims)):
+      return self._call_forward_log_det_jacobian(x, event_ndims, name)
 
   @contextlib.contextmanager
   def _name_scope(self, name=None, values=None):
@@ -779,9 +957,6 @@ class Bijector(object):
 
   def _cache(self, mapping):
     """Helper which stores mapping info in forward/inverse dicts."""
-    if self._constant_ildj is not None:
-      # Fold in ildj if known constant Jacobian.
-      mapping = mapping.merge(ildj=self._constant_ildj)
     # Merging from lookup is an added check that we're not overwriting anything
     # which is not None.
     mapping = mapping.merge(mapping=self._lookup(
@@ -803,22 +978,66 @@ class Bijector(object):
       return self._from_y.get(mapping.y_key, mapping)
     return mapping
 
-  def _event_dims_tensor(self, sample):
-    """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`."""
-    if self.event_ndims is None:
-      raise ValueError("Jacobian cannot be computed with unknown event_ndims")
-    static_event_ndims = tensor_util.constant_value(self.event_ndims)
-    static_rank = sample.get_shape().ndims
-    if static_event_ndims is not None and static_rank is not None:
-      return ops.convert_to_tensor(
-          static_rank + np.arange(-static_event_ndims, 0).astype(np.int32))
-
-    if static_event_ndims is not None:
-      event_range = np.arange(-static_event_ndims, 0).astype(np.int32)
-    else:
-      event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32)
-
-    if static_rank is not None:
-      return event_range + static_rank
+  def _reduce_jacobian_det_over_event(
+      self, y, ildj, min_event_ndims, event_ndims):
+    """Reduce jacobian over event_ndims - min_event_ndims."""
+    if not self.is_constant_jacobian:
+      return math_ops.reduce_sum(
+          ildj,
+          self._get_event_reduce_dims(min_event_ndims, event_ndims))
+
+    # In this case, we need to tile the jacobian over the event and reduce.
+    y_rank = array_ops.rank(y)
+    y_shape = array_ops.shape(y)[
+        y_rank - event_ndims : y_rank - min_event_ndims]
+
+    ones = array_ops.ones(y_shape, ildj.dtype)
+    reduced_ildj = math_ops.reduce_sum(
+        ones * ildj,
+        axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
+    # The multiplication by ones can change the inferred static shape so we try
+    # to recover as much as possible.
+    if (isinstance(event_ndims, int) and
+        y.get_shape().ndims and ildj.get_shape().ndims):
+      y_shape = y.get_shape()
+      y_shape = y_shape[y_shape.ndims - event_ndims :
+                        y_shape.ndims - min_event_ndims]
+      ildj_shape = ildj.get_shape()
+      broadcast_shape = array_ops.broadcast_static_shape(
+          ildj_shape, y_shape)
+      reduced_ildj.set_shape(
+          broadcast_shape[: broadcast_shape.ndims - (
+              event_ndims - min_event_ndims)])
+
+    return reduced_ildj
+
+  def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
+    """Compute the reduction dimensions given event_ndims."""
+    min_event_ndims_ = (min_event_ndims if isinstance(min_event_ndims, int)
+                        else tensor_util.constant_value(min_event_ndims))
+    event_ndims_ = (event_ndims if isinstance(event_ndims, int)
+                    else tensor_util.constant_value(event_ndims))
+
+    if min_event_ndims_ is not None and event_ndims_ is not None:
+      return [-index for index in range(1, event_ndims_ - min_event_ndims_ + 1)]
     else:
-      return event_range + array_ops.rank(sample)
+      reduce_ndims = event_ndims - min_event_ndims
+      return math_ops.range(-reduce_ndims, 0)
+
+  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
+    """Check whether event_ndims is atleast min_event_ndims."""
+    min_event_ndims_ = (min_event_ndims if isinstance(min_event_ndims, int)
+                        else tensor_util.constant_value(min_event_ndims))
+    event_ndims_ = (event_ndims if isinstance(event_ndims, int)
+                    else tensor_util.constant_value(event_ndims))
+
+    if min_event_ndims_ is not None and event_ndims_ is not None:
+      if min_event_ndims_ > event_ndims_:
+        raise ValueError("event_ndims ({}) must be larger than "
+                         "min_event_ndims ({})".format(
+                             event_ndims_, min_event_ndims_))
+      return []
+
+    if self.validate_args:
+      return [check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
+    return []
index ff3535c..784bfd5 100644 (file)
@@ -79,9 +79,7 @@ def assert_scalar_congruency(bijector,
   Raises:
     AssertionError:  If tests fail.
   """
-
   # Checks and defaults.
-  assert bijector.event_ndims.eval() == 0
   if sess is None:
     sess = ops.get_default_session()
 
@@ -111,7 +109,10 @@ def assert_scalar_congruency(bijector,
   # (b - a) = \int_a^b dx = \int_{y(a)}^{y(b)} |dx/dy| dy
   # "change_measure_dy_dx" below is a Monte Carlo approximation to the right
   # hand side, which should then be close to the left, which is (b - a).
-  dy_dx = math_ops.exp(bijector.inverse_log_det_jacobian(uniform_y_samps))
+  # We assume event_ndims=0 because we assume scalar -> scalar. The log_det
+  # methods will handle whether they expect event_ndims > 0.
+  dy_dx = math_ops.exp(bijector.inverse_log_det_jacobian(
+      uniform_y_samps, event_ndims=0))
   # E[|dx/dy|] under Uniform[lower_y, upper_y]
   # = \int_{y(a)}^{y(b)} |dx/dy| dP(u), where dP(u) is the uniform measure
   expectation_of_dy_dx_under_uniform = math_ops.reduce_mean(dy_dx)
@@ -121,7 +122,8 @@ def assert_scalar_congruency(bijector,
 
   # We'll also check that dy_dx = 1 / dx_dy.
   dx_dy = math_ops.exp(
-      bijector.forward_log_det_jacobian(bijector.inverse(uniform_y_samps)))
+      bijector.forward_log_det_jacobian(
+          bijector.inverse(uniform_y_samps), event_ndims=0))
 
   [
       forward_on_10_pts_v,
@@ -158,7 +160,8 @@ def assert_scalar_congruency(bijector,
       dy_dx_v, np.divide(1., dx_dy_v), atol=1e-5, rtol=1e-3)
 
 
-def assert_bijective_and_finite(bijector, x, y, atol=0, rtol=1e-5, sess=None):
+def assert_bijective_and_finite(
+    bijector, x, y, event_ndims, atol=0, rtol=1e-5, sess=None):
   """Assert that forward/inverse (along with jacobians) are inverses and finite.
 
   It is recommended to use x and y values that are very very close to the edge
@@ -168,6 +171,8 @@ def assert_bijective_and_finite(bijector, x, y, atol=0, rtol=1e-5, sess=None):
     bijector:  A Bijector instance.
     x:  np.array of values in the domain of bijector.forward.
     y:  np.array of values in the domain of bijector.inverse.
+    event_ndims: Integer describing the number of event dimensions this bijector
+      operates on.
     atol:  Absolute tolerance.
     rtol:  Relative tolerance.
     sess:  TensorFlow session.  Defaults to the default session.
@@ -197,10 +202,10 @@ def assert_bijective_and_finite(bijector, x, y, atol=0, rtol=1e-5, sess=None):
   ] = sess.run([
       bijector.inverse(f_x),
       bijector.forward(g_y),
-      bijector.inverse_log_det_jacobian(f_x),
-      bijector.forward_log_det_jacobian(x),
-      bijector.inverse_log_det_jacobian(y),
-      bijector.forward_log_det_jacobian(g_y),
+      bijector.inverse_log_det_jacobian(f_x, event_ndims=event_ndims),
+      bijector.forward_log_det_jacobian(x, event_ndims=event_ndims),
+      bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims),
+      bijector.forward_log_det_jacobian(g_y, event_ndims=event_ndims),
       f_x,
       g_y,
   ])
diff --git a/tensorflow/python/ops/distributions/bijectors.py b/tensorflow/python/ops/distributions/bijectors.py
deleted file mode 100644 (file)
index 69c3a5d..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Core module for TensorFlow distribution bijectors."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-# go/tf-wildcard-import
-# pylint: disable=wildcard-import,unused-import
-from tensorflow.python.ops.distributions.bijector import Bijector
-from tensorflow.python.ops.distributions.identity_bijector import Identity
-
-# pylint: enable=wildcard-import,unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = ["Bijector", "Identity"]
-
-remove_undocumented(__name__, _allowed_symbols)
index 9df7d14..7c4b869 100644 (file)
@@ -19,7 +19,6 @@ from __future__ import print_function
 
 
 # pylint: disable=wildcard-import,unused-import
-from tensorflow.python.ops.distributions import bijectors
 from tensorflow.python.ops.distributions.bernoulli import Bernoulli
 from tensorflow.python.ops.distributions.beta import Beta
 from tensorflow.python.ops.distributions.categorical import Categorical
@@ -40,7 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented
 
 
 _allowed_symbols = [
-    "bijectors",
     "Bernoulli",
     "Beta",
     "Categorical",
index 2972c35..8628e68 100644 (file)
@@ -20,7 +20,6 @@ from __future__ import print_function
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops.distributions import bijector
-from tensorflow.python.util.tf_export import tf_export
 
 
 __all__ = [
@@ -28,7 +27,6 @@ __all__ = [
 ]
 
 
-@tf_export("distributions.bijectors.Identity")
 class Identity(bijector.Bijector):
   """Compute Y = g(X) = X.
 
@@ -37,7 +35,7 @@ class Identity(bijector.Bijector):
     ```python
     # Create the Y=g(X)=X transform which is intended for Tensors with 1 batch
     # ndim and 1 event ndim (i.e., vector of vectors).
-    identity = Identity(event_ndims=1)
+    identity = Identity()
     x = [[1., 2],
          [3, 4]]
     x == identity.forward(x) == identity.inverse(x)
@@ -45,10 +43,10 @@ class Identity(bijector.Bijector):
 
   """
 
-  def __init__(self, validate_args=False, event_ndims=0, name="identity"):
+  def __init__(self, validate_args=False, name="identity"):
     super(Identity, self).__init__(
+        forward_min_event_ndims=0,
         is_constant_jacobian=True,
-        event_ndims=event_ndims,
         validate_args=validate_args,
         name=name)
 
index 1efcf9d..1ad63a8 100644 (file)
@@ -197,8 +197,7 @@ class TransformedDistribution(distribution_lib.Distribution):
     distribution=ds.Normal(loc=0., scale=1.),
     bijector=ds.bijectors.Affine(
       shift=-1.,
-      scale_identity_multiplier=2.,
-      event_ndims=0),
+      scale_identity_multiplier=2.)
     name="NormalTransformedDistribution")
   ```
 
@@ -419,48 +418,51 @@ class TransformedDistribution(distribution_lib.Distribution):
     # For caching to work, it is imperative that the bijector is the first to
     # modify the input.
     x = self.bijector.inverse(y)
-    ildj = self.bijector.inverse_log_det_jacobian(y)
+    event_ndims = self._maybe_get_event_ndims_statically()
+
+    ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
     if self.bijector._is_injective:  # pylint: disable=protected-access
-      return self._finish_log_prob_for_one_fiber(y, x, ildj)
+      return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims)
 
     lp_on_fibers = [
-        self._finish_log_prob_for_one_fiber(y, x_i, ildj_i)
+        self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
         for x_i, ildj_i in zip(x, ildj)]
     return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0)
 
-  def _finish_log_prob_for_one_fiber(self, y, x, ildj):
+  def _finish_log_prob_for_one_fiber(self, y, x, ildj, event_ndims):
     """Finish computation of log_prob on one element of the inverse image."""
     x = self._maybe_rotate_dims(x, rotate_right=True)
     log_prob = self.distribution.log_prob(x)
     if self._is_maybe_event_override:
       log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
     log_prob += math_ops.cast(ildj, log_prob.dtype)
-    if self._is_maybe_event_override:
+    if self._is_maybe_event_override and isinstance(event_ndims, int):
       log_prob.set_shape(array_ops.broadcast_static_shape(
-          y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
+          x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
     return log_prob
 
   def _prob(self, y):
     x = self.bijector.inverse(y)
-    ildj = self.bijector.inverse_log_det_jacobian(y)
+    event_ndims = self._maybe_get_event_ndims_statically()
+    ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
     if self.bijector._is_injective:  # pylint: disable=protected-access
-      return self._finish_prob_for_one_fiber(y, x, ildj)
+      return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims)
 
     prob_on_fibers = [
-        self._finish_prob_for_one_fiber(y, x_i, ildj_i)
+        self._finish_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
         for x_i, ildj_i in zip(x, ildj)]
     return sum(prob_on_fibers)
 
-  def _finish_prob_for_one_fiber(self, y, x, ildj):
+  def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims):
     """Finish computation of prob on one element of the inverse image."""
     x = self._maybe_rotate_dims(x, rotate_right=True)
     prob = self.distribution.prob(x)
     if self._is_maybe_event_override:
       prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
     prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
-    if self._is_maybe_event_override:
+    if self._is_maybe_event_override and isinstance(event_ndims, int):
       prob.set_shape(array_ops.broadcast_static_shape(
-          y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
+          y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
     return prob
 
   def _log_cdf(self, y):
@@ -545,10 +547,17 @@ class TransformedDistribution(distribution_lib.Distribution):
           _ones_like(self.distribution.batch_shape_tensor())
       ], 0)
       entropy = array_ops.tile(entropy, multiples)
-    dummy = array_ops.zeros([], self.dtype)
-    entropy -= math_ops.cast(
-        self.bijector.inverse_log_det_jacobian(dummy),
-        entropy.dtype)
+    dummy = array_ops.zeros(
+        shape=array_ops.concat(
+            [self.batch_shape_tensor(), self.event_shape_tensor()],
+            0),
+        dtype=self.dtype)
+    event_ndims = (self.event_shape.ndims if self.event_shape.ndims is not None
+                   else array_ops.size(self.event_shape_tensor()))
+    ildj = self.bijector.inverse_log_det_jacobian(
+        dummy, event_ndims=event_ndims)
+
+    entropy -= math_ops.cast(ildj, entropy.dtype)
     entropy.set_shape(self.batch_shape)
     return entropy
 
@@ -610,3 +619,16 @@ class TransformedDistribution(distribution_lib.Distribution):
     n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims
     return array_ops.transpose(
         x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
+
+  def _maybe_get_event_ndims_statically(self):
+    if self.event_shape.ndims is not None:
+      return self.event_shape.ndims
+
+    event_ndims = array_ops.size(self.event_shape_tensor())
+
+    static_event_ndims = tensor_util.constant_value(event_ndims)
+
+    if static_event_ndims is not None:
+      return static_event_ndims
+
+    return event_ndims
diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-bijector.pbtxt
deleted file mode 100644 (file)
index 11565bd..0000000
+++ /dev/null
@@ -1,65 +0,0 @@
-path: "tensorflow.distributions.bijectors.Bijector"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.distributions.bijector_impl.Bijector\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "dtype"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "event_ndims"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "graph_parents"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "is_constant_jacobian"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "name"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "validate_args"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'event_ndims\', \'graph_parents\', \'is_constant_jacobian\', \'validate_args\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "forward"
-    argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward\'], "
-  }
-  member_method {
-    name: "forward_event_shape"
-    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "forward_event_shape_tensor"
-    argspec: "args=[\'self\', \'input_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_event_shape_tensor\'], "
-  }
-  member_method {
-    name: "forward_log_det_jacobian"
-    argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_log_det_jacobian\'], "
-  }
-  member_method {
-    name: "inverse"
-    argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
-  }
-  member_method {
-    name: "inverse_event_shape"
-    argspec: "args=[\'self\', \'output_shape\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "inverse_event_shape_tensor"
-    argspec: "args=[\'self\', \'output_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_event_shape_tensor\'], "
-  }
-  member_method {
-    name: "inverse_log_det_jacobian"
-    argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_log_det_jacobian\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.-identity.pbtxt
deleted file mode 100644 (file)
index 1e5fe62..0000000
+++ /dev/null
@@ -1,66 +0,0 @@
-path: "tensorflow.distributions.bijectors.Identity"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.distributions.identity_bijector.Identity\'>"
-  is_instance: "<class \'tensorflow.python.ops.distributions.bijector_impl.Bijector\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "dtype"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "event_ndims"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "graph_parents"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "is_constant_jacobian"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "name"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "validate_args"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'validate_args\', \'event_ndims\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'0\', \'identity\'], "
-  }
-  member_method {
-    name: "forward"
-    argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward\'], "
-  }
-  member_method {
-    name: "forward_event_shape"
-    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "forward_event_shape_tensor"
-    argspec: "args=[\'self\', \'input_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_event_shape_tensor\'], "
-  }
-  member_method {
-    name: "forward_log_det_jacobian"
-    argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'forward_log_det_jacobian\'], "
-  }
-  member_method {
-    name: "inverse"
-    argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
-  }
-  member_method {
-    name: "inverse_event_shape"
-    argspec: "args=[\'self\', \'output_shape\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "inverse_event_shape_tensor"
-    argspec: "args=[\'self\', \'output_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_event_shape_tensor\'], "
-  }
-  member_method {
-    name: "inverse_log_det_jacobian"
-    argspec: "args=[\'self\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse_log_det_jacobian\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt b/tensorflow/tools/api/golden/tensorflow.distributions.bijectors.pbtxt
deleted file mode 100644 (file)
index 1d0144f..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-path: "tensorflow.distributions.bijectors"
-tf_module {
-  member {
-    name: "Bijector"
-    mtype: "<class \'abc.ABCMeta\'>"
-  }
-  member {
-    name: "Identity"
-    mtype: "<class \'abc.ABCMeta\'>"
-  }
-}
index 2fba7c5..90b60ef 100644 (file)
@@ -68,10 +68,6 @@ tf_module {
     name: "Uniform"
     mtype: "<class \'tensorflow.python.ops.distributions.distribution._DistributionMeta\'>"
   }
-  member {
-    name: "bijectors"
-    mtype: "<type \'module\'>"
-  }
   member_method {
     name: "kl_divergence"
     argspec: "args=[\'distribution_a\', \'distribution_b\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "