From: Edward Yang Date: Mon, 17 Dec 2018 21:25:31 +0000 (-0800) Subject: Tighten up invariants regarding StreamId. (#15125) X-Git-Tag: submit/tizen/20210715.075526~2211 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3df79f403e8b9621d5adb0447266becd10d633b0;p=platform%2Fupstream%2Fpytorch.git Tighten up invariants regarding StreamId. (#15125) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15125 I realized that it is really bad juju if you fake a StreamId out of thin air, because in general this isn't going to work. So, make the constructor a lot scarier. Most "faking StreamId out of thin air" happens because someone just wants to put something on the default stream. Reviewed By: dzhulgakov Differential Revision: D13432800 fbshipit-source-id: a86991d6fc1d8aa4e54e8175e5f06f90856238e6 --- diff --git a/aten/src/ATen/detail/CPUGuardImpl.h b/aten/src/ATen/detail/CPUGuardImpl.h index 7370068177..f47cdc355f 100644 --- a/aten/src/ATen/detail/CPUGuardImpl.h +++ b/aten/src/ATen/detail/CPUGuardImpl.h @@ -27,12 +27,12 @@ struct CPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { } Stream getStream(Device d) const noexcept override { // no-op - return Stream(Device(DeviceType::CPU, -1), 0); + return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1)); } // NB: These do NOT set the current device Stream exchangeStream(Stream s) const noexcept override { // no-op - return Stream(Device(DeviceType::CPU, -1), 0); + return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1)); } }; diff --git a/c10/Stream.h b/c10/Stream.h index c128c7a28b..aacfe078e7 100644 --- a/c10/Stream.h +++ b/c10/Stream.h @@ -7,6 +7,11 @@ namespace c10 { /// An index representing a specific stream. A StreamId is not independently /// meaningful without knowing the Device it is associated with; try to /// use Stream rather than StreamId directly. +/// +/// StreamIds are opaque; they are assigned by some DeviceType-specific +/// numbering system which is not visible to the user. HOWEVER, we +/// guarantee that StreamId 0 is always a valid stream, and corresponds +/// to some sort of "default" stream. using StreamId = int32_t; // NB: I decided not to call the above StreamIndex to avoid confusion with @@ -54,10 +59,27 @@ private: Device device_; StreamId id_; public: - explicit Stream(Device device, StreamId id) + enum Unsafe { UNSAFE }; + enum Default { DEFAULT }; + + /// Unsafely construct a stream from a Device and a StreamId. In + /// general, only specific implementations of streams for a + /// backend should manufacture Stream directly in this way; other users + /// should use the provided APIs to get a stream. In particular, + /// we don't require backends to give any guarantees about non-zero + /// StreamIds; they are welcome to allocate in whatever way they like. + explicit Stream(Unsafe, Device device, StreamId id) : device_(device) , id_(id) {} + /// Construct the default stream of a Device. The default stream is + /// NOT the same as the current stream; default stream is a fixed stream + /// that never changes, whereas the current stream may be changed by + /// StreamGuard. + explicit Stream(Default, Device device) + : device_(device) + , id_(0) {} + bool operator==(const Stream& other) const noexcept { return this->device_ == other.device_ && this->id_ == other.id_; } @@ -99,7 +121,9 @@ public: bits >>= 16; auto device_type = static_cast(bits); AT_CHECK(isValidDeviceType(device_type)); - return Stream(Device(device_type, device_index), stream_id); + // Unfortunately, we can't check if the StreamId is valid here; it + // will be checked upon first use. + return Stream(UNSAFE, Device(device_type, device_index), stream_id); } // I decided NOT to provide setters on this class, because really, diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index a8ba415d68..13cc0df9c1 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -74,8 +74,10 @@ static std::vector> high_priori // This is not really for efficiency; it's just easier to write the code // to extract the index if we do this with bitmasks :) // -// This is entirely an internal implementation detail, we reserve the right to -// renumber streams however we like. +// We are obligated to treat the stream ID 0 as the default stream, per the +// invariant specified in c10::Stream. However, all other numbers are entirely +// an internal implementation detail, we reserve the right to renumber streams +// however we like. // // Note that it is really important that the MSB is zero; StreamId is a // *signed* integer, and unsigned to signed conversion outside of the @@ -256,7 +258,9 @@ CUDAStreamInternals* CUDAStream_internals(CUDAStream s) { switch (st) { case StreamIdType::DEFAULT: AT_ASSERTM(si == 0, "Unrecognized stream ", s.unwrap(), - " (I think this should be the default stream, but I got a non-zero index ", si, ")"); + " (I think this should be the default stream, but I got a non-zero index ", si, ").", + " Did you manufacture the StreamId yourself? Don't do that; use the", + " official API like c10::cuda::getStreamFromPool() to get a new stream."); return &default_streams[device_index]; case StreamIdType::LOW: return &low_priority_streams[device_index][si]; @@ -269,7 +273,8 @@ CUDAStreamInternals* CUDAStream_internals(CUDAStream s) { CUDAStream CUDAStream_fromInternals(const CUDAStreamInternals* ptr) { return CUDAStream(CUDAStream::UNCHECKED, - Stream(c10::Device(DeviceType::CUDA, ptr->device_index), + Stream(Stream::UNSAFE, + c10::Device(DeviceType::CUDA, ptr->device_index), CUDAStream_getStreamId(ptr))); } diff --git a/c10/impl/FakeGuardImpl.h b/c10/impl/FakeGuardImpl.h index 72d8bd356f..fd0c0fab08 100644 --- a/c10/impl/FakeGuardImpl.h +++ b/c10/impl/FakeGuardImpl.h @@ -47,12 +47,12 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface { current_device_ = d.index(); } Stream getStream(Device d) const noexcept override { - return Stream(d, current_streams_[d.index()]); + return Stream(Stream::UNSAFE, d, current_streams_[d.index()]); } Stream exchangeStream(Stream s) const noexcept override { auto old_id = current_streams_[s.device_index()]; current_streams_[s.device_index()] = s.id(); - return Stream(s.device(), old_id); + return Stream(Stream::UNSAFE, s.device(), old_id); } // Convenience methods for testing static DeviceIndex getDeviceIndex() { diff --git a/c10/test/impl/InlineStreamGuard_test.cpp b/c10/test/impl/InlineStreamGuard_test.cpp index 276ffc7fc2..bf99658255 100644 --- a/c10/test/impl/InlineStreamGuard_test.cpp +++ b/c10/test/impl/InlineStreamGuard_test.cpp @@ -14,7 +14,7 @@ static Device dev(DeviceIndex index) { } static Stream stream(DeviceIndex index, StreamId sid) { - return Stream(dev(index), sid); + return Stream(Stream::UNSAFE, dev(index), sid); } // -- InlineStreamGuard -------------------------------------------------------