from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib
+from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
batch_shape=new_batch_shape_ph,
validate_args=True).sample().eval()
+ def test_broadcasting_explicitly_unsupported(self):
+ old_batch_shape = [4]
+ new_batch_shape = [1, 4, 1]
+ rate_ = self.dtype([1, 10, 2, 20])
+
+ rate = array_ops.placeholder_with_default(
+ rate_,
+ shape=old_batch_shape if self.is_static_shape else None)
+ poisson_4 = poisson_lib.Poisson(rate)
+ new_batch_shape_ph = (
+ constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+ else array_ops.placeholder_with_default(
+ np.int32(new_batch_shape), shape=None))
+ poisson_141_reshaped = batch_reshape_lib.BatchReshape(
+ poisson_4, new_batch_shape_ph, validate_args=True)
+
+ x_4 = self.dtype([2, 12, 3, 23])
+ x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4)
+
+ if self.is_static_shape:
+ with self.assertRaisesRegexp(NotImplementedError,
+ "too few event dims"):
+ poisson_141_reshaped.log_prob(x_4)
+ with self.assertRaisesRegexp(NotImplementedError,
+ "unexpected batch and event shape"):
+ poisson_141_reshaped.log_prob(x_114)
+ return
+
+ with self.assertRaisesOpError("too few event dims"):
+ with self.test_session():
+ poisson_141_reshaped.log_prob(x_4).eval()
+
+ with self.assertRaisesOpError("unexpected batch and event shape"):
+ with self.test_session():
+ poisson_141_reshaped.log_prob(x_114).eval()
+
class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase):
self._batch_shape_static = tensor_util.constant_value(self._batch_shape_)
if self._batch_shape_static is not None:
self._batch_shape_static = np.int32(self._batch_shape_static)
- self._runtime_assertions = make_runtime_assertions(
+ self._runtime_assertions = validate_init_args(
self._distribution,
self._batch_shape_,
validate_args,
def _call_reshape_input_output(self, fn, x):
"""Calls `fn`, appropriately reshaping its input `x` and output."""
- with ops.control_dependencies(self._runtime_assertions):
+ with ops.control_dependencies(
+ self._runtime_assertions + self._validate_sample_arg(x)):
sample_shape, static_sample_shape = self._sample_shape(x)
old_shape = array_ops.concat([
sample_shape,
result.set_shape(static_shape)
return result
-
-def make_runtime_assertions(
+ def _validate_sample_arg(self, x):
+ """Helper which validates sample arg, e.g., input to `log_prob`."""
+ with ops.name_scope(name="validate_sample_arg", values=[x]):
+ x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims)
+ event_ndims = (array_ops.size(self.event_shape_tensor())
+ if self.event_shape.ndims is None
+ else self.event_shape.ndims)
+ batch_ndims = (array_ops.size(self.batch_shape_tensor())
+ if self.batch_shape.ndims is None
+ else self.batch_shape.ndims)
+ expected_batch_event_ndims = batch_ndims + event_ndims
+
+ if (isinstance(x_ndims, int) and
+ isinstance(expected_batch_event_ndims, int)):
+ if x_ndims < expected_batch_event_ndims:
+ raise NotImplementedError(
+ "Broadcasting is not supported; too few event dims "
+ "(expected at least {}, saw {}).".format(
+ expected_batch_event_ndims, x_ndims))
+ ndims_assertion = []
+ elif self.validate_args:
+ ndims_assertion = [
+ check_ops.assert_greater_equal(
+ x_ndims,
+ expected_batch_event_ndims,
+ message="Broadcasting is not supported; too few event dims.",
+ name="assert_batch_and_event_ndims_large_enough"),
+ ]
+
+ if (self.batch_shape.is_fully_defined() and
+ self.event_shape.is_fully_defined()):
+ expected_batch_event_shape = np.int32(self.batch_shape.concatenate(
+ self.event_shape).as_list())
+ else:
+ expected_batch_event_shape = array_ops.concat([
+ self.batch_shape_tensor(),
+ self.event_shape_tensor(),
+ ], axis=0)
+
+ sample_ndims = x_ndims - expected_batch_event_ndims
+ if isinstance(sample_ndims, int):
+ sample_ndims = max(sample_ndims, 0)
+ if (isinstance(sample_ndims, int) and
+ x.shape[sample_ndims:].is_fully_defined()):
+ actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list())
+ else:
+ sample_ndims = math_ops.maximum(sample_ndims, 0)
+ actual_batch_event_shape = array_ops.shape(x)[sample_ndims:]
+
+ if (isinstance(expected_batch_event_shape, np.ndarray) and
+ isinstance(actual_batch_event_shape, np.ndarray)):
+ if any(expected_batch_event_shape != actual_batch_event_shape):
+ raise NotImplementedError("Broadcasting is not supported; "
+ "unexpected batch and event shape "
+ "(expected {}, saw {}).".format(
+ expected_batch_event_shape,
+ actual_batch_event_shape))
+ # We need to set the final runtime-assertions to `ndims_assertion` since
+ # its possible this assertion was created. We could add a condition to
+ # only do so if `self.validate_args == True`, however this is redundant
+ # as `ndims_assertion` already encodes this information.
+ runtime_assertions = ndims_assertion
+ elif self.validate_args:
+ # We need to make the `ndims_assertion` a control dep because otherwise
+ # TF itself might raise an exception owing to this assertion being
+ # ill-defined, ie, one cannot even compare different rank Tensors.
+ with ops.control_dependencies(ndims_assertion):
+ shape_assertion = check_ops.assert_equal(
+ expected_batch_event_shape,
+ actual_batch_event_shape,
+ message=("Broadcasting is not supported; "
+ "unexpected batch and event shape."),
+ name="assert_batch_and_event_shape_same")
+ runtime_assertions = [shape_assertion]
+ else:
+ runtime_assertions = []
+
+ return runtime_assertions
+
+
+def validate_init_args(
distribution,
batch_shape,
validate_args,
batch_shape_static):
"""Helper to __init__ which makes or raises assertions."""
- runtime_assertions = []
-
- if batch_shape.shape.ndims is not None:
- if batch_shape.shape.ndims != 1:
- raise ValueError("`batch_shape` must be a vector "
- "(saw rank: {}).".format(
- batch_shape.shape.ndims))
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_rank(
- batch_shape,
- 1,
- message="`batch_shape` must be a vector.",
- name="assert_batch_shape_is_vector"),
- ]
-
- batch_size_static = np.prod(batch_shape_static)
- dist_batch_size_static = (
- None if not distribution.batch_shape.is_fully_defined()
- else np.prod(distribution.batch_shape).value)
-
- if batch_size_static is not None and dist_batch_size_static is not None:
- if batch_size_static != dist_batch_size_static:
- raise ValueError("`batch_shape` size ({}) must match "
- "`distribution.batch_shape` size ({}).".format(
- batch_size_static,
- dist_batch_size_static))
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_equal(
- math_ops.reduce_prod(batch_shape),
- math_ops.reduce_prod(distribution.batch_shape_tensor()),
- message=("`batch_shape` size must match "
- "`distributions.batch_shape` size."),
- name="assert_batch_size"),
- ]
-
- if batch_shape_static is not None:
- if np.any(batch_shape_static < 1):
- raise ValueError("`batch_shape` elements must be positive "
- "(i.e., larger than zero).")
- elif validate_args:
- runtime_assertions += [
- check_ops.assert_positive(
- batch_shape,
- message=("`batch_shape` elements must be positive "
- "(i.e., larger than zero)."),
- name="assert_batch_shape_positive")
- ]
-
- return runtime_assertions
+ with ops.name_scope(name="validate_init_args",
+ values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access
+ runtime_assertions = []
+
+ if batch_shape.shape.ndims is not None:
+ if batch_shape.shape.ndims != 1:
+ raise ValueError("`batch_shape` must be a vector "
+ "(saw rank: {}).".format(
+ batch_shape.shape.ndims))
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_rank(
+ batch_shape,
+ 1,
+ message="`batch_shape` must be a vector.",
+ name="assert_batch_shape_is_vector"),
+ ]
+
+ batch_size_static = np.prod(batch_shape_static)
+ dist_batch_size_static = (
+ None if not distribution.batch_shape.is_fully_defined()
+ else np.prod(distribution.batch_shape).value)
+
+ if batch_size_static is not None and dist_batch_size_static is not None:
+ if batch_size_static != dist_batch_size_static:
+ raise ValueError("`batch_shape` size ({}) must match "
+ "`distribution.batch_shape` size ({}).".format(
+ batch_size_static,
+ dist_batch_size_static))
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_equal(
+ math_ops.reduce_prod(batch_shape),
+ math_ops.reduce_prod(distribution.batch_shape_tensor()),
+ message=("`batch_shape` size must match "
+ "`distributions.batch_shape` size."),
+ name="assert_batch_size"),
+ ]
+
+ if batch_shape_static is not None:
+ if np.any(batch_shape_static < 1):
+ raise ValueError("`batch_shape` elements must be positive "
+ "(i.e., larger than zero).")
+ elif validate_args:
+ runtime_assertions += [
+ check_ops.assert_positive(
+ batch_shape,
+ message=("`batch_shape` elements must be positive "
+ "(i.e., larger than zero)."),
+ name="assert_batch_shape_positive")
+ ]
+
+ return runtime_assertions