for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]:
method = getattr(b, name)
with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"):
- method(1., event_ndims=0., arg1="b1", arg2="b2")
+ method(1., event_ndims=0, arg1="b1", arg2="b2")
if __name__ == "__main__":
if not self.bijectors:
return ildj
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
self.inverse_min_event_ndims)
if _use_static_shape(y, event_ndims):
if _use_static_shape(y, event_ndims):
event_shape = b.inverse_event_shape(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
event_shape.ndims)
else:
event_shape = b.inverse_event_shape_tensor(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
- array_ops.size(event_shape))
+ event_ndims = array_ops.size(event_shape)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+ if event_ndims_ is not None:
+ event_ndims = event_ndims_
+
y = b.inverse(y, **kwargs.get(b.name, {}))
return ildj
if not self.bijectors:
return fldj
- event_ndims = self._maybe_get_event_ndims_statically(
+ event_ndims = self._maybe_get_static_event_ndims(
self.forward_min_event_ndims)
if _use_static_shape(x, event_ndims):
x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
if _use_static_shape(x, event_ndims):
event_shape = b.forward_event_shape(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims)
+ event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims)
else:
event_shape = b.forward_event_shape_tensor(event_shape)
- event_ndims = self._maybe_get_event_ndims_statically(
- array_ops.size(event_shape))
+ event_ndims = array_ops.size(event_shape)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
+ if event_ndims_ is not None:
+ event_ndims = event_ndims_
x = b.forward(x, **kwargs.get(b.name, {}))
return fldj
-
- def _maybe_get_event_ndims_statically(self, event_ndims):
- event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically(
- event_ndims)
- if event_ndims_ is None:
- return event_ndims
- return event_ndims_
-
-
from tensorflow.contrib.distributions.python.ops import conditional_distribution
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import transformed_distribution
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
bijector_kwargs = bijector_kwargs or {}
distribution_kwargs = distribution_kwargs or {}
x = self.bijector.inverse(y, **bijector_kwargs)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(
y, event_ndims=event_ndims, **bijector_kwargs)
if self.bijector._is_injective: # pylint: disable=protected-access
inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
return self.bijector.forward(inv_cdf, **bijector_kwargs)
- def _maybe_get_event_ndims_statically(self):
+ def _maybe_get_static_event_ndims(self):
if self.event_shape.ndims is not None:
return self.event_shape.ndims
event_ndims = array_ops.size(self.event_shape_tensor())
- static_event_ndims = tensor_util.constant_value(event_ndims)
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- if static_event_ndims is not None:
- return static_event_ndims
+ if event_ndims_ is not None:
+ return event_ndims_
return event_ndims
class BrokenBijector(bijector.Bijector):
"""Forward and inverse are not inverses of each other."""
- def __init__(self, forward_missing=False, inverse_missing=False):
+ def __init__(
+ self, forward_missing=False, inverse_missing=False, validate_args=False):
super(BrokenBijector, self).__init__(
- validate_args=False, forward_min_event_ndims=0, name="broken")
+ validate_args=validate_args, forward_min_event_ndims=0, name="broken")
self._forward_missing = forward_missing
self._inverse_missing = inverse_missing
raise IntentionallyMissingError
return math_ops.log(2.)
+class BijectorTestEventNdims(test.TestCase):
+
+ def testBijectorNonIntegerEventNdims(self):
+ bij = BrokenBijector()
+ with self.assertRaisesRegexp(ValueError, "Expected integer"):
+ bij.forward_log_det_jacobian(1., event_ndims=1.5)
+ with self.assertRaisesRegexp(ValueError, "Expected integer"):
+ bij.inverse_log_det_jacobian(1., event_ndims=1.5)
+
+ def testBijectorArrayEventNdims(self):
+ bij = BrokenBijector()
+ with self.assertRaisesRegexp(ValueError, "Expected scalar"):
+ bij.forward_log_det_jacobian(1., event_ndims=(1, 2))
+ with self.assertRaisesRegexp(ValueError, "Expected scalar"):
+ bij.inverse_log_det_jacobian(1., event_ndims=(1, 2))
+
+ def testBijectorDynamicEventNdims(self):
+ bij = BrokenBijector(validate_args=True)
+ event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
+ with self.test_session():
+ with self.assertRaisesOpError("Expected scalar"):
+ bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
+ event_ndims: (1, 2)})
+ with self.assertRaisesOpError("Expected scalar"):
+ bij.inverse_log_det_jacobian(1., event_ndims=event_ndims).eval({
+ event_ndims: (1, 2)})
+
@six.add_metaclass(abc.ABCMeta)
class BijectorCachingTestBase(object):
axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
# The multiplication by ones can change the inferred static shape so we try
# to recover as much as possible.
- event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
if (event_ndims_ is not None and
y.shape.ndims is not None and
ildj.shape.ndims is not None):
def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
"""Compute the reduction dimensions given event_ndims."""
- event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+ event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
if event_ndims_ is not None:
return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
"""Check whether event_ndims is atleast min_event_ndims."""
- event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims)
+ event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
+ event_ndims_ = tensor_util.constant_value(event_ndims)
assertions = []
+
+ if not event_ndims.dtype.is_integer:
+ raise ValueError("Expected integer dtype, got dtype {}".format(
+ event_ndims.dtype))
+
if event_ndims_ is not None:
+ if event_ndims.shape.ndims != 0:
+ raise ValueError("Expected scalar event_ndims, got shape {}".format(
+ event_ndims.shape))
if min_event_ndims > event_ndims_:
raise ValueError("event_ndims ({}) must be larger than "
"min_event_ndims ({})".format(
elif self.validate_args:
assertions += [
check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
+
+ if event_ndims.shape.is_fully_defined():
+ if event_ndims.shape.ndims != 0:
+ raise ValueError("Expected scalar shape, got ndims {}".format(
+ event_ndims.shape.ndims))
+
+ elif self.validate_args:
+ assertions += [
+ check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
return assertions
- def _maybe_get_event_ndims_statically(self, event_ndims):
+ def _maybe_get_static_event_ndims(self, event_ndims):
"""Helper which returns tries to return an integer static value."""
event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- if isinstance(event_ndims_, np.ndarray):
- if (event_ndims_.dtype not in (np.int32, np.int64) or
- len(event_ndims_.shape)):
+ if isinstance(event_ndims_, (np.generic, np.ndarray)):
+ if event_ndims_.dtype not in (np.int32, np.int64):
+ raise ValueError("Expected integer dtype, got dtype {}".format(
+ event_ndims_.dtype))
+
+ if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
raise ValueError("Expected a scalar integer, got {}".format(
event_ndims_))
- event_ndims_ = event_ndims_.tolist()
+ event_ndims_ = int(event_ndims_)
return event_ndims_
# For caching to work, it is imperative that the bijector is the first to
# modify the input.
x = self.bijector.inverse(y)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
if self.bijector._is_injective: # pylint: disable=protected-access
log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
log_prob += math_ops.cast(ildj, log_prob.dtype)
if self._is_maybe_event_override and isinstance(event_ndims, int):
- log_prob.set_shape(array_ops.broadcast_static_shape(
- x.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
+ log_prob.set_shape(
+ array_ops.broadcast_static_shape(
+ y.get_shape().with_rank_at_least(1)[:-event_ndims],
+ self.batch_shape))
return log_prob
def _prob(self, y):
x = self.bijector.inverse(y)
- event_ndims = self._maybe_get_event_ndims_statically()
+ event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
if self.bijector._is_injective: # pylint: disable=protected-access
return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims)
prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
if self._is_maybe_event_override and isinstance(event_ndims, int):
- prob.set_shape(array_ops.broadcast_static_shape(
- y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
+ prob.set_shape(
+ array_ops.broadcast_static_shape(
+ y.get_shape().with_rank_at_least(1)[:-event_ndims],
+ self.batch_shape))
return prob
def _log_cdf(self, y):
return array_ops.transpose(
x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
- def _maybe_get_event_ndims_statically(self):
+ def _maybe_get_static_event_ndims(self):
if self.event_shape.ndims is not None:
return self.event_shape.ndims
event_ndims = array_ops.size(self.event_shape_tensor())
+ event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
- static_event_ndims = tensor_util.constant_value(event_ndims)
-
- if static_event_ndims is not None:
- return static_event_ndims
+ if event_ndims_ is not None:
+ return event_ndims_
return event_ndims
if x is None:
return x
try:
+ # This returns an np.ndarray.
x_ = tensor_util.constant_value(x)
except TypeError:
x_ = x