Define PRNG seeding style for new code in Distributions and TF Probability, with...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 14:13:49 +0000 (07:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 14:16:16 +0000 (07:16 -0700)
Implement lightweight PRNG for seed generation in that style.

Enables incremental refactoring of existing code into this style.

PiperOrigin-RevId: 191884573

tensorflow/contrib/distributions/BUILD
tensorflow/contrib/distributions/__init__.py
tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py [new file with mode: 0644]
tensorflow/contrib/distributions/python/ops/seed_stream.py [new file with mode: 0644]

index 9799901..fec6eaf 100644 (file)
@@ -491,6 +491,16 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "seed_stream_test",
+    size = "small",
+    srcs = ["python/kernel_tests/seed_stream_test.py"],
+    additional_deps = [
+        ":distributions_py",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+cuda_py_test(
     name = "statistical_testing_test",
     size = "medium",
     srcs = [
index 4d44894..ddf5989 100644 (file)
@@ -59,6 +59,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
 from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
 from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
 from tensorflow.contrib.distributions.python.ops.sample_stats import *
+from tensorflow.contrib.distributions.python.ops.seed_stream import *
 from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import *
 from tensorflow.contrib.distributions.python.ops.test_util import *
 from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import *
@@ -126,6 +127,7 @@ _allowed_symbols = [
     'NormalWithSoftplusScale',
     'Poisson',
     'PoissonLogNormalQuadratureCompound',
+    'SeedStream',
     'SinhArcsinh',
     'StudentT',
     'StudentTWithAbsDfSoftplusScale',
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
new file mode 100644 (file)
index 0000000..9680573
--- /dev/null
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SeedStream class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import seed_stream
+from tensorflow.python.platform import test
+
+
+class SeedStreamTest(test.TestCase):
+
+  def assertAllUnique(self, items):
+    self.assertEqual(len(items), len(set(items)))
+
+  def testNonRepetition(self):
+    # The probability of repetitions in a short stream from a correct
+    # PRNG is negligible; this test catches bugs that prevent state
+    # updates.
+    strm = seed_stream.SeedStream(seed=4, salt="salt")
+    output = [strm() for _ in range(50)]
+    self.assertEqual(sorted(output), sorted(list(set(output))))
+
+  def testReproducibility(self):
+    strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+    strm2 = seed_stream.SeedStream(seed=4, salt="salt")
+    strm3 = seed_stream.SeedStream(seed=4, salt="salt")
+    outputs = [strm1() for _ in range(50)]
+    self.assertEqual(outputs, [strm2() for _ in range(50)])
+    self.assertEqual(outputs, [strm3() for _ in range(50)])
+
+  def testSeededDistinctness(self):
+    strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+    strm2 = seed_stream.SeedStream(seed=5, salt="salt")
+    self.assertAllUnique(
+        [strm1() for _ in range(50)] + [strm2() for _ in range(50)])
+
+  def testSaltedDistinctness(self):
+    strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+    strm2 = seed_stream.SeedStream(seed=4, salt="another salt")
+    self.assertAllUnique(
+        [strm1() for _ in range(50)] + [strm2() for _ in range(50)])
+
+  def testNestingRobustness(self):
+    # SeedStreams started from generated seeds should not collide with
+    # the master or with each other, even if the salts are the same.
+    strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+    strm2 = seed_stream.SeedStream(strm1(), salt="salt")
+    strm3 = seed_stream.SeedStream(strm1(), salt="salt")
+    outputs = [strm1() for _ in range(50)]
+    self.assertAllUnique(
+        outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)])
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py
new file mode 100644 (file)
index 0000000..056d349
--- /dev/null
@@ -0,0 +1,228 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Local PRNG for amplifying seed entropy into seeds for base operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import hashlib
+
+
+class SeedStream(object):
+  """Local PRNG for amplifying seed entropy into seeds for base operations.
+
+  Writing sampling code which correctly sets the pseudo-random number
+  generator (PRNG) seed is surprisingly difficult.  This class serves as
+  a helper for the TensorFlow Probability coding pattern designed to
+  avoid common mistakes.
+
+  # Motivating Example
+
+  A common first-cut implementation of a sampler for the beta
+  distribution is to compute the ratio of a gamma with itself plus
+  another gamma.  This code snippet tries to do that, but contains a
+  surprisingly common error:
+
+  ```python
+  def broken_beta(shape, alpha, beta, seed):
+    x = tf.random_gamma(shape, alpha, seed=seed)
+    y = tf.random_gamma(shape, beta, seed=seed)
+    return x / (x + y)
+  ```
+
+  The mistake is that the two gamma draws are seeded with the same
+  seed.  This causes them to always produce the same results, which,
+  in turn, leads this code snippet to always return `0.5`.  Because it
+  can happen across abstraction boundaries, this kind of error is
+  surprisingly easy to make when handling immutable seeds.
+
+  # Goals
+
+  TensorFlow Probability adopts a code style designed to eliminate the
+  above class of error, without exacerbating others.  The goals of
+  this code style are:
+
+  - Support reproducibility of results (by encouraging seeding of all
+    pseudo-random operations).
+
+  - Avoid shared-write global state (by not relying on a global PRNG).
+
+  - Prevent accidental seed reuse by TF Probability implementers.  This
+    goal is served with the local pseudo-random seed generator provided
+    in this module.
+
+  - Mitigate potential accidental seed reuse by TF Probability clients
+    (with a salting scheme).
+
+  - Prevent accidental resonances with downstream PRNGs (by hashing the
+    output).
+
+  ## Non-goals
+
+  - Implementing a high-performance PRNG for generating large amounts of
+    entropy.  That's the job of the underlying TensorFlow PRNG we are
+    seeding.
+
+  - Avoiding random seed collisions, aka "birthday attacks".
+
+  # Code pattern
+
+  ```python
+  def random_beta(shape, alpha, beta, seed):        # (a)
+    seed = SeedStream(seed, salt="random_beta")     # (b)
+    x = tf.random_gamma(shape, alpha, seed=seed())  # (c)
+    y = tf.random_gamma(shape, beta, seed=seed())   # (c)
+    return x / (x + y)
+  ```
+
+  The elements of this pattern are:
+
+  - Accept an explicit seed (line a) as an argument in all public
+    functions, and write the function to be deterministic (up to any
+    numerical issues) for fixed seed.
+
+    - Rationale: This provides the client with the ability to reproduce
+      results.  Accepting an immutable seed rather than a mutable PRNG
+      object reduces code coupling, permitting different sections to be
+      reproducible independently.
+
+  - Use that seed only to initialize a local `SeedStream` instance (line b).
+
+    - Rationale: Avoids accidental seed reuse.
+
+  - Supply the name of the function being implemented as a salt to the
+    `SeedStream` instance (line b).  This serves to keep the salts
+    unique; unique salts ensure that clients of TF Probability will see
+    different functions always produce independent results even if
+    called with the same seeds.
+
+  - Seed each callee operation with the output of a unique call to the
+    `SeedStream` instance (lines c).  This ensures reproducibility of
+    results while preventing seed reuse across callee invocations.
+
+  # Why salt?
+
+  Salting the `SeedStream` instances (with unique salts) is defensive
+  programming against a client accidentally committing a mistake
+  similar to our motivating example.  Consider the following situation
+  that might arise without salting:
+
+  ```python
+  def tfp_foo(seed):
+    seed = SeedStream(seed, salt="")
+    foo_stuff = tf.random_normal(seed=seed())
+    ...
+
+  def tfp_bar(seed):
+    seed = SeedStream(seed, salt="")
+    bar_stuff = tf.random_normal(seed=seed())
+    ...
+
+  def client_baz(seed):
+    foo = tfp_foo(seed=seed)
+    bar = tfp_bar(seed=seed)
+    ...
+  ```
+
+  The client should have used different seeds as inputs to `foo` and
+  `bar`.  However, because they didn't, *and because `foo` and `bar`
+  both sample a Gaussian internally as their first action*, the
+  internal `foo_stuff` and `bar_stuff` will be the same, and the
+  returned `foo` and `bar` will not be independent, leading to subtly
+  incorrect answers from the client's simulation.  This kind of bug is
+  particularly insidious for the client, because it depends on a
+  Distributions implementation detail, namely the order in which `foo`
+  and `bar` invoke the samplers they depend on.  In particular, a
+  Bayesflow team member can introduce such a bug in previously
+  (accidentally) correct client code by performing an internal
+  refactoring that causes this operation order alignment.
+
+  A salting discipline eliminates this problem by making sure that the
+  seeds seen by `foo`'s callees will differ from those seen by `bar`'s
+  callees, even if `foo` and `bar` are invoked with the same input
+  seed.
+  """
+
+  def __init__(self, seed, salt):
+    """Initializes a `SeedStream`.
+
+    Args:
+      seed: Any Python object convertible to string, supplying the
+        initial entropy.  If `None`, operations seeded with seeds
+        drawn from this `SeedStream` will follow TensorFlow semantics
+        for not being seeded.
+      salt: Any Python object convertible to string, supplying
+        auxiliary entropy.  Must be unique across the Distributions
+        and TensorFlow Probability code base.  See class docstring for
+        rationale.
+    """
+    self._seed = seed
+    self._salt = salt
+    self._counter = 0
+
+  def __call__(self):
+    """Returns a fresh integer usable as a seed in downstream operations.
+
+    If this `SeedStream` was initialized with `seed=None`, returns
+    `None`.  This has the effect that downstream operations (both
+    `SeedStream`s and primitive TensorFlow ops) will behave as though
+    they were unseeded.
+
+    The returned integer is non-negative, and uniformly distributed in
+    the half-open interval `[0, 2**512)`.  This is consistent with
+    TensorFlow, as TensorFlow operations internally use the residue of
+    the given seed modulo `2**31 - 1` (see
+    `tensorflow/python/framework/random_seed.py`).
+
+    Returns:
+      seed: A fresh integer usable as a seed in downstream operations,
+        or `None`.
+    """
+    self._counter += 1
+    if self._seed is None:
+      return None
+    composite = str((self._seed, self._counter, self._salt)).encode("utf-8")
+    return int(hashlib.sha512(composite).hexdigest(), 16)
+
+  @property
+  def original_seed(self):
+    return self._seed
+
+  @property
+  def salt(self):
+    return self._salt
+
+# Design rationales for the SeedStream class
+#
+# - Salts are accepted for the reason given above to supply them.
+#
+# - A `None` seed propagates to downstream seeds, so they exhibit
+#   their "unseeded" behavior.
+#
+# - The return value is a Python int so it can be passed directly to
+#   TensorFlow operations as a seed.  It is large to avoid losing seed
+#   space needlessly (TF will internally read only the last 31 bits).
+#
+# - The output is hashed with a crypto-grade hash function as a form
+#   of defensive programming: this reliably prevents all possible
+#   accidental resonances with all possible downstream PRNGs.  The
+#   specific function used is not important; SHA512 was ready to hand.
+#
+# - The internal state update is a simple counter because (a) given
+#   that the output is hashed anyway, this is enough, and (b) letting
+#   it be this predictable permits a future "generate many seeds in
+#   parallel" operation whose results would agree with running
+#   sequentially.