// This contains adapters to allow a single code path for sync/async logic
public partial class SslStream
{
- private interface ISslWriteAdapter
- {
- Task LockAsync();
- ValueTask WriteAsync(byte[] buffer, int offset, int count);
- CancellationToken CancellationToken { get; }
- }
-
- private interface ISslReadAdapter
+ private interface ISslIOAdapter
{
ValueTask<int> ReadAsync(byte[] buffer, int offset, int count);
- ValueTask<int> LockAsync(Memory<byte> buffer);
+ ValueTask<int> ReadLockAsync(Memory<byte> buffer);
+ Task WriteLockAsync();
+ ValueTask WriteAsync(byte[] buffer, int offset, int count);
CancellationToken CancellationToken { get; }
}
- private readonly struct SslReadAsync : ISslReadAdapter
+ private readonly struct AsyncSslIOAdapter : ISslIOAdapter
{
private readonly SslStream _sslStream;
private readonly CancellationToken _cancellationToken;
- public SslReadAsync(SslStream sslStream, CancellationToken cancellationToken)
+ public AsyncSslIOAdapter(SslStream sslStream, CancellationToken cancellationToken)
{
_cancellationToken = cancellationToken;
_sslStream = sslStream;
public ValueTask<int> ReadAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.ReadAsync(new Memory<byte>(buffer, offset, count), _cancellationToken);
- public ValueTask<int> LockAsync(Memory<byte> buffer) => _sslStream.CheckEnqueueReadAsync(buffer);
+ public ValueTask<int> ReadLockAsync(Memory<byte> buffer) => _sslStream.CheckEnqueueReadAsync(buffer);
+
+ public Task WriteLockAsync() => _sslStream.CheckEnqueueWriteAsync();
+
+ public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), _cancellationToken);
public CancellationToken CancellationToken => _cancellationToken;
}
- private readonly struct SslReadSync : ISslReadAdapter
+ private readonly struct SyncSslIOAdapter : ISslIOAdapter
{
private readonly SslStream _sslStream;
- public SslReadSync(SslStream sslStream) => _sslStream = sslStream;
+ public SyncSslIOAdapter(SslStream sslStream) => _sslStream = sslStream;
public ValueTask<int> ReadAsync(byte[] buffer, int offset, int count) => new ValueTask<int>(_sslStream.InnerStream.Read(buffer, offset, count));
- public ValueTask<int> LockAsync(Memory<byte> buffer) => new ValueTask<int>(_sslStream.CheckEnqueueRead(buffer));
+ public ValueTask<int> ReadLockAsync(Memory<byte> buffer) => new ValueTask<int>(_sslStream.CheckEnqueueRead(buffer));
- public CancellationToken CancellationToken => default;
- }
-
- private readonly struct SslWriteAsync : ISslWriteAdapter
- {
- private readonly SslStream _sslStream;
- private readonly CancellationToken _cancellationToken;
-
- public SslWriteAsync(SslStream sslStream, CancellationToken cancellationToken)
+ public ValueTask WriteAsync(byte[] buffer, int offset, int count)
{
- _sslStream = sslStream;
- _cancellationToken = cancellationToken;
+ _sslStream.InnerStream.Write(buffer, offset, count);
+ return default;
}
- public Task LockAsync() => _sslStream.CheckEnqueueWriteAsync();
-
- public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), _cancellationToken);
-
- public CancellationToken CancellationToken => _cancellationToken;
- }
-
- private readonly struct SslWriteSync : ISslWriteAdapter
- {
- private readonly SslStream _sslStream;
-
- public SslWriteSync(SslStream sslStream) => _sslStream = sslStream;
-
- public Task LockAsync()
+ public Task WriteLockAsync()
{
_sslStream.CheckEnqueueWrite();
return Task.CompletedTask;
}
- public ValueTask WriteAsync(byte[] buffer, int offset, int count)
- {
- _sslStream.InnerStream.Write(buffer, offset, count);
- return default;
- }
-
public CancellationToken CancellationToken => default;
}
}
private int _nestedAuth;
- private SecurityStatusPal _securityStatus;
-
- private enum CachedSessionStatus : byte
- {
- Unknown = 0,
- IsNotCached = 1,
- IsCached = 2,
- Renegotiated = 3
- }
- private CachedSessionStatus _CachedSession;
-
private enum Framing
{
Unknown = 0,
private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default)
{
Task result = null;
- if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
- {
- throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate"));
- }
-
- try
- {
- ThrowIfExceptional();
- // A trick to discover and avoid cached sessions.
- _CachedSession = CachedSessionStatus.Unknown;
-
- if (isAsync)
- {
- result = ForceAuthenticationAsync(_context.IsServer, null, cancellationToken);
- }
- else
- {
- ForceAuthentication(_context.IsServer, null);
+ ThrowIfExceptional();
- if (NetEventSource.IsEnabled)
- NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication),
- SslProtocol,
- CipherAlgorithm,
- CipherStrength,
- HashAlgorithm,
- HashStrength,
- KeyExchangeAlgorithm,
- KeyExchangeStrength);
- }
+ if (isAsync)
+ {
+ result = ForceAuthenticationAsync(new AsyncSslIOAdapter(this, cancellationToken), _context.IsServer, null, isApm);
}
- finally
+ else
{
- // Operation has completed.
- _nestedAuth = 0;
+ ForceAuthenticationAsync(new SyncSslIOAdapter(this), _context.IsServer, null).GetAwaiter().GetResult();
}
return result;
//
// This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read().
//
- private async Task ReplyOnReAuthenticationAsync(byte[] buffer, CancellationToken cancellationToken)
+ private async Task ReplyOnReAuthenticationAsync<TIOAdapter>(TIOAdapter adapter, byte[] buffer)
+ where TIOAdapter : ISslIOAdapter
{
lock (SyncLock)
{
_lockReadState = LockHandshake;
}
- await ForceAuthenticationAsync(false, buffer, cancellationToken).ConfigureAwait(false);
+ await ForceAuthenticationAsync(adapter, receiveFirst: false, buffer).ConfigureAwait(false);
FinishHandshakeRead(LockNone);
}
- //
- // This method attempts to start authentication.
- // Incoming buffer is either null or is the result of "renegotiate" decrypted message
- // If write is in progress the method will either wait or be put on hold
- //
- private void ForceAuthentication(bool receiveFirst, byte[] buffer)
+ // reAuthenticationData is only used on Windows in case of renegotiation.
+ private async Task ForceAuthenticationAsync<TIOAdapter>(TIOAdapter adapter, bool receiveFirst, byte[] reAuthenticationData, bool isApm = false)
+ where TIOAdapter : ISslIOAdapter
{
- // This will tell that we don't know the framing yet (what SSL version is)
_framing = Framing.Unknown;
+ ProtocolToken message;
- try
+ if (reAuthenticationData == null)
{
- if (receiveFirst)
- {
- // Listen for a client blob.
- ReceiveBlob(buffer);
- }
- else
+ // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation tansparently.
+ if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
{
- // We start with the first blob.
- SendBlob(buffer, (buffer == null ? 0 : buffer.Length));
+ throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate"));
}
}
- catch (Exception e)
+
+ try
{
- // Failed auth, reset the framing if any.
- _framing = Framing.Unknown;
- _handshakeCompleted = false;
- SetException(e);
- if (_exception.SourceException != e)
+ if (!receiveFirst)
{
- ThrowIfExceptional();
+ message = _context.NextMessage(reAuthenticationData, 0, (reAuthenticationData == null ? 0 : reAuthenticationData.Length));
+ if (message.Size > 0)
+ {
+ await adapter.WriteAsync(message.Payload, 0, message.Size).ConfigureAwait(false);
+ }
+
+ if (message.Failed)
+ {
+ // tracing done in NextMessage()
+ throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+ }
}
- throw;
- }
- finally
- {
- if (_exception != null)
+
+ do
{
- // This a failed handshake. Release waiting IO if any.
- FinishHandshake(null);
- }
- }
- }
+ message = await ReceiveBlobAsync(adapter).ConfigureAwait(false);
+ if (message.Size > 0)
+ {
+ // If there is message send it out even if call failed. It may contain TLS Alert.
+ await adapter.WriteAsync(message.Payload, 0, message.Size).ConfigureAwait(false);
+ }
- internal async Task ForceAuthenticationAsync(bool receiveFirst, byte[] buffer, CancellationToken cancellationToken)
- {
- _framing = Framing.Unknown;
- ProtocolToken message;
- SslReadAsync adapter = new SslReadAsync(this, cancellationToken);
+ if (message.Failed)
+ {
+ throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+ }
+ } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK);
- if (!receiveFirst)
- {
- message = _context.NextMessage(buffer, 0, (buffer == null ? 0 : buffer.Length));
- if (message.Failed)
+ ProtocolToken alertToken = null;
+ if (!CompleteHandshake(ref alertToken))
{
- // tracing done in NextMessage()
- throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+ SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
}
-
- await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false);
}
-
- do
+ finally
{
- message = await ReceiveBlobAsync(adapter, buffer, cancellationToken).ConfigureAwait(false);
- if (message.Size > 0)
- {
- // If there is message send it out even if call failed. It may contain TLS Alert.
- await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false);
- }
-
- if (message.Failed)
+ if (reAuthenticationData == null)
{
- throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+ _nestedAuth = 0;
}
- } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK);
-
- ProtocolToken alertToken = null;
- if (!CompleteHandshake(ref alertToken))
- {
- SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
}
if (NetEventSource.IsEnabled)
}
- //
- // Client side starts here, but server also loops through this method.
- //
- private void SendBlob(byte[] incoming, int count)
- {
- ProtocolToken message = _context.NextMessage(incoming, 0, count);
- _securityStatus = message.Status;
-
- if (message.Size != 0)
- {
- if (_context.IsServer && _CachedSession == CachedSessionStatus.Unknown)
- {
- //
- //[Schannel] If the first call to ASC returns a token less than 200 bytes,
- // then it's a reconnect (a handshake based on a cache entry).
- //
- _CachedSession = message.Size < 200 ? CachedSessionStatus.IsCached : CachedSessionStatus.IsNotCached;
- }
-
- if (_framing == Framing.Unified)
- {
- _framing = DetectFraming(message.Payload, message.Payload.Length);
- }
-
- InnerStream.Write(message.Payload, 0, message.Size);
- }
-
- CheckCompletionBeforeNextReceive(message);
- }
-
- //
- // This will check and logically complete / fail the auth handshake.
- //
- private void CheckCompletionBeforeNextReceive(ProtocolToken message)
- {
- if (message.Failed)
- {
- SendAuthResetSignal(null, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException())));
- return;
- }
- else if (message.Done)
- {
- ProtocolToken alertToken = null;
-
- if (!CompleteHandshake(ref alertToken))
- {
- SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
- return;
- }
-
- // Release waiting IO if any. Presumably it should not throw.
- // Otherwise application may get not expected type of the exception.
- FinishHandshake(null);
- return;
- }
-
- ReceiveBlob(message.Payload);
- }
-
- //
- // Server side starts here, but client also loops through this method.
- //
- private void ReceiveBlob(byte[] buffer)
- {
- //This is first server read.
- buffer = EnsureBufferSize(buffer, 0, SecureChannel.ReadHeaderSize);
-
- int readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize);
-
- if (readBytes == 0)
- {
- // EOF received
- throw new IOException(SR.net_auth_eof);
- }
-
- if (_framing == Framing.Unknown)
- {
- _framing = DetectFraming(buffer, readBytes);
- }
-
- int restBytes = GetRemainingFrameSize(buffer, 0, readBytes);
-
- if (restBytes < 0)
- {
- throw new IOException(SR.net_ssl_io_frame);
- }
-
- if (restBytes == 0)
- {
- // EOF received
- throw new AuthenticationException(SR.net_auth_eof, null);
- }
-
- buffer = EnsureBufferSize(buffer, readBytes, readBytes + restBytes);
-
- restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes);
-
- SendBlob(buffer, readBytes + restBytes);
- }
-
- private async ValueTask<ProtocolToken> ReceiveBlobAsync(SslReadAsync adapter, byte[] buffer, CancellationToken cancellationToken)
+ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter adapter)
+ where TIOAdapter : ISslIOAdapter
{
ResetReadBuffer();
int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
}
}
- private async ValueTask WriteAsyncChunked<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
- where TWriteAdapter : struct, ISslWriteAdapter
+ private async ValueTask WriteAsyncChunked<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
+ where TIOAdapter : struct, ISslIOAdapter
{
do
{
} while (buffer.Length != 0);
}
- private ValueTask WriteSingleChunk<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
- where TWriteAdapter : struct, ISslWriteAdapter
+ private ValueTask WriteSingleChunk<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
+ where TIOAdapter : struct, ISslIOAdapter
{
// Request a write IO slot.
- Task ioSlot = writeAdapter.LockAsync();
+ Task ioSlot = writeAdapter.WriteLockAsync();
if (!ioSlot.IsCompletedSuccessfully)
{
// Operation is async and has been queued, return.
return CompleteAsync(t, rentedBuffer);
}
- async ValueTask WaitForWriteIOSlot(TWriteAdapter wAdapter, Task lockTask, ReadOnlyMemory<byte> buff)
+ async ValueTask WaitForWriteIOSlot(TIOAdapter wAdapter, Task lockTask, ReadOnlyMemory<byte> buff)
{
await lockTask.ConfigureAwait(false);
await WriteSingleChunk(wAdapter, buff).ConfigureAwait(false);
}
}
- private async ValueTask<int> ReadAsyncInternal<TReadAdapter>(TReadAdapter adapter, Memory<byte> buffer)
- where TReadAdapter : ISslReadAdapter
+ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(TIOAdapter adapter, Memory<byte> buffer)
+ where TIOAdapter : ISslIOAdapter
{
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
{
return copyBytes;
}
- copyBytes = await adapter.LockAsync(buffer).ConfigureAwait(false);
+ copyBytes = await adapter.ReadLockAsync(buffer).ConfigureAwait(false);
if (copyBytes > 0)
{
return copyBytes;
throw new IOException(SR.net_ssl_io_renego);
}
- await ReplyOnReAuthenticationAsync(extraBuffer, adapter.CancellationToken).ConfigureAwait(false);
+ await ReplyOnReAuthenticationAsync(adapter, extraBuffer).ConfigureAwait(false);
// Loop on read.
continue;
}
}
}
- private ValueTask<int> FillBufferAsync<TReadAdapter>(TReadAdapter adapter, int minSize)
- where TReadAdapter : ISslReadAdapter
+ private ValueTask<int> FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
+ where TIOAdapter : ISslIOAdapter
{
if (_internalBufferCount >= minSize)
{
return new ValueTask<int>(minSize);
- async ValueTask<int> InternalFillBufferAsync(TReadAdapter adap, ValueTask<int> task, int min, int initial)
+ async ValueTask<int> InternalFillBufferAsync(TIOAdapter adap, ValueTask<int> task, int min, int initial)
{
while (true)
{
}
}
- private async ValueTask WriteAsyncInternal<TWriteAdapter>(TWriteAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
- where TWriteAdapter : struct, ISslWriteAdapter
+ private async ValueTask WriteAsyncInternal<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
+ where TIOAdapter : struct, ISslIOAdapter
{
ThrowIfExceptionalOrNotAuthenticatedOrShutdown();