}
Stream getStream(Device d) const noexcept override {
// no-op
- return Stream(Device(DeviceType::CPU, -1), 0);
+ return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
// no-op
- return Stream(Device(DeviceType::CPU, -1), 0);
+ return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1));
}
};
/// An index representing a specific stream. A StreamId is not independently
/// meaningful without knowing the Device it is associated with; try to
/// use Stream rather than StreamId directly.
+///
+/// StreamIds are opaque; they are assigned by some DeviceType-specific
+/// numbering system which is not visible to the user. HOWEVER, we
+/// guarantee that StreamId 0 is always a valid stream, and corresponds
+/// to some sort of "default" stream.
using StreamId = int32_t;
// NB: I decided not to call the above StreamIndex to avoid confusion with
Device device_;
StreamId id_;
public:
- explicit Stream(Device device, StreamId id)
+ enum Unsafe { UNSAFE };
+ enum Default { DEFAULT };
+
+ /// Unsafely construct a stream from a Device and a StreamId. In
+ /// general, only specific implementations of streams for a
+ /// backend should manufacture Stream directly in this way; other users
+ /// should use the provided APIs to get a stream. In particular,
+ /// we don't require backends to give any guarantees about non-zero
+ /// StreamIds; they are welcome to allocate in whatever way they like.
+ explicit Stream(Unsafe, Device device, StreamId id)
: device_(device)
, id_(id) {}
+ /// Construct the default stream of a Device. The default stream is
+ /// NOT the same as the current stream; default stream is a fixed stream
+ /// that never changes, whereas the current stream may be changed by
+ /// StreamGuard.
+ explicit Stream(Default, Device device)
+ : device_(device)
+ , id_(0) {}
+
bool operator==(const Stream& other) const noexcept {
return this->device_ == other.device_ && this->id_ == other.id_;
}
bits >>= 16;
auto device_type = static_cast<DeviceType>(bits);
AT_CHECK(isValidDeviceType(device_type));
- return Stream(Device(device_type, device_index), stream_id);
+ // Unfortunately, we can't check if the StreamId is valid here; it
+ // will be checked upon first use.
+ return Stream(UNSAFE, Device(device_type, device_index), stream_id);
}
// I decided NOT to provide setters on this class, because really,
// This is not really for efficiency; it's just easier to write the code
// to extract the index if we do this with bitmasks :)
//
-// This is entirely an internal implementation detail, we reserve the right to
-// renumber streams however we like.
+// We are obligated to treat the stream ID 0 as the default stream, per the
+// invariant specified in c10::Stream. However, all other numbers are entirely
+// an internal implementation detail, we reserve the right to renumber streams
+// however we like.
//
// Note that it is really important that the MSB is zero; StreamId is a
// *signed* integer, and unsigned to signed conversion outside of the
switch (st) {
case StreamIdType::DEFAULT:
AT_ASSERTM(si == 0, "Unrecognized stream ", s.unwrap(),
- " (I think this should be the default stream, but I got a non-zero index ", si, ")");
+ " (I think this should be the default stream, but I got a non-zero index ", si, ").",
+ " Did you manufacture the StreamId yourself? Don't do that; use the",
+ " official API like c10::cuda::getStreamFromPool() to get a new stream.");
return &default_streams[device_index];
case StreamIdType::LOW:
return &low_priority_streams[device_index][si];
CUDAStream CUDAStream_fromInternals(const CUDAStreamInternals* ptr) {
return CUDAStream(CUDAStream::UNCHECKED,
- Stream(c10::Device(DeviceType::CUDA, ptr->device_index),
+ Stream(Stream::UNSAFE,
+ c10::Device(DeviceType::CUDA, ptr->device_index),
CUDAStream_getStreamId(ptr)));
}
current_device_ = d.index();
}
Stream getStream(Device d) const noexcept override {
- return Stream(d, current_streams_[d.index()]);
+ return Stream(Stream::UNSAFE, d, current_streams_[d.index()]);
}
Stream exchangeStream(Stream s) const noexcept override {
auto old_id = current_streams_[s.device_index()];
current_streams_[s.device_index()] = s.id();
- return Stream(s.device(), old_id);
+ return Stream(Stream::UNSAFE, s.device(), old_id);
}
// Convenience methods for testing
static DeviceIndex getDeviceIndex() {
}
static Stream stream(DeviceIndex index, StreamId sid) {
- return Stream(dev(index), sid);
+ return Stream(Stream::UNSAFE, dev(index), sid);
}
// -- InlineStreamGuard -------------------------------------------------------