From 4b563ed0008953519a0ad9ec09a3261f1d3759dd Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Wed, 4 Apr 2018 17:28:41 -0700 Subject: [PATCH] BUGFIX: Detect when broadcasting is required and raise NotImplementedError. PiperOrigin-RevId: 191673876 --- .../python/kernel_tests/batch_reshape_test.py | 37 ++++ .../distributions/python/ops/batch_reshape.py | 192 +++++++++++++++------ 2 files changed, 174 insertions(+), 55 deletions(-) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 4d2f40e..c6c8d2c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops @@ -514,6 +515,42 @@ class _BatchReshapeTest(object): batch_shape=new_batch_shape_ph, validate_args=True).sample().eval() + def test_broadcasting_explicitly_unsupported(self): + old_batch_shape = [4] + new_batch_shape = [1, 4, 1] + rate_ = self.dtype([1, 10, 2, 20]) + + rate = array_ops.placeholder_with_default( + rate_, + shape=old_batch_shape if self.is_static_shape else None) + poisson_4 = poisson_lib.Poisson(rate) + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + poisson_141_reshaped = batch_reshape_lib.BatchReshape( + poisson_4, new_batch_shape_ph, validate_args=True) + + x_4 = self.dtype([2, 12, 3, 23]) + x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) + + if self.is_static_shape: + with self.assertRaisesRegexp(NotImplementedError, + "too few event dims"): + poisson_141_reshaped.log_prob(x_4) + with self.assertRaisesRegexp(NotImplementedError, + "unexpected batch and event shape"): + poisson_141_reshaped.log_prob(x_114) + return + + with self.assertRaisesOpError("too few event dims"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_4).eval() + + with self.assertRaisesOpError("unexpected batch and event shape"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_114).eval() + class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase): diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index c7ee9b2..3e6c35e 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -115,7 +115,7 @@ class BatchReshape(distribution_lib.Distribution): self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) if self._batch_shape_static is not None: self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = make_runtime_assertions( + self._runtime_assertions = validate_init_args( self._distribution, self._batch_shape_, validate_args, @@ -229,7 +229,8 @@ class BatchReshape(distribution_lib.Distribution): def _call_reshape_input_output(self, fn, x): """Calls `fn`, appropriately reshaping its input `x` and output.""" - with ops.control_dependencies(self._runtime_assertions): + with ops.control_dependencies( + self._runtime_assertions + self._validate_sample_arg(x)): sample_shape, static_sample_shape = self._sample_shape(x) old_shape = array_ops.concat([ sample_shape, @@ -273,61 +274,142 @@ class BatchReshape(distribution_lib.Distribution): result.set_shape(static_shape) return result - -def make_runtime_assertions( + def _validate_sample_arg(self, x): + """Helper which validates sample arg, e.g., input to `log_prob`.""" + with ops.name_scope(name="validate_sample_arg", values=[x]): + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = (array_ops.size(self.batch_shape_tensor()) + if self.batch_shape.ndims is None + else self.batch_shape.ndims) + expected_batch_event_ndims = batch_ndims + event_ndims + + if (isinstance(x_ndims, int) and + isinstance(expected_batch_event_ndims, int)): + if x_ndims < expected_batch_event_ndims: + raise NotImplementedError( + "Broadcasting is not supported; too few event dims " + "(expected at least {}, saw {}).".format( + expected_batch_event_ndims, x_ndims)) + ndims_assertion = [] + elif self.validate_args: + ndims_assertion = [ + check_ops.assert_greater_equal( + x_ndims, + expected_batch_event_ndims, + message="Broadcasting is not supported; too few event dims.", + name="assert_batch_and_event_ndims_large_enough"), + ] + + if (self.batch_shape.is_fully_defined() and + self.event_shape.is_fully_defined()): + expected_batch_event_shape = np.int32(self.batch_shape.concatenate( + self.event_shape).as_list()) + else: + expected_batch_event_shape = array_ops.concat([ + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + + sample_ndims = x_ndims - expected_batch_event_ndims + if isinstance(sample_ndims, int): + sample_ndims = max(sample_ndims, 0) + if (isinstance(sample_ndims, int) and + x.shape[sample_ndims:].is_fully_defined()): + actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list()) + else: + sample_ndims = math_ops.maximum(sample_ndims, 0) + actual_batch_event_shape = array_ops.shape(x)[sample_ndims:] + + if (isinstance(expected_batch_event_shape, np.ndarray) and + isinstance(actual_batch_event_shape, np.ndarray)): + if any(expected_batch_event_shape != actual_batch_event_shape): + raise NotImplementedError("Broadcasting is not supported; " + "unexpected batch and event shape " + "(expected {}, saw {}).".format( + expected_batch_event_shape, + actual_batch_event_shape)) + # We need to set the final runtime-assertions to `ndims_assertion` since + # its possible this assertion was created. We could add a condition to + # only do so if `self.validate_args == True`, however this is redundant + # as `ndims_assertion` already encodes this information. + runtime_assertions = ndims_assertion + elif self.validate_args: + # We need to make the `ndims_assertion` a control dep because otherwise + # TF itself might raise an exception owing to this assertion being + # ill-defined, ie, one cannot even compare different rank Tensors. + with ops.control_dependencies(ndims_assertion): + shape_assertion = check_ops.assert_equal( + expected_batch_event_shape, + actual_batch_event_shape, + message=("Broadcasting is not supported; " + "unexpected batch and event shape."), + name="assert_batch_and_event_shape_same") + runtime_assertions = [shape_assertion] + else: + runtime_assertions = [] + + return runtime_assertions + + +def validate_init_args( distribution, batch_shape, validate_args, batch_shape_static): """Helper to __init__ which makes or raises assertions.""" - runtime_assertions = [] - - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] - - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) - - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + with ops.name_scope(name="validate_init_args", + values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access + runtime_assertions = [] + + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format( + batch_shape.shape.ndims)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_rank( + batch_shape, + 1, + message="`batch_shape` must be a vector.", + name="assert_batch_shape_is_vector"), + ] + + batch_size_static = np.prod(batch_shape_static) + dist_batch_size_static = ( + None if not distribution.batch_shape.is_fully_defined() + else np.prod(distribution.batch_shape).value) + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, + dist_batch_size_static)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_equal( + math_ops.reduce_prod(batch_shape), + math_ops.reduce_prod(distribution.batch_shape_tensor()), + message=("`batch_shape` size must match " + "`distributions.batch_shape` size."), + name="assert_batch_size"), + ] + + if batch_shape_static is not None: + if np.any(batch_shape_static < 1): + raise ValueError("`batch_shape` elements must be positive " + "(i.e., larger than zero).") + elif validate_args: + runtime_assertions += [ + check_ops.assert_positive( + batch_shape, + message=("`batch_shape` elements must be positive " + "(i.e., larger than zero)."), + name="assert_batch_shape_positive") + ] + + return runtime_assertions -- 2.7.4