Enable `SeedStream` construction from other `SeedStream` instances.
authorJoshua V. Dillon <jvdillon@google.com>
Fri, 18 May 2018 19:17:05 +0000 (12:17 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 19:20:34 +0000 (12:20 -0700)
PiperOrigin-RevId: 197182686

tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
tensorflow/contrib/distributions/python/ops/seed_stream.py

index 9680573..b91a610 100644 (file)
@@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase):
     self.assertAllUnique(
         outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)])
 
+  def testInitFromOtherSeedStream(self):
+    strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+    strm2 = seed_stream.SeedStream(strm1, salt="salt")
+    strm3 = seed_stream.SeedStream(strm1, salt="another salt")
+    out1 = [strm1() for _ in range(50)]
+    out2 = [strm2() for _ in range(50)]
+    out3 = [strm3() for _ in range(50)]
+    self.assertAllEqual(out1, out2)
+    self.assertAllUnique(out1 + out3)
+
 
 if __name__ == "__main__":
   test.main()
index 056d349..cf505ac 100644 (file)
@@ -169,7 +169,7 @@ class SeedStream(object):
         and TensorFlow Probability code base.  See class docstring for
         rationale.
     """
-    self._seed = seed
+    self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed
     self._salt = salt
     self._counter = 0