while (true)
{
- int readBytes = await ReadAllAsync(adapter, _readHeader).ConfigureAwait(false);
+ int readBytes = await ReadAllAsync(adapter, _readHeader, allowZeroRead: true).ConfigureAwait(false);
if (readBytes == 0)
{
return 0;
{
_readBuffer = new byte[readBytes];
}
- readBytes = await ReadAllAsync(adapter, new Memory<byte>(_readBuffer, 0, readBytes)).ConfigureAwait(false);
- if (readBytes == 0)
- {
- // We already checked that the frame body is bigger than 0 bytes. Hence, this is an EOF.
- throw new IOException(SR.net_io_eof);
- }
+
+ readBytes = await ReadAllAsync(adapter, new Memory<byte>(_readBuffer, 0, readBytes), allowZeroRead: false).ConfigureAwait(false);
// Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_
// Decrypted data start from zero offset, the size can be shrunk after decryption.
_readInProgress = 0;
}
- static async ValueTask<int> ReadAllAsync(TAdapter adapter, Memory<byte> buffer)
+ static async ValueTask<int> ReadAllAsync(TAdapter adapter, Memory<byte> buffer, bool allowZeroRead)
{
- int length = buffer.Length;
+ int read = 0;
do
{
int bytes = await adapter.ReadAsync(buffer).ConfigureAwait(false);
if (bytes == 0)
{
- if (!buffer.IsEmpty)
+ if (read != 0 || !allowZeroRead)
{
throw new IOException(SR.net_io_eof);
}
}
buffer = buffer.Slice(bytes);
+ read += bytes;
}
while (!buffer.IsEmpty);
- return length;
+ return read;
}
}
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
}
}
+
+ [ConditionalFact(nameof(IsNtlmInstalled))]
+ public async Task NegotiateStream_ReadToEof_Returns0()
+ {
+ (Stream stream1, Stream stream2) = TestHelper.GetConnectedStreams();
+ using (var client = new NegotiateStream(stream1))
+ using (var server = new NegotiateStream(stream2))
+ {
+ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+ AuthenticateAsClientAsync(client, CredentialCache.DefaultNetworkCredentials, string.Empty),
+ AuthenticateAsServerAsync(server));
+
+ client.Write(Encoding.UTF8.GetBytes("hello"));
+ client.Dispose();
+
+ Assert.Equal('h', server.ReadByte());
+ Assert.Equal('e', server.ReadByte());
+ Assert.Equal('l', server.ReadByte());
+ Assert.Equal('l', server.ReadByte());
+ Assert.Equal('o', server.ReadByte());
+ Assert.Equal(-1, server.ReadByte());
+ }
+ }
}
public sealed class NegotiateStreamStreamToStreamTest_Async_Array : NegotiateStreamStreamToStreamTest