From: Joshua V. Dillon Date: Fri, 18 May 2018 19:17:05 +0000 (-0700) Subject: Enable `SeedStream` construction from other `SeedStream` instances. X-Git-Tag: upstream/v1.9.0_rc1~94^2^2~11 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=514bb4f3a630612fd6f6aaf62d9bbc0e4c72d0ff;p=platform%2Fupstream%2Ftensorflow.git Enable `SeedStream` construction from other `SeedStream` instances. PiperOrigin-RevId: 197182686 --- diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 9680573..b91a610 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index 056d349..cf505ac 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -169,7 +169,7 @@ class SeedStream(object): and TensorFlow Probability code base. See class docstring for rationale. """ - self._seed = seed + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed self._salt = salt self._counter = 0