From: A. Unique TensorFlower Date: Thu, 24 May 2018 03:03:20 +0000 (-0700) Subject: Set the correct shape in transformed distribution. X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~135 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=42e50daa384183d2f64e0ab5ae3f9bed07128e07;p=platform%2Fupstream%2Ftensorflow.git Set the correct shape in transformed distribution. Also add distribution_util.maybe_get_static_event_ndims to be reused in bijector and transformed distribution classes. PiperOrigin-RevId: 197831651 --- diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 8b279eb..f8a5261 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase): for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1., event_ndims=0., arg1="b1", arg2="b2") + method(1., event_ndims=0, arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index b158a51..16f9595 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -234,7 +234,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return ildj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -248,12 +248,15 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -270,7 +273,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return fldj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -283,21 +286,14 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj - - def _maybe_get_event_ndims_statically(self, event_ndims): - event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( - event_ndims) - if event_ndims_ is None: - return event_ndims - return event_ndims_ - - diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 10b4536..3598c8d 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -106,7 +105,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -131,7 +130,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -220,14 +219,14 @@ class ConditionalTransformedDistribution( inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) - static_event_ndims = tensor_util.constant_value(event_ndims) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py index a7fe336..8b11556 100644 --- a/tensorflow/python/kernel_tests/distributions/bijector_test.py +++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py @@ -90,9 +90,10 @@ class IntentionallyMissingError(Exception): class BrokenBijector(bijector.Bijector): """Forward and inverse are not inverses of each other.""" - def __init__(self, forward_missing=False, inverse_missing=False): + def __init__( + self, forward_missing=False, inverse_missing=False, validate_args=False): super(BrokenBijector, self).__init__( - validate_args=False, forward_min_event_ndims=0, name="broken") + validate_args=validate_args, forward_min_event_ndims=0, name="broken") self._forward_missing = forward_missing self._inverse_missing = inverse_missing @@ -116,6 +117,33 @@ class BrokenBijector(bijector.Bijector): raise IntentionallyMissingError return math_ops.log(2.) +class BijectorTestEventNdims(test.TestCase): + + def testBijectorNonIntegerEventNdims(self): + bij = BrokenBijector() + with self.assertRaisesRegexp(ValueError, "Expected integer"): + bij.forward_log_det_jacobian(1., event_ndims=1.5) + with self.assertRaisesRegexp(ValueError, "Expected integer"): + bij.inverse_log_det_jacobian(1., event_ndims=1.5) + + def testBijectorArrayEventNdims(self): + bij = BrokenBijector() + with self.assertRaisesRegexp(ValueError, "Expected scalar"): + bij.forward_log_det_jacobian(1., event_ndims=(1, 2)) + with self.assertRaisesRegexp(ValueError, "Expected scalar"): + bij.inverse_log_det_jacobian(1., event_ndims=(1, 2)) + + def testBijectorDynamicEventNdims(self): + bij = BrokenBijector(validate_args=True) + event_ndims = array_ops.placeholder(dtype=np.int32, shape=None) + with self.test_session(): + with self.assertRaisesOpError("Expected scalar"): + bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({ + event_ndims: (1, 2)}) + with self.assertRaisesOpError("Expected scalar"): + bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({ + event_ndims: (1, 2)}) + @six.add_metaclass(abc.ABCMeta) class BijectorCachingTestBase(object): diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index caceadf..969553b 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -1021,7 +1021,7 @@ class Bijector(object): axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. - event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if (event_ndims_ is not None and y.shape.ndims is not None and ildj.shape.ndims is not None): @@ -1036,7 +1036,7 @@ class Bijector(object): def _get_event_reduce_dims(self, min_event_ndims, event_ndims): """Compute the reduction dimensions given event_ndims.""" - event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if event_ndims_ is not None: return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)] @@ -1046,9 +1046,18 @@ class Bijector(object): def _check_valid_event_ndims(self, min_event_ndims, event_ndims): """Check whether event_ndims is atleast min_event_ndims.""" - event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_ = tensor_util.constant_value(event_ndims) assertions = [] + + if not event_ndims.dtype.is_integer: + raise ValueError("Expected integer dtype, got dtype {}".format( + event_ndims.dtype)) + if event_ndims_ is not None: + if event_ndims.shape.ndims != 0: + raise ValueError("Expected scalar event_ndims, got shape {}".format( + event_ndims.shape)) if min_event_ndims > event_ndims_: raise ValueError("event_ndims ({}) must be larger than " "min_event_ndims ({})".format( @@ -1056,17 +1065,29 @@ class Bijector(object): elif self.validate_args: assertions += [ check_ops.assert_greater_equal(event_ndims, min_event_ndims)] + + if event_ndims.shape.is_fully_defined(): + if event_ndims.shape.ndims != 0: + raise ValueError("Expected scalar shape, got ndims {}".format( + event_ndims.shape.ndims)) + + elif self.validate_args: + assertions += [ + check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")] return assertions - def _maybe_get_event_ndims_statically(self, event_ndims): + def _maybe_get_static_event_ndims(self, event_ndims): """Helper which returns tries to return an integer static value.""" event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if isinstance(event_ndims_, np.ndarray): - if (event_ndims_.dtype not in (np.int32, np.int64) or - len(event_ndims_.shape)): + if isinstance(event_ndims_, (np.generic, np.ndarray)): + if event_ndims_.dtype not in (np.int32, np.int64): + raise ValueError("Expected integer dtype, got dtype {}".format( + event_ndims_.dtype)) + + if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape): raise ValueError("Expected a scalar integer, got {}".format( event_ndims_)) - event_ndims_ = event_ndims_.tolist() + event_ndims_ = int(event_ndims_) return event_ndims_ diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py index 9392464..c2674bd 100644 --- a/tensorflow/python/ops/distributions/transformed_distribution.py +++ b/tensorflow/python/ops/distributions/transformed_distribution.py @@ -416,7 +416,7 @@ class TransformedDistribution(distribution_lib.Distribution): # For caching to work, it is imperative that the bijector is the first to # modify the input. x = self.bijector.inverse(y) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) if self.bijector._is_injective: # pylint: disable=protected-access @@ -435,13 +435,15 @@ class TransformedDistribution(distribution_lib.Distribution): log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) log_prob += math_ops.cast(ildj, log_prob.dtype) if self._is_maybe_event_override and isinstance(event_ndims, int): - log_prob.set_shape(array_ops.broadcast_static_shape( - x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) + log_prob.set_shape( + array_ops.broadcast_static_shape( + y.get_shape().with_rank_at_least(1)[:-event_ndims], + self.batch_shape)) return log_prob def _prob(self, y): x = self.bijector.inverse(y) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims) @@ -459,8 +461,10 @@ class TransformedDistribution(distribution_lib.Distribution): prob = math_ops.reduce_prod(prob, self._reduce_event_indices) prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): - prob.set_shape(array_ops.broadcast_static_shape( - y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) + prob.set_shape( + array_ops.broadcast_static_shape( + y.get_shape().with_rank_at_least(1)[:-event_ndims], + self.batch_shape)) return prob def _log_cdf(self, y): @@ -618,15 +622,14 @@ class TransformedDistribution(distribution_lib.Distribution): return array_ops.transpose( x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n))) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - static_event_ndims = tensor_util.constant_value(event_ndims) - - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 59c89d2..728fda2 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -179,6 +179,7 @@ def maybe_get_static_value(x, dtype=None): if x is None: return x try: + # This returns an np.ndarray. x_ = tensor_util.constant_value(x) except TypeError: x_ = x