/// </summary>
protected virtual bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => false;
+ protected virtual bool ExtraZeroByteReadsAllowed => false;
+
[Theory]
[InlineData(false)]
[InlineData(true)]
using StreamPair innerStreams = ConnectedStreams.CreateBidirectional();
(Stream innerWriteable, Stream innerReadable) = GetReadWritePair(innerStreams);
- var tracker = new ZeroByteReadTrackingStream(innerReadable);
+ var tracker = new ZeroByteReadTrackingStream(innerReadable, ExtraZeroByteReadsAllowed);
using StreamPair streams = await CreateWrappedConnectedStreamsAsync((innerWriteable, tracker));
(Stream writeable, Stream readable) = GetReadWritePair(streams);
private sealed class ZeroByteReadTrackingStream : DelegatingStream
{
private TaskCompletionSource? _signal;
+ private bool _extraZeroByteReadsAllowed;
- public ZeroByteReadTrackingStream(Stream innerStream) : base(innerStream)
+ public ZeroByteReadTrackingStream(Stream innerStream, bool extraZeroByteReadsAllowed = false) : base(innerStream)
{
+ _extraZeroByteReadsAllowed = extraZeroByteReadsAllowed;
}
public Task WaitForZeroByteReadAsync()
if (bufferLength == 0)
{
var signal = _signal;
- if (signal is null)
+ if (signal is null && !_extraZeroByteReadsAllowed)
{
throw new Exception("Unexpected zero byte read");
}
_signal = null;
- signal.SetResult();
+ signal?.SetResult();
}
}
private const int HandshakeTypeOffsetSsl2 = 2; // Offset of HelloType in Sslv2 and Unified frames
private const int HandshakeTypeOffsetTls = 5; // Offset of HelloType in Sslv3 and TLS frames
+ private const int UnknownTlsFrameLength = int.MaxValue; // frame too short to determine length
+
private bool _receivedEOF;
// Used by Telemetry to ensure we log connection close exactly once
throw SslStreamPal.GetException(status);
}
- _buffer.EnsureAvailableSpace(InitialHandshakeBufferSize);
-
ProtocolToken message;
do
{
- int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+ int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
ProcessTlsFrame(frameSize, out message);
if (message.Size > 0)
while (!handshakeCompleted)
{
- int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+ int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
ProcessTlsFrame(frameSize, out message);
ReadOnlyMemory<byte> payload = default;
}
// This method will make sure we have at least one full TLS frame buffered.
- private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
+ private async ValueTask<int> ReceiveHandshakeFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
- int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+ int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, InitialHandshakeBufferSize).ConfigureAwait(false);
if (frameSize == 0)
{
private bool HaveFullTlsFrame(out int frameSize)
{
- if (_buffer.EncryptedLength < TlsFrameHelper.HeaderSize)
- {
- frameSize = int.MaxValue;
- return false;
- }
-
frameSize = GetFrameSize(_buffer.EncryptedReadOnlySpan);
return _buffer.EncryptedLength >= frameSize;
}
[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
- private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
+ private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken, int estimatedSize)
where TIOAdapter : IReadWriteAdapter
{
- int frameSize;
- if (HaveFullTlsFrame(out frameSize))
+ if (HaveFullTlsFrame(out int frameSize))
{
return frameSize;
}
- if (frameSize != int.MaxValue)
- {
- // make sure we have space for the whole frame
- _buffer.EnsureAvailableSpace(frameSize - _buffer.EncryptedLength);
- }
- else
- {
- // move existing data to the beginning of the buffer (they will
- // be couple of bytes only, otherwise we would have entire
- // header and know exact size)
- _buffer.EnsureAvailableSpace(_buffer.Capacity - _buffer.EncryptedLength);
- }
+ await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
+
+ // If we don't have enough data to determine the frame size, use the provided estimate
+ // (e.g. a full TLS frame for reads, and a somewhat shorter frame for handshake / renegotiation).
+ // If we do know the frame size, ensure we have space for the whole frame.
+ _buffer.EnsureAvailableSpace(frameSize == UnknownTlsFrameLength ?
+ estimatedSize :
+ frameSize - _buffer.EncryptedLength);
while (_buffer.EncryptedLength < frameSize)
{
private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
+
// Throw first if we already have exception.
// Check for disposal is not atomic so we will check again below.
ThrowIfExceptionalOrNotAuthenticated();
try
{
int processedLength = 0;
+ int nextTlsFrameLength = UnknownTlsFrameLength;
if (_buffer.DecryptedLength != 0)
{
processedLength = CopyDecryptedData(buffer);
- if (processedLength == buffer.Length || !HaveFullTlsFrame(out _))
+ if (processedLength == buffer.Length || !HaveFullTlsFrame(out nextTlsFrameLength))
{
// We either filled whole buffer or used all buffered frames.
return processedLength;
buffer = buffer.Slice(processedLength);
}
- if (_receivedEOF)
+ if (_receivedEOF && nextTlsFrameLength == UnknownTlsFrameLength)
{
+ // there should be no frames waiting for processing
Debug.Assert(_buffer.EncryptedLength == 0);
// We received EOF during previous read but had buffered data to return.
return 0;
}
- if (buffer.Length == 0 && _buffer.ActiveLength == 0)
- {
- // User requested a zero-byte read, and we have no data available in the buffer for processing.
- // This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
- // for reduced memory consumption when data is not immediately available.
- // So, we will issue our own zero-byte read against the underlying stream and defer buffer allocation
- // until data is actually available from the underlying stream.
- // Note that if the underlying stream does not supporting blocking on zero byte reads, then this will
- // complete immediately and won't save any memory, but will still function correctly.
- await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
- }
-
Debug.Assert(_buffer.DecryptedLength == 0);
- _buffer.EnsureAvailableSpace(ReadBufferSize - _buffer.ActiveLength);
-
while (true)
{
- int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
+ int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, ReadBufferSize).ConfigureAwait(false);
if (payloadBytes == 0)
{
_receivedEOF = true;
// Returns TLS Frame size including header size.
private int GetFrameSize(ReadOnlySpan<byte> buffer)
{
+ if (buffer.Length < TlsFrameHelper.HeaderSize)
+ {
+ return UnknownTlsFrameLength;
+ }
+
if (!TlsFrameHelper.TryGetFrameHeader(buffer, ref _lastFrame.Header))
{
throw new IOException(SR.net_ssl_io_frame);