From: A. Unique TensorFlower Date: Wed, 16 May 2018 15:45:28 +0000 (-0700) Subject: Modify tf.contrib.distributions.BatchReshape to behave a bit more like X-Git-Tag: upstream/v1.9.0_rc1~106^2^2~55 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=61108cc05e3eb6463fbef5eba9d1ff7b1130263d;p=platform%2Fupstream%2Ftensorflow.git Modify tf.contrib.distributions.BatchReshape to behave a bit more like tf.reshape: accept a single unknown dimension and infer partial shape information statically. PiperOrigin-RevId: 196833267 --- diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 59d549b..f2bb2d3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -448,8 +448,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r"`batch_shape` size must match " - r"`distributions.batch_shape` size"): + with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -457,8 +456,13 @@ class _BatchReshapeTest(object): def test_non_positive_shape(self): dims = 2 - new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. - old_batch_shape = [2] + old_batch_shape = [4] + if self.is_static_shape: + # Unknown first dimension does not trigger size check. Note that + # any dimension < 0 is treated statically as unknown. + new_batch_shape = [-1, 0] + else: + new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = ( constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape @@ -471,7 +475,7 @@ class _BatchReshapeTest(object): mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: - with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -479,7 +483,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r".*must be positive.*"): + with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index 8a4041c..c709318 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -42,9 +42,6 @@ class BatchReshape(distribution_lib.Distribution): This "meta-distribution" reshapes the batch dimensions of another distribution. - Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support - `-1` for flattening. - #### Examples ```python @@ -52,7 +49,7 @@ class BatchReshape(distribution_lib.Distribution): dtype = np.float32 dims = 2 - new_batch_shape = [1, 2, 3] + new_batch_shape = [1, 2, -1] old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype) @@ -86,8 +83,9 @@ class BatchReshape(distribution_lib.Distribution): Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. - batch_shape: Positive `int`-like vector-shaped `Tensor` representing the - new shape of the batch dimensions. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing + the new shape of the batch dimensions. Up to one dimension may contain + `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -107,29 +105,26 @@ class BatchReshape(distribution_lib.Distribution): """ parameters = distribution_util.parent_frame_arguments() name = name or "BatchReshape" + distribution.name - self._distribution = distribution with ops.name_scope(name, values=[batch_shape]) as name: - self._batch_shape_ = ops.convert_to_tensor( - batch_shape, - dtype=dtypes.int32, - name="batch_shape") - self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) - if self._batch_shape_static is not None: - self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = validate_init_args( - self._distribution, - self._batch_shape_, - validate_args, - self._batch_shape_static) + # The unexpanded batch shape may contain up to one dimension of -1. + self._batch_shape_unexpanded = ops.convert_to_tensor( + batch_shape, dtype=dtypes.int32, name="batch_shape") + validate_init_args_statically(distribution, self._batch_shape_unexpanded) + batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( + distribution.batch_shape_tensor(), self._batch_shape_unexpanded, + validate_args) + self._distribution = distribution + self._batch_shape_ = batch_shape + self._batch_shape_static = batch_shape_static + self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( - dtype=self._distribution.dtype, - reparameterization_type=self._distribution.reparameterization_type, + dtype=distribution.dtype, + reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [self._batch_shape_] + - self._distribution._graph_parents), # pylint: disable=protected-access + [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name) @property @@ -141,7 +136,7 @@ class BatchReshape(distribution_lib.Distribution): return array_ops.identity(self._batch_shape_) def _batch_shape(self): - return tensor_shape.TensorShape(self._batch_shape_static) + return self._batch_shape_static def _event_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -153,11 +148,13 @@ class BatchReshape(distribution_lib.Distribution): def _sample_n(self, n, seed=None): with ops.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed) - new_shape = array_ops.concat([ - [n], - self.batch_shape_tensor(), - self.event_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + [n], + self._batch_shape_unexpanded, + self.event_shape_tensor(), + ], + axis=0) return array_ops.reshape(x, new_shape) def _log_prob(self, x): @@ -214,9 +211,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] @@ -239,10 +236,11 @@ class BatchReshape(distribution_lib.Distribution): self.event_shape_tensor(), ], axis=0) result = fn(array_ops.reshape(x, old_shape)) - new_shape = array_ops.concat([ - sample_shape, - self.batch_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + sample_shape, + self._batch_shape_unexpanded, + ], axis=0) result = array_ops.reshape(result, new_shape) if (static_sample_shape.ndims is not None and self.batch_shape.ndims is not None): @@ -262,8 +260,7 @@ class BatchReshape(distribution_lib.Distribution): if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = array_ops.concat( - [self.batch_shape_tensor()] + event_shape_list, - axis=0) + [self._batch_shape_unexpanded] + event_shape_list, axis=0) result = array_ops.reshape(fn(), new_shape) if (self.batch_shape.ndims is not None and self.event_shape.ndims is not None): @@ -282,9 +279,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and @@ -356,62 +353,56 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions -def validate_init_args( - distribution, - batch_shape, - validate_args, - batch_shape_static): +def calculate_reshape(original_shape, new_shape, validate=False, name=None): + """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" + batch_shape_static = tensor_util.constant_value_as_shape(new_shape) + if batch_shape_static.is_fully_defined(): + return np.int32(batch_shape_static.as_list()), batch_shape_static, [] + with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): + original_size = math_ops.reduce_prod(original_shape) + implicit_dim = math_ops.equal(new_shape, -1) + size_implicit_dim = ( + original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) + new_ndims = array_ops.shape(new_shape) + expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) + validations = [] if not validate else [ + check_ops.assert_rank( + original_shape, 1, message="Original shape must be a vector."), + check_ops.assert_rank( + new_shape, 1, message="New shape must be a vector."), + check_ops.assert_less_equal( + math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), + 1, + message="At most one dimension can be unknown."), + check_ops.assert_positive( + expanded_new_shape, message="Shape elements must be >=-1."), + check_ops.assert_equal( + math_ops.reduce_prod(expanded_new_shape), + original_size, + message="Shape sizes do not match."), + ] + return expanded_new_shape, batch_shape_static, validations + + +def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" - with ops.name_scope(name="validate_init_args", - values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access - runtime_assertions = [] - - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] - - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) - - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format(batch_shape.shape.ndims)) + + batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) + batch_size_static = batch_shape_static.num_elements() + dist_batch_size_static = distribution.batch_shape.num_elements() + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, dist_batch_size_static)) + + if batch_shape_static.dims is not None: + if any( + dim.value is not None and dim.value < 1 for dim in batch_shape_static): + raise ValueError("`batch_shape` elements must be >=-1.")