BUGFIX: Detect when broadcasting is required and raise NotImplementedError.
authorJoshua V. Dillon <jvdillon@google.com>
Thu, 5 Apr 2018 00:28:41 +0000 (17:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 00:31:14 +0000 (17:31 -0700)
PiperOrigin-RevId: 191673876

tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
tensorflow/contrib/distributions/python/ops/batch_reshape.py

index 4d2f40e..c6c8d2c 100644 (file)
@@ -22,6 +22,7 @@ import numpy as np
 
 from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib
 from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib
+from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
 from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import array_ops
@@ -514,6 +515,42 @@ class _BatchReshapeTest(object):
               batch_shape=new_batch_shape_ph,
               validate_args=True).sample().eval()
 
+  def test_broadcasting_explicitly_unsupported(self):
+    old_batch_shape = [4]
+    new_batch_shape = [1, 4, 1]
+    rate_ = self.dtype([1, 10, 2, 20])
+
+    rate = array_ops.placeholder_with_default(
+        rate_,
+        shape=old_batch_shape if self.is_static_shape else None)
+    poisson_4 = poisson_lib.Poisson(rate)
+    new_batch_shape_ph = (
+        constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape
+        else array_ops.placeholder_with_default(
+            np.int32(new_batch_shape), shape=None))
+    poisson_141_reshaped = batch_reshape_lib.BatchReshape(
+        poisson_4, new_batch_shape_ph, validate_args=True)
+
+    x_4 = self.dtype([2, 12, 3, 23])
+    x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4)
+
+    if self.is_static_shape:
+      with self.assertRaisesRegexp(NotImplementedError,
+                                   "too few event dims"):
+        poisson_141_reshaped.log_prob(x_4)
+      with self.assertRaisesRegexp(NotImplementedError,
+                                   "unexpected batch and event shape"):
+        poisson_141_reshaped.log_prob(x_114)
+      return
+
+    with self.assertRaisesOpError("too few event dims"):
+      with self.test_session():
+        poisson_141_reshaped.log_prob(x_4).eval()
+
+    with self.assertRaisesOpError("unexpected batch and event shape"):
+      with self.test_session():
+        poisson_141_reshaped.log_prob(x_114).eval()
+
 
 class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase):
 
index c7ee9b2..3e6c35e 100644 (file)
@@ -115,7 +115,7 @@ class BatchReshape(distribution_lib.Distribution):
       self._batch_shape_static = tensor_util.constant_value(self._batch_shape_)
       if self._batch_shape_static is not None:
         self._batch_shape_static = np.int32(self._batch_shape_static)
-      self._runtime_assertions = make_runtime_assertions(
+      self._runtime_assertions = validate_init_args(
           self._distribution,
           self._batch_shape_,
           validate_args,
@@ -229,7 +229,8 @@ class BatchReshape(distribution_lib.Distribution):
 
   def _call_reshape_input_output(self, fn, x):
     """Calls `fn`, appropriately reshaping its input `x` and output."""
-    with ops.control_dependencies(self._runtime_assertions):
+    with ops.control_dependencies(
+        self._runtime_assertions + self._validate_sample_arg(x)):
       sample_shape, static_sample_shape = self._sample_shape(x)
       old_shape = array_ops.concat([
           sample_shape,
@@ -273,61 +274,142 @@ class BatchReshape(distribution_lib.Distribution):
         result.set_shape(static_shape)
       return result
 
-
-def make_runtime_assertions(
+  def _validate_sample_arg(self, x):
+    """Helper which validates sample arg, e.g., input to `log_prob`."""
+    with ops.name_scope(name="validate_sample_arg", values=[x]):
+      x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims)
+      event_ndims = (array_ops.size(self.event_shape_tensor())
+                     if self.event_shape.ndims is None
+                     else self.event_shape.ndims)
+      batch_ndims = (array_ops.size(self.batch_shape_tensor())
+                     if self.batch_shape.ndims is None
+                     else self.batch_shape.ndims)
+      expected_batch_event_ndims = batch_ndims + event_ndims
+
+      if (isinstance(x_ndims, int) and
+          isinstance(expected_batch_event_ndims, int)):
+        if x_ndims < expected_batch_event_ndims:
+          raise NotImplementedError(
+              "Broadcasting is not supported; too few event dims "
+              "(expected at least {}, saw {}).".format(
+                  expected_batch_event_ndims, x_ndims))
+        ndims_assertion = []
+      elif self.validate_args:
+        ndims_assertion = [
+            check_ops.assert_greater_equal(
+                x_ndims,
+                expected_batch_event_ndims,
+                message="Broadcasting is not supported; too few event dims.",
+                name="assert_batch_and_event_ndims_large_enough"),
+        ]
+
+      if (self.batch_shape.is_fully_defined() and
+          self.event_shape.is_fully_defined()):
+        expected_batch_event_shape = np.int32(self.batch_shape.concatenate(
+            self.event_shape).as_list())
+      else:
+        expected_batch_event_shape = array_ops.concat([
+            self.batch_shape_tensor(),
+            self.event_shape_tensor(),
+        ], axis=0)
+
+      sample_ndims = x_ndims - expected_batch_event_ndims
+      if isinstance(sample_ndims, int):
+        sample_ndims = max(sample_ndims, 0)
+      if (isinstance(sample_ndims, int) and
+          x.shape[sample_ndims:].is_fully_defined()):
+        actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list())
+      else:
+        sample_ndims = math_ops.maximum(sample_ndims, 0)
+        actual_batch_event_shape = array_ops.shape(x)[sample_ndims:]
+
+      if (isinstance(expected_batch_event_shape, np.ndarray) and
+          isinstance(actual_batch_event_shape, np.ndarray)):
+        if any(expected_batch_event_shape != actual_batch_event_shape):
+          raise NotImplementedError("Broadcasting is not supported; "
+                                    "unexpected batch and event shape "
+                                    "(expected {}, saw {}).".format(
+                                        expected_batch_event_shape,
+                                        actual_batch_event_shape))
+        # We need to set the final runtime-assertions to `ndims_assertion` since
+        # its possible this assertion was created. We could add a condition to
+        # only do so if `self.validate_args == True`, however this is redundant
+        # as `ndims_assertion` already encodes this information.
+        runtime_assertions = ndims_assertion
+      elif self.validate_args:
+        # We need to make the `ndims_assertion` a control dep because otherwise
+        # TF itself might raise an exception owing to this assertion being
+        # ill-defined, ie, one cannot even compare different rank Tensors.
+        with ops.control_dependencies(ndims_assertion):
+          shape_assertion = check_ops.assert_equal(
+              expected_batch_event_shape,
+              actual_batch_event_shape,
+              message=("Broadcasting is not supported; "
+                       "unexpected batch and event shape."),
+              name="assert_batch_and_event_shape_same")
+        runtime_assertions = [shape_assertion]
+      else:
+        runtime_assertions = []
+
+      return runtime_assertions
+
+
+def validate_init_args(
     distribution,
     batch_shape,
     validate_args,
     batch_shape_static):
   """Helper to __init__ which makes or raises assertions."""
-  runtime_assertions = []
-
-  if batch_shape.shape.ndims is not None:
-    if batch_shape.shape.ndims != 1:
-      raise ValueError("`batch_shape` must be a vector "
-                       "(saw rank: {}).".format(
-                           batch_shape.shape.ndims))
-  elif validate_args:
-    runtime_assertions += [
-        check_ops.assert_rank(
-            batch_shape,
-            1,
-            message="`batch_shape` must be a vector.",
-            name="assert_batch_shape_is_vector"),
-    ]
-
-  batch_size_static = np.prod(batch_shape_static)
-  dist_batch_size_static = (
-      None if not distribution.batch_shape.is_fully_defined()
-      else np.prod(distribution.batch_shape).value)
-
-  if batch_size_static is not None and dist_batch_size_static is not None:
-    if batch_size_static != dist_batch_size_static:
-      raise ValueError("`batch_shape` size ({}) must match "
-                       "`distribution.batch_shape` size ({}).".format(
-                           batch_size_static,
-                           dist_batch_size_static))
-  elif validate_args:
-    runtime_assertions += [
-        check_ops.assert_equal(
-            math_ops.reduce_prod(batch_shape),
-            math_ops.reduce_prod(distribution.batch_shape_tensor()),
-            message=("`batch_shape` size must match "
-                     "`distributions.batch_shape` size."),
-            name="assert_batch_size"),
-    ]
-
-  if batch_shape_static is not None:
-    if np.any(batch_shape_static < 1):
-      raise ValueError("`batch_shape` elements must be positive "
-                       "(i.e., larger than zero).")
-  elif validate_args:
-    runtime_assertions += [
-        check_ops.assert_positive(
-            batch_shape,
-            message=("`batch_shape` elements must be positive "
-                     "(i.e., larger than zero)."),
-            name="assert_batch_shape_positive")
-    ]
-
-  return runtime_assertions
+  with ops.name_scope(name="validate_init_args",
+                      values=[batch_shape] + distribution._graph_parents):  # pylint: disable=protected-access
+    runtime_assertions = []
+
+    if batch_shape.shape.ndims is not None:
+      if batch_shape.shape.ndims != 1:
+        raise ValueError("`batch_shape` must be a vector "
+                         "(saw rank: {}).".format(
+                             batch_shape.shape.ndims))
+    elif validate_args:
+      runtime_assertions += [
+          check_ops.assert_rank(
+              batch_shape,
+              1,
+              message="`batch_shape` must be a vector.",
+              name="assert_batch_shape_is_vector"),
+      ]
+
+    batch_size_static = np.prod(batch_shape_static)
+    dist_batch_size_static = (
+        None if not distribution.batch_shape.is_fully_defined()
+        else np.prod(distribution.batch_shape).value)
+
+    if batch_size_static is not None and dist_batch_size_static is not None:
+      if batch_size_static != dist_batch_size_static:
+        raise ValueError("`batch_shape` size ({}) must match "
+                         "`distribution.batch_shape` size ({}).".format(
+                             batch_size_static,
+                             dist_batch_size_static))
+    elif validate_args:
+      runtime_assertions += [
+          check_ops.assert_equal(
+              math_ops.reduce_prod(batch_shape),
+              math_ops.reduce_prod(distribution.batch_shape_tensor()),
+              message=("`batch_shape` size must match "
+                       "`distributions.batch_shape` size."),
+              name="assert_batch_size"),
+      ]
+
+    if batch_shape_static is not None:
+      if np.any(batch_shape_static < 1):
+        raise ValueError("`batch_shape` elements must be positive "
+                         "(i.e., larger than zero).")
+    elif validate_args:
+      runtime_assertions += [
+          check_ops.assert_positive(
+              batch_shape,
+              message=("`batch_shape` elements must be positive "
+                       "(i.e., larger than zero)."),
+              name="assert_batch_shape_positive")
+      ]
+
+    return runtime_assertions