Tighten up invariants regarding StreamId. (#15125)
authorEdward Yang <ezyang@fb.com>
Mon, 17 Dec 2018 21:25:31 +0000 (13:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 21:30:54 +0000 (13:30 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15125

I realized that it is really bad juju if you fake a StreamId
out of thin air, because in general this isn't going to work.
So, make the constructor a lot scarier.

Most "faking StreamId out of thin air" happens because someone
just wants to put something on the default stream.

Reviewed By: dzhulgakov

Differential Revision: D13432800

fbshipit-source-id: a86991d6fc1d8aa4e54e8175e5f06f90856238e6

aten/src/ATen/detail/CPUGuardImpl.h
c10/Stream.h
c10/cuda/CUDAStream.cpp
c10/impl/FakeGuardImpl.h
c10/test/impl/InlineStreamGuard_test.cpp

index 737006817720521bb5e2adf76cfeb85469f1c694..f47cdc355f4ab9e4ac57edabae05d3f0669f865c 100644 (file)
@@ -27,12 +27,12 @@ struct CPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
   }
   Stream getStream(Device d) const noexcept override {
     // no-op
-    return Stream(Device(DeviceType::CPU, -1), 0);
+    return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1));
   }
   // NB: These do NOT set the current device
   Stream exchangeStream(Stream s) const noexcept override {
     // no-op
-    return Stream(Device(DeviceType::CPU, -1), 0);
+    return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1));
   }
 };
 
index c128c7a28bfc7d3fd5ac4693332093cfccbfbbaf..aacfe078e74ff464a3c4dc6eb413e957dfad2934 100644 (file)
@@ -7,6 +7,11 @@ namespace c10 {
 /// An index representing a specific stream.  A StreamId is not independently
 /// meaningful without knowing the Device it is associated with; try to
 /// use Stream rather than StreamId directly.
+///
+/// StreamIds are opaque; they are assigned by some DeviceType-specific
+/// numbering system which is not visible to the user.  HOWEVER, we
+/// guarantee that StreamId 0 is always a valid stream, and corresponds
+/// to some sort of "default" stream.
 using StreamId = int32_t;
 
 // NB: I decided not to call the above StreamIndex to avoid confusion with
@@ -54,10 +59,27 @@ private:
   Device device_;
   StreamId id_;
 public:
-  explicit Stream(Device device, StreamId id)
+  enum Unsafe { UNSAFE };
+  enum Default { DEFAULT };
+
+  /// Unsafely construct a stream from a Device and a StreamId.  In
+  /// general, only specific implementations of streams for a
+  /// backend should manufacture Stream directly in this way; other users
+  /// should use the provided APIs to get a stream.  In particular,
+  /// we don't require backends to give any guarantees about non-zero
+  /// StreamIds; they are welcome to allocate in whatever way they like.
+  explicit Stream(Unsafe, Device device, StreamId id)
     : device_(device)
     , id_(id) {}
 
+  /// Construct the default stream of a Device.  The default stream is
+  /// NOT the same as the current stream; default stream is a fixed stream
+  /// that never changes, whereas the current stream may be changed by
+  /// StreamGuard.
+  explicit Stream(Default, Device device)
+    : device_(device)
+    , id_(0) {}
+
   bool operator==(const Stream& other) const noexcept {
     return this->device_ == other.device_ && this->id_ == other.id_;
   }
@@ -99,7 +121,9 @@ public:
     bits >>= 16;
     auto device_type = static_cast<DeviceType>(bits);
     AT_CHECK(isValidDeviceType(device_type));
-    return Stream(Device(device_type, device_index), stream_id);
+    // Unfortunately, we can't check if the StreamId is valid here; it
+    // will be checked upon first use.
+    return Stream(UNSAFE, Device(device_type, device_index), stream_id);
   }
 
   // I decided NOT to provide setters on this class, because really,
index a8ba415d681405559220afb878ed47bf759f0d22..13cc0df9c1dab7fd21032906a6215661e3297d3b 100644 (file)
@@ -74,8 +74,10 @@ static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> high_priori
 // This is not really for efficiency; it's just easier to write the code
 // to extract the index if we do this with bitmasks :)
 //
-// This is entirely an internal implementation detail, we reserve the right to
-// renumber streams however we like.
+// We are obligated to treat the stream ID 0 as the default stream, per the
+// invariant specified in c10::Stream.  However, all other numbers are entirely
+// an internal implementation detail, we reserve the right to renumber streams
+// however we like.
 //
 // Note that it is really important that the MSB is zero; StreamId is a
 // *signed* integer, and unsigned to signed conversion outside of the
@@ -256,7 +258,9 @@ CUDAStreamInternals* CUDAStream_internals(CUDAStream s) {
   switch (st) {
     case StreamIdType::DEFAULT:
       AT_ASSERTM(si == 0, "Unrecognized stream ", s.unwrap(),
-                          " (I think this should be the default stream, but I got a non-zero index ", si, ")");
+                          " (I think this should be the default stream, but I got a non-zero index ", si, ").",
+                          " Did you manufacture the StreamId yourself?  Don't do that; use the",
+                          " official API like c10::cuda::getStreamFromPool() to get a new stream.");
       return &default_streams[device_index];
     case StreamIdType::LOW:
       return &low_priority_streams[device_index][si];
@@ -269,7 +273,8 @@ CUDAStreamInternals* CUDAStream_internals(CUDAStream s) {
 
 CUDAStream CUDAStream_fromInternals(const CUDAStreamInternals* ptr) {
   return CUDAStream(CUDAStream::UNCHECKED,
-                    Stream(c10::Device(DeviceType::CUDA, ptr->device_index),
+                    Stream(Stream::UNSAFE,
+                           c10::Device(DeviceType::CUDA, ptr->device_index),
                            CUDAStream_getStreamId(ptr)));
 }
 
index 72d8bd356f4983c05edf2e6ab76a03016ddad8b7..fd0c0fab08fc6983cbace61d685ab74163dde4ea 100644 (file)
@@ -47,12 +47,12 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface {
     current_device_ = d.index();
   }
   Stream getStream(Device d) const noexcept override {
-    return Stream(d, current_streams_[d.index()]);
+    return Stream(Stream::UNSAFE, d, current_streams_[d.index()]);
   }
   Stream exchangeStream(Stream s) const noexcept override {
     auto old_id = current_streams_[s.device_index()];
     current_streams_[s.device_index()] = s.id();
-    return Stream(s.device(), old_id);
+    return Stream(Stream::UNSAFE, s.device(), old_id);
   }
   // Convenience methods for testing
   static DeviceIndex getDeviceIndex() {
index 276ffc7fc29a2faa57cc7301b85954d9b073135c..bf99658255f240f615c8071d64570626b8288808 100644 (file)
@@ -14,7 +14,7 @@ static Device dev(DeviceIndex index) {
 }
 
 static Stream stream(DeviceIndex index, StreamId sid) {
-  return Stream(dev(index), sid);
+  return Stream(Stream::UNSAFE, dev(index), sid);
 }
 
 // -- InlineStreamGuard -------------------------------------------------------