Set the correct shape in transformed distribution.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 03:03:20 +0000 (20:03 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 03:05:39 +0000 (20:05 -0700)
Also add distribution_util.maybe_get_static_event_ndims to be reused in bijector and transformed distribution classes.

PiperOrigin-RevId: 197831651

tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py
tensorflow/contrib/distributions/python/ops/bijectors/chain.py
tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py
tensorflow/python/kernel_tests/distributions/bijector_test.py
tensorflow/python/ops/distributions/bijector_impl.py
tensorflow/python/ops/distributions/transformed_distribution.py
tensorflow/python/ops/distributions/util.py

index 8b279eb..f8a5261 100644 (file)
@@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase):
     for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]:
       method = getattr(b, name)
       with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"):
-        method(1., event_ndims=0., arg1="b1", arg2="b2")
+        method(1., event_ndims=0, arg1="b1", arg2="b2")
 
 
 if __name__ == "__main__":
index b158a51..16f9595 100644 (file)
@@ -234,7 +234,7 @@ class Chain(bijector.Bijector):
     if not self.bijectors:
       return ildj
 
-    event_ndims = self._maybe_get_event_ndims_statically(
+    event_ndims = self._maybe_get_static_event_ndims(
         self.inverse_min_event_ndims)
 
     if _use_static_shape(y, event_ndims):
@@ -248,12 +248,15 @@ class Chain(bijector.Bijector):
 
       if _use_static_shape(y, event_ndims):
         event_shape = b.inverse_event_shape(event_shape)
-        event_ndims = self._maybe_get_event_ndims_statically(
+        event_ndims = self._maybe_get_static_event_ndims(
             event_shape.ndims)
       else:
         event_shape = b.inverse_event_shape_tensor(event_shape)
-        event_ndims = self._maybe_get_event_ndims_statically(
-            array_ops.size(event_shape))
+        event_ndims = array_ops.size(event_shape)
+        event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+        if event_ndims_ is not None:
+          event_ndims = event_ndims_
+
       y = b.inverse(y, **kwargs.get(b.name, {}))
     return ildj
 
@@ -270,7 +273,7 @@ class Chain(bijector.Bijector):
     if not self.bijectors:
       return fldj
 
-    event_ndims = self._maybe_get_event_ndims_statically(
+    event_ndims = self._maybe_get_static_event_ndims(
         self.forward_min_event_ndims)
 
     if _use_static_shape(x, event_ndims):
@@ -283,21 +286,14 @@ class Chain(bijector.Bijector):
           x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
       if _use_static_shape(x, event_ndims):
         event_shape = b.forward_event_shape(event_shape)
-        event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
+        event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims)
       else:
         event_shape = b.forward_event_shape_tensor(event_shape)
-        event_ndims = self._maybe_get_event_ndims_statically(
-            array_ops.size(event_shape))
+        event_ndims = array_ops.size(event_shape)
+        event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+        if event_ndims_ is not None:
+          event_ndims = event_ndims_
 
       x = b.forward(x, **kwargs.get(b.name, {}))
 
     return fldj
-
-  def _maybe_get_event_ndims_statically(self, event_ndims):
-    event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
-        event_ndims)
-    if event_ndims_ is None:
-      return event_ndims
-    return event_ndims_
-
-
index 10b4536..3598c8d 100644 (file)
@@ -20,7 +20,6 @@ from __future__ import print_function
 from tensorflow.contrib.distributions.python.ops import conditional_distribution
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.distributions import transformed_distribution
@@ -106,7 +105,7 @@ class ConditionalTransformedDistribution(
     bijector_kwargs = bijector_kwargs or {}
     distribution_kwargs = distribution_kwargs or {}
     x = self.bijector.inverse(y, **bijector_kwargs)
-    event_ndims = self._maybe_get_event_ndims_statically()
+    event_ndims = self._maybe_get_static_event_ndims()
     ildj = self.bijector.inverse_log_det_jacobian(
         y, event_ndims=event_ndims, **bijector_kwargs)
     if self.bijector._is_injective:  # pylint: disable=protected-access
@@ -131,7 +130,7 @@ class ConditionalTransformedDistribution(
     bijector_kwargs = bijector_kwargs or {}
     distribution_kwargs = distribution_kwargs or {}
     x = self.bijector.inverse(y, **bijector_kwargs)
-    event_ndims = self._maybe_get_event_ndims_statically()
+    event_ndims = self._maybe_get_static_event_ndims()
     ildj = self.bijector.inverse_log_det_jacobian(
         y, event_ndims=event_ndims, **bijector_kwargs)
     if self.bijector._is_injective:  # pylint: disable=protected-access
@@ -220,14 +219,14 @@ class ConditionalTransformedDistribution(
     inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
     return self.bijector.forward(inv_cdf, **bijector_kwargs)
 
-  def _maybe_get_event_ndims_statically(self):
+  def _maybe_get_static_event_ndims(self):
     if self.event_shape.ndims is not None:
       return self.event_shape.ndims
 
     event_ndims = array_ops.size(self.event_shape_tensor())
-    static_event_ndims = tensor_util.constant_value(event_ndims)
+    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
 
-    if static_event_ndims is not None:
-      return static_event_ndims
+    if event_ndims_ is not None:
+      return event_ndims_
 
     return event_ndims
index a7fe336..8b11556 100644 (file)
@@ -90,9 +90,10 @@ class IntentionallyMissingError(Exception):
 class BrokenBijector(bijector.Bijector):
   """Forward and inverse are not inverses of each other."""
 
-  def __init__(self, forward_missing=False, inverse_missing=False):
+  def __init__(
+      self, forward_missing=False, inverse_missing=False, validate_args=False):
     super(BrokenBijector, self).__init__(
-        validate_args=False, forward_min_event_ndims=0, name="broken")
+        validate_args=validate_args, forward_min_event_ndims=0, name="broken")
     self._forward_missing = forward_missing
     self._inverse_missing = inverse_missing
 
@@ -116,6 +117,33 @@ class BrokenBijector(bijector.Bijector):
       raise IntentionallyMissingError
     return math_ops.log(2.)
 
+class BijectorTestEventNdims(test.TestCase):
+
+  def testBijectorNonIntegerEventNdims(self):
+    bij = BrokenBijector()
+    with self.assertRaisesRegexp(ValueError, "Expected integer"):
+      bij.forward_log_det_jacobian(1., event_ndims=1.5)
+    with self.assertRaisesRegexp(ValueError, "Expected integer"):
+      bij.inverse_log_det_jacobian(1., event_ndims=1.5)
+
+  def testBijectorArrayEventNdims(self):
+    bij = BrokenBijector()
+    with self.assertRaisesRegexp(ValueError, "Expected scalar"):
+      bij.forward_log_det_jacobian(1., event_ndims=(1, 2))
+    with self.assertRaisesRegexp(ValueError, "Expected scalar"):
+      bij.inverse_log_det_jacobian(1., event_ndims=(1, 2))
+
+  def testBijectorDynamicEventNdims(self):
+    bij = BrokenBijector(validate_args=True)
+    event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
+    with self.test_session():
+      with self.assertRaisesOpError("Expected scalar"):
+        bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
+            event_ndims: (1, 2)})
+      with self.assertRaisesOpError("Expected scalar"):
+        bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({
+            event_ndims: (1, 2)})
+
 
 @six.add_metaclass(abc.ABCMeta)
 class BijectorCachingTestBase(object):
index caceadf..969553b 100644 (file)
@@ -1021,7 +1021,7 @@ class Bijector(object):
         axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
     # The multiplication by ones can change the inferred static shape so we try
     # to recover as much as possible.
-    event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
     if (event_ndims_ is not None and
         y.shape.ndims is not None and
         ildj.shape.ndims is not None):
@@ -1036,7 +1036,7 @@ class Bijector(object):
 
   def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
     """Compute the reduction dimensions given event_ndims."""
-    event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
 
     if event_ndims_ is not None:
       return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
@@ -1046,9 +1046,18 @@ class Bijector(object):
 
   def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
     """Check whether event_ndims is atleast min_event_ndims."""
-    event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+    event_ndims_ = tensor_util.constant_value(event_ndims)
     assertions = []
+
+    if not event_ndims.dtype.is_integer:
+      raise ValueError("Expected integer dtype, got dtype {}".format(
+          event_ndims.dtype))
+
     if event_ndims_ is not None:
+      if event_ndims.shape.ndims != 0:
+        raise ValueError("Expected scalar event_ndims, got shape {}".format(
+            event_ndims.shape))
       if min_event_ndims > event_ndims_:
         raise ValueError("event_ndims ({}) must be larger than "
                          "min_event_ndims ({})".format(
@@ -1056,17 +1065,29 @@ class Bijector(object):
     elif self.validate_args:
       assertions += [
           check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
+
+    if event_ndims.shape.is_fully_defined():
+      if event_ndims.shape.ndims != 0:
+        raise ValueError("Expected scalar shape, got ndims {}".format(
+            event_ndims.shape.ndims))
+
+    elif self.validate_args:
+      assertions += [
+          check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
     return assertions
 
-  def _maybe_get_event_ndims_statically(self, event_ndims):
+  def _maybe_get_static_event_ndims(self, event_ndims):
     """Helper which returns tries to return an integer static value."""
     event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
 
-    if isinstance(event_ndims_, np.ndarray):
-      if (event_ndims_.dtype not in (np.int32, np.int64) or
-          len(event_ndims_.shape)):
+    if isinstance(event_ndims_, (np.generic, np.ndarray)):
+      if event_ndims_.dtype not in (np.int32, np.int64):
+        raise ValueError("Expected integer dtype, got dtype {}".format(
+            event_ndims_.dtype))
+
+      if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
         raise ValueError("Expected a scalar integer, got {}".format(
             event_ndims_))
-      event_ndims_ = event_ndims_.tolist()
+      event_ndims_ = int(event_ndims_)
 
     return event_ndims_
index 9392464..c2674bd 100644 (file)
@@ -416,7 +416,7 @@ class TransformedDistribution(distribution_lib.Distribution):
     # For caching to work, it is imperative that the bijector is the first to
     # modify the input.
     x = self.bijector.inverse(y)
-    event_ndims = self._maybe_get_event_ndims_statically()
+    event_ndims = self._maybe_get_static_event_ndims()
 
     ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
     if self.bijector._is_injective:  # pylint: disable=protected-access
@@ -435,13 +435,15 @@ class TransformedDistribution(distribution_lib.Distribution):
       log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
     log_prob += math_ops.cast(ildj, log_prob.dtype)
     if self._is_maybe_event_override and isinstance(event_ndims, int):
-      log_prob.set_shape(array_ops.broadcast_static_shape(
-          x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
+      log_prob.set_shape(
+          array_ops.broadcast_static_shape(
+              y.get_shape().with_rank_at_least(1)[:-event_ndims],
+              self.batch_shape))
     return log_prob
 
   def _prob(self, y):
     x = self.bijector.inverse(y)
-    event_ndims = self._maybe_get_event_ndims_statically()
+    event_ndims = self._maybe_get_static_event_ndims()
     ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
     if self.bijector._is_injective:  # pylint: disable=protected-access
       return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims)
@@ -459,8 +461,10 @@ class TransformedDistribution(distribution_lib.Distribution):
       prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
     prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
     if self._is_maybe_event_override and isinstance(event_ndims, int):
-      prob.set_shape(array_ops.broadcast_static_shape(
-          y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
+      prob.set_shape(
+          array_ops.broadcast_static_shape(
+              y.get_shape().with_rank_at_least(1)[:-event_ndims],
+              self.batch_shape))
     return prob
 
   def _log_cdf(self, y):
@@ -618,15 +622,14 @@ class TransformedDistribution(distribution_lib.Distribution):
     return array_ops.transpose(
         x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
 
-  def _maybe_get_event_ndims_statically(self):
+  def _maybe_get_static_event_ndims(self):
     if self.event_shape.ndims is not None:
       return self.event_shape.ndims
 
     event_ndims = array_ops.size(self.event_shape_tensor())
+    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
 
-    static_event_ndims = tensor_util.constant_value(event_ndims)
-
-    if static_event_ndims is not None:
-      return static_event_ndims
+    if event_ndims_ is not None:
+      return event_ndims_
 
     return event_ndims
index 59c89d2..728fda2 100644 (file)
@@ -179,6 +179,7 @@ def maybe_get_static_value(x, dtype=None):
   if x is None:
     return x
   try:
+    # This returns an np.ndarray.
     x_ = tensor_util.constant_value(x)
   except TypeError:
     x_ = x