return context;
}
- internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int recvOffset, int recvCount, out byte[] sendBuf, out int sendCount)
+ internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, out byte[] sendBuf, out int sendCount)
{
sendBuf = null;
sendCount = 0;
Exception handshakeException = null;
- if ((recvBuf != null) && (recvCount > 0))
+ if (input.Length > 0)
{
- if (BioWrite(context.InputBio, recvBuf, recvOffset, recvCount) <= 0)
+ if (Ssl.BioWrite(context.InputBio, ref MemoryMarshal.GetReference(input), input.Length) != input.Length)
{
// Make sure we clear out the error that is stored in the queue
throw Crypto.CreateOpenSslCryptographicException();
return stateOk;
}
- internal static int Encrypt(SafeSslHandle context, ReadOnlyMemory<byte> input, ref byte[] output, out Ssl.SslErrorCode errorCode)
+ internal static int Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref byte[] output, out Ssl.SslErrorCode errorCode)
{
#if DEBUG
ulong assertNoError = Crypto.ErrPeekError();
lock (context)
{
- unsafe
- {
- using (MemoryHandle handle = input.Pin())
- {
- retVal = Ssl.SslWrite(context, (byte*)handle.Pointer, input.Length);
- }
- }
+ retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length);
if (retVal != input.Length)
{
}
[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslWrite")]
- internal static extern unsafe int SslWrite(SafeSslHandle ssl, byte* buf, int num);
+ internal static extern unsafe int SslWrite(SafeSslHandle ssl, ref byte buf, int num);
[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslRead")]
internal static extern unsafe int SslRead(SafeSslHandle ssl, byte* buf, int num);
[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")]
internal static extern unsafe int BioWrite(SafeBioHandle b, byte* data, int len);
+ [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")]
+ internal static extern unsafe int BioWrite(SafeBioHandle b, ref byte data, int len);
+
[DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertificate")]
internal static extern SafeX509Handle SslGetPeerCertificate(SafeSslHandle ssl);
{
return true;
}
+
+ return false;
+ }
+
+ public static bool SecurityForceSocketStreams()
+ {
+ string value = Configuration.Security.SecurityForceSocketStreams;
+ if (value != null && (value.Equals("true", StringComparison.OrdinalIgnoreCase) || value.Equals("1")))
+ {
+ return true;
+ }
+
return false;
}
// 127.0.0.1 testclienteku.contoso.com
public static string HostsFileNamesInstalled => GetValue("COREFX_NET_SECURITY_HOSTS_FILE_INSTALLED");
+ // Allows packet captures.
+ public static string SecurityForceSocketStreams => GetValue("COREFX_NET_SECURITY_FORCE_SOCKET_STREAMS");
}
}
}
//
internal class AsyncProtocolRequest
{
-#if DEBUG
- internal object _DebugAsyncChain; // Optionally used to track chains of async calls.
-#endif
-
private AsyncProtocolCallback _callback;
private int _completionStatus;
public LazyAsyncResult UserAsyncResult;
public int Result;
- public object AsyncState;
public readonly CancellationToken CancellationToken;
public byte[] Buffer; // Temporary buffer reused by a protocol.
Debug.Assert(count >= 0);
Debug.Assert(count <= buf.Length - offset);
+ Write(buf.AsSpan(offset, count));
+ }
+ internal void Write(ReadOnlySpan<byte> buf)
+ {
lock (_fromConnection)
{
- for (int i = 0; i < count; i++)
+ foreach (byte b in buf)
{
- _fromConnection.Enqueue(buf[offset + i]);
+ _fromConnection.Enqueue(b);
}
}
-
}
internal int BytesReadyForConnection => _toConnection.Count;
//
// Acquire Server Side Certificate information and set it on the class.
//
- private bool AcquireServerCredentials(ref byte[] thumbPrint, byte[] clientHello)
+ private bool AcquireServerCredentials(ref byte[] thumbPrint, ReadOnlySpan<byte> clientHello)
{
if (NetEventSource.IsEnabled)
NetEventSource.Enter(this);
if (_refreshCredentialNeeded)
{
cachedCreds = _sslAuthenticationOptions.IsServer
- ? AcquireServerCredentials(ref thumbPrint, input)
+ ? AcquireServerCredentials(ref thumbPrint, new ReadOnlySpan<byte>(input, offset, count))
: AcquireClientCredentials(ref thumbPrint);
}
status = SslStreamPal.AcceptSecurityContext(
ref _credentialsHandle,
ref _securityContext,
- input != null ? new ArraySegment<byte>(input, offset, count) : default,
+ input, offset, count,
ref result,
_sslAuthenticationOptions);
}
ref _credentialsHandle,
ref _securityContext,
_sslAuthenticationOptions.TargetHost,
- input != null ? new ArraySegment<byte>(input, offset, count) : default,
+ input, offset, count,
ref result,
_sslAuthenticationOptions);
}
private static readonly IdnMapping s_idnMapping = CreateIdnMapping();
private static readonly Encoding s_encoding = CreateEncoding();
- public static string GetServerName(byte[] clientHello)
- {
- return GetSniFromSslPlainText(clientHello);
- }
-
- private static string GetSniFromSslPlainText(ReadOnlySpan<byte> sslPlainText)
+ public static string GetServerName(ReadOnlySpan<byte> sslPlainText)
{
// https://tools.ietf.org/html/rfc6101#section-5.2.1
// struct {
public partial class SslStream
{
private static int s_uniqueNameInteger = 123;
- private static readonly AsyncProtocolCallback s_partialFrameCallback = new AsyncProtocolCallback(PartialFrameCallback);
- private static readonly AsyncProtocolCallback s_readFrameCallback = new AsyncProtocolCallback(ReadFrameCallback);
- private static readonly AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback);
private SslAuthenticationOptions _sslAuthenticationOptions;
}
private CachedSessionStatus _CachedSession;
+ private enum Framing
+ {
+ Unknown = 0,
+ BeforeSSL3,
+ SinceSSL3,
+ Unified,
+ Invalid
+ }
+
+ // This is set on the first packet to figure out the framing style.
+ private Framing _framing = Framing.Unknown;
+
+ // SSL3/TLS protocol frames definitions.
+ private enum FrameType : byte
+ {
+ ChangeCipherSpec = 20,
+ Alert = 21,
+ Handshake = 22,
+ AppData = 23
+ }
+
// This block is used by re-handshake code to buffer data decrypted with the old key.
private byte[] _queuedReadData;
private int _queuedReadCount;
- private bool _pendingReHandshake;
- private const int MaxQueuedReadBytes = 1024 * 128;
//
// This block is used to rule the >>re-handshakes<< that are concurrent with read/write I/O requests.
}
//
- // Called by re-handshake if found data decrypted with the old key
- //
- private Exception EnqueueOldKeyDecryptedData(byte[] buffer, int offset, int count)
- {
- lock (SyncLock)
- {
- if (_queuedReadCount + count > MaxQueuedReadBytes)
- {
- return ExceptionDispatchInfo.SetCurrentStackTrace(
- new IOException(SR.Format(SR.net_auth_ignored_reauth, MaxQueuedReadBytes.ToString(NumberFormatInfo.CurrentInfo))));
- }
-
- if (count != 0)
- {
- // This is inefficient yet simple and that should be a rare case of receiving data encrypted with "old" key.
- _queuedReadData = EnsureBufferSize(_queuedReadData, _queuedReadCount, _queuedReadCount + count);
- Buffer.BlockCopy(buffer, offset, _queuedReadData, _queuedReadCount, count);
- _queuedReadCount += count;
- FinishHandshakeRead(LockHandshake);
- }
- }
- return null;
- }
-
- //
// When re-handshaking the "old" key decrypted data are queued until the handshake is done.
// When stream calls for decryption we will feed it queued data left from "old" encryption key.
//
// This method assumes that a SSPI context is already in a good shape.
// For example it is either a fresh context or already authenticated context that needs renegotiation.
//
- private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken cancellationToken)
+ private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default)
{
+ Task result = null;
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
{
- throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, lazyResult == null ? "BeginAuthenticate" : "Authenticate", "authenticate"));
+ throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate"));
}
try
{
ThrowIfExceptional();
- AsyncProtocolRequest asyncRequest = null;
- if (lazyResult != null)
- {
- asyncRequest = new AsyncProtocolRequest(lazyResult, cancellationToken);
- asyncRequest.Buffer = null;
-#if DEBUG
- lazyResult._debugAsyncChain = asyncRequest;
-#endif
- }
// A trick to discover and avoid cached sessions.
_CachedSession = CachedSessionStatus.Unknown;
- ForceAuthentication(_context.IsServer, null, asyncRequest);
-
- // Not aync so the connection is completed at this point.
- if (lazyResult == null && NetEventSource.IsEnabled)
+ if (isAsync)
{
+ result = ForceAuthenticationAsync(_context.IsServer, null, cancellationToken);
+ }
+ else
+ {
+ ForceAuthentication(_context.IsServer, null);
+
if (NetEventSource.IsEnabled)
NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication),
SslProtocol,
KeyExchangeStrength);
}
}
- catch (Exception)
- {
- // If an exception emerges synchronously, the asynchronous operation was not
- // initiated, so no operation is in progress.
- _nestedAuth = 0;
- throw;
- }
finally
{
- // For synchronous operations, the operation has completed.
- if (lazyResult == null)
- {
- _nestedAuth = 0;
- }
+ // Operation has completed.
+ _nestedAuth = 0;
}
+
+ return result;
}
//
// This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read().
//
- private void ReplyOnReAuthentication(byte[] buffer, CancellationToken cancellationToken)
+ private async Task ReplyOnReAuthenticationAsync(byte[] buffer, CancellationToken cancellationToken)
{
lock (SyncLock)
{
// Note we are already inside the read, so checking for already going concurrent handshake.
_lockReadState = LockHandshake;
-
- if (_pendingReHandshake)
- {
- // A concurrent handshake is pending, resume.
- FinishRead(buffer);
- return;
- }
}
- // Start rehandshake from here.
-
- // Forcing async mode. The caller will queue another Read as soon as we return using its preferred
- // calling convention, which will be woken up when the handshake completes. The callback is just
- // to capture any SocketErrors that happen during the handshake so they can be surfaced from the Read.
- AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(new LazyAsyncResult(this, null, new AsyncCallback(RehandshakeCompleteCallback)), cancellationToken);
- // Buffer contains a result from DecryptMessage that will be passed to ISC/ASC
- asyncRequest.Buffer = buffer;
- ForceAuthentication(false, buffer, asyncRequest);
+ await ForceAuthenticationAsync(false, buffer, cancellationToken).ConfigureAwait(false);
+ FinishHandshakeRead(LockNone);
}
//
// Incoming buffer is either null or is the result of "renegotiate" decrypted message
// If write is in progress the method will either wait or be put on hold
//
- private void ForceAuthentication(bool receiveFirst, byte[] buffer, AsyncProtocolRequest asyncRequest)
+ private void ForceAuthentication(bool receiveFirst, byte[] buffer)
{
- if (CheckEnqueueHandshake(buffer, asyncRequest))
- {
- // Async handshake is enqueued and will resume later.
- return;
- }
- // Either Sync handshake is ready to go or async handshake won the race over write.
-
// This will tell that we don't know the framing yet (what SSL version is)
- _Framing = Framing.Unknown;
+ _framing = Framing.Unknown;
try
{
if (receiveFirst)
{
// Listen for a client blob.
- StartReceiveBlob(buffer, asyncRequest);
+ ReceiveBlob(buffer);
}
else
{
// We start with the first blob.
- StartSendBlob(buffer, (buffer == null ? 0 : buffer.Length), asyncRequest);
+ SendBlob(buffer, (buffer == null ? 0 : buffer.Length));
}
}
catch (Exception e)
{
// Failed auth, reset the framing if any.
- _Framing = Framing.Unknown;
+ _framing = Framing.Unknown;
_handshakeCompleted = false;
SetException(e);
if (_exception != null)
{
// This a failed handshake. Release waiting IO if any.
- FinishHandshake(null, null);
+ FinishHandshake(null);
}
}
}
- private void EndProcessAuthentication(IAsyncResult result)
+ internal async Task ForceAuthenticationAsync(bool receiveFirst, byte[] buffer, CancellationToken cancellationToken)
{
- if (result == null)
- {
- throw new ArgumentNullException("asyncResult");
- }
+ _framing = Framing.Unknown;
+ ProtocolToken message;
+ SslReadAsync adapter = new SslReadAsync(this, cancellationToken);
- LazyAsyncResult lazyResult = result as LazyAsyncResult;
- if (lazyResult == null)
+ if (!receiveFirst)
{
- throw new ArgumentException(SR.Format(SR.net_io_async_result, result.GetType().FullName), "asyncResult");
+ message = _context.NextMessage(buffer, 0, (buffer == null ? 0 : buffer.Length));
+ if (message.Failed)
+ {
+ // tracing done in NextMessage()
+ throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+ }
+
+ await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false);
}
- if (Interlocked.Exchange(ref _nestedAuth, 0) == 0)
+ do
{
- throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndAuthenticate"));
- }
+ message = await ReceiveBlobAsync(adapter, buffer, cancellationToken).ConfigureAwait(false);
+ if (message.Size > 0)
+ {
+ // If there is message send it out even if call failed. It may contain TLS Alert.
+ await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false);
+ }
- InternalEndProcessAuthentication(lazyResult);
+ if (message.Failed)
+ {
+ throw new AuthenticationException(SR.net_auth_SSPI, message.GetException());
+ }
+ } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK);
- // Connection is completed at this point.
- if (NetEventSource.IsEnabled)
+ ProtocolToken alertToken = null;
+ if (!CompleteHandshake(ref alertToken))
{
- if (NetEventSource.IsEnabled)
- NetEventSource.Log.SspiSelectedCipherSuite(nameof(EndProcessAuthentication),
- SslProtocol,
- CipherAlgorithm,
- CipherStrength,
- HashAlgorithm,
- HashStrength,
- KeyExchangeAlgorithm,
- KeyExchangeStrength);
+ SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
}
- }
-
- private void InternalEndProcessAuthentication(LazyAsyncResult lazyResult)
- {
- // No "artificial" timeouts implemented so far, InnerStream controls that.
- lazyResult.InternalWaitForCompletion();
- Exception e = lazyResult.Result as Exception;
- if (e != null)
- {
- // Failed auth, reset the framing if any.
- _Framing = Framing.Unknown;
- _handshakeCompleted = false;
+ if (NetEventSource.IsEnabled)
+ NetEventSource.Log.SspiSelectedCipherSuite(nameof(ForceAuthenticationAsync),
+ SslProtocol,
+ CipherAlgorithm,
+ CipherStrength,
+ HashAlgorithm,
+ HashStrength,
+ KeyExchangeAlgorithm,
+ KeyExchangeStrength);
- SetException(e);
- ThrowIfExceptional();
- }
}
//
// Client side starts here, but server also loops through this method.
//
- private void StartSendBlob(byte[] incoming, int count, AsyncProtocolRequest asyncRequest)
+ private void SendBlob(byte[] incoming, int count)
{
ProtocolToken message = _context.NextMessage(incoming, 0, count);
_securityStatus = message.Status;
_CachedSession = message.Size < 200 ? CachedSessionStatus.IsCached : CachedSessionStatus.IsNotCached;
}
- if (_Framing == Framing.Unified)
+ if (_framing == Framing.Unified)
{
- _Framing = DetectFraming(message.Payload, message.Payload.Length);
+ _framing = DetectFraming(message.Payload, message.Payload.Length);
}
- if (asyncRequest == null)
- {
- InnerStream.Write(message.Payload, 0, message.Size);
- }
- else
- {
- asyncRequest.AsyncState = message;
- Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size, asyncRequest.CancellationToken);
- if (t.IsCompleted)
- {
- t.GetAwaiter().GetResult();
- }
- else
- {
- IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest);
- if (!ar.CompletedSynchronously)
- {
-#if DEBUG
- asyncRequest._DebugAsyncChain = ar;
-#endif
- return;
- }
- TaskToApm.End(ar);
- }
- }
+ InnerStream.Write(message.Payload, 0, message.Size);
}
- CheckCompletionBeforeNextReceive(message, asyncRequest);
+ CheckCompletionBeforeNextReceive(message);
}
//
// This will check and logically complete / fail the auth handshake.
//
- private void CheckCompletionBeforeNextReceive(ProtocolToken message, AsyncProtocolRequest asyncRequest)
+ private void CheckCompletionBeforeNextReceive(ProtocolToken message)
{
if (message.Failed)
{
- StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException())));
+ SendAuthResetSignal(null, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException())));
return;
}
- else if (message.Done && !_pendingReHandshake)
+ else if (message.Done)
{
ProtocolToken alertToken = null;
if (!CompleteHandshake(ref alertToken))
{
- StartSendAuthResetSignal(alertToken, asyncRequest, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
+ SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null)));
return;
}
// Release waiting IO if any. Presumably it should not throw.
// Otherwise application may get not expected type of the exception.
- FinishHandshake(null, asyncRequest);
+ FinishHandshake(null);
return;
}
- StartReceiveBlob(message.Payload, asyncRequest);
+ ReceiveBlob(message.Payload);
}
//
// Server side starts here, but client also loops through this method.
//
- private void StartReceiveBlob(byte[] buffer, AsyncProtocolRequest asyncRequest)
+ private void ReceiveBlob(byte[] buffer)
{
- if (_pendingReHandshake)
- {
- if (CheckEnqueueHandshakeRead(ref buffer, asyncRequest))
- {
- return;
- }
-
- if (!_pendingReHandshake)
- {
- // Renegotiate: proceed to the next step.
- ProcessReceivedBlob(buffer, buffer.Length, asyncRequest);
- return;
- }
- }
-
//This is first server read.
buffer = EnsureBufferSize(buffer, 0, SecureChannel.ReadHeaderSize);
- int readBytes = 0;
- if (asyncRequest == null)
- {
- readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize);
- }
- else
- {
- asyncRequest.SetNextRequest(buffer, 0, SecureChannel.ReadHeaderSize, s_partialFrameCallback);
- _ = FixedSizeReader.ReadPacketAsync(_innerStream, asyncRequest);
- if (!asyncRequest.MustCompleteSynchronously)
- {
- return;
- }
-
- readBytes = asyncRequest.Result;
- }
+ int readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize);
- StartReadFrame(buffer, readBytes, asyncRequest);
- }
-
- //
- private void StartReadFrame(byte[] buffer, int readBytes, AsyncProtocolRequest asyncRequest)
- {
if (readBytes == 0)
{
// EOF received
throw new IOException(SR.net_auth_eof);
}
- if (_Framing == Framing.Unknown)
+ if (_framing == Framing.Unknown)
{
- _Framing = DetectFraming(buffer, readBytes);
+ _framing = DetectFraming(buffer, readBytes);
}
int restBytes = GetRemainingFrameSize(buffer, 0, readBytes);
buffer = EnsureBufferSize(buffer, readBytes, readBytes + restBytes);
- if (asyncRequest == null)
- {
- restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes);
- }
- else
- {
- asyncRequest.SetNextRequest(buffer, readBytes, restBytes, s_readFrameCallback);
- _ = FixedSizeReader.ReadPacketAsync(_innerStream, asyncRequest);
- if (!asyncRequest.MustCompleteSynchronously)
- {
- return;
- }
+ restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes);
- restBytes = asyncRequest.Result;
- if (restBytes == 0)
- {
- //EOF received: fail.
- readBytes = 0;
- }
- }
- ProcessReceivedBlob(buffer, readBytes + restBytes, asyncRequest);
+ SendBlob(buffer, readBytes + restBytes);
}
- private void ProcessReceivedBlob(byte[] buffer, int count, AsyncProtocolRequest asyncRequest)
+ private async ValueTask<ProtocolToken> ReceiveBlobAsync(SslReadAsync adapter, byte[] buffer, CancellationToken cancellationToken)
{
- if (count == 0)
+ ResetReadBuffer();
+ int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
+ if (readBytes == 0)
{
- // EOF received.
- throw new AuthenticationException(SR.net_auth_eof, null);
+ throw new IOException(SR.net_io_eof);
}
- if (_pendingReHandshake)
+ if (_framing == Framing.Unified || _framing == Framing.Unknown)
{
- int offset = 0;
- SecurityStatusPal status = PrivateDecryptData(buffer, ref offset, ref count);
+ _framing = DetectFraming(_internalBuffer, readBytes);
+ }
- if (status.ErrorCode == SecurityStatusPalErrorCode.OK)
- {
- Exception e = EnqueueOldKeyDecryptedData(buffer, offset, count);
- if (e != null)
- {
- StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(e));
- return;
- }
+ int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes);
+ if (payloadBytes < 0)
+ {
+ throw new IOException(SR.net_frame_read_size);
+ }
- _Framing = Framing.Unknown;
- StartReceiveBlob(buffer, asyncRequest);
- return;
- }
- else if (status.ErrorCode != SecurityStatusPalErrorCode.Renegotiate)
- {
- // Fail re-handshake.
- ProtocolToken message = new ProtocolToken(null, status);
- StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(
- ExceptionDispatchInfo.SetCurrentStackTrace(new AuthenticationException(SR.net_auth_SSPI, message.GetException()))));
- return;
- }
+ int frameSize = SecureChannel.ReadHeaderSize + payloadBytes;
- // We expect only handshake messages from now.
- _pendingReHandshake = false;
- if (offset != 0)
+ if (readBytes < frameSize)
+ {
+ readBytes = await FillBufferAsync(adapter, frameSize).ConfigureAwait(false);
+ Debug.Assert(readBytes >= 0);
+ if (readBytes == 0)
{
- Buffer.BlockCopy(buffer, offset, buffer, 0, count);
+ throw new IOException(SR.net_io_eof);
}
}
- StartSendBlob(buffer, count, asyncRequest);
+ ProtocolToken token = _context.NextMessage(_internalBuffer, _internalOffset, frameSize);
+ ConsumeBufferedBytes(frameSize);
+
+ return token;
}
//
// This is to reset auth state on remote side.
// If this write succeeds we will allow auth retrying.
//
- private void StartSendAuthResetSignal(ProtocolToken message, AsyncProtocolRequest asyncRequest, ExceptionDispatchInfo exception)
+ private void SendAuthResetSignal(ProtocolToken message, ExceptionDispatchInfo exception)
{
+ SetException(exception.SourceException);
+
if (message == null || message.Size == 0)
{
//
exception.Throw();
}
- if (asyncRequest == null)
- {
- InnerStream.Write(message.Payload, 0, message.Size);
- }
- else
- {
- asyncRequest.AsyncState = exception;
- Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size, asyncRequest.CancellationToken);
- if (t.IsCompleted)
- {
- t.GetAwaiter().GetResult();
- }
- else
- {
- IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest);
- if (!ar.CompletedSynchronously)
- {
- return;
- }
- TaskToApm.End(ar);
- }
- }
+ InnerStream.Write(message.Payload, 0, message.Size);
exception.Throw();
}
return true;
}
- private static void WriteCallback(IAsyncResult transportResult)
- {
- if (transportResult.CompletedSynchronously)
- {
- return;
- }
-
- AsyncProtocolRequest asyncRequest;
- SslStream sslState;
-
-#if DEBUG
- try
- {
-#endif
- asyncRequest = (AsyncProtocolRequest)transportResult.AsyncState;
- sslState = (SslStream)asyncRequest.AsyncObject;
-#if DEBUG
- }
- catch (Exception exception) when (!ExceptionCheck.IsFatal(exception))
- {
- NetEventSource.Fail(null, $"Exception while decoding context: {exception}");
- throw;
- }
-#endif
-
- // Async completion.
- try
- {
- TaskToApm.End(transportResult);
-
- // Special case for an error notification.
- object asyncState = asyncRequest.AsyncState;
- ExceptionDispatchInfo exception = asyncState as ExceptionDispatchInfo;
- if (exception != null)
- {
- exception.Throw();
- }
-
- sslState.CheckCompletionBeforeNextReceive((ProtocolToken)asyncState, asyncRequest);
- }
- catch (Exception e)
- {
- if (asyncRequest.IsUserCompleted)
- {
- // This will throw on a worker thread.
- throw;
- }
-
- sslState.FinishHandshake(e, asyncRequest);
- }
- }
-
- private static void PartialFrameCallback(AsyncProtocolRequest asyncRequest)
- {
- if (NetEventSource.IsEnabled)
- NetEventSource.Enter(null);
-
- // Async ONLY completion.
- SslStream sslState = (SslStream)asyncRequest.AsyncObject;
- try
- {
- sslState.StartReadFrame(asyncRequest.Buffer, asyncRequest.Result, asyncRequest);
- }
- catch (Exception e)
- {
- if (asyncRequest.IsUserCompleted)
- {
- // This will throw on a worker thread.
- throw;
- }
-
- sslState.FinishHandshake(e, asyncRequest);
- }
- }
-
- //
- //
- private static void ReadFrameCallback(AsyncProtocolRequest asyncRequest)
- {
- if (NetEventSource.IsEnabled)
- NetEventSource.Enter(null);
-
- // Async ONLY completion.
- SslStream sslState = (SslStream)asyncRequest.AsyncObject;
- try
- {
- if (asyncRequest.Result == 0)
- {
- //EOF received: will fail.
- asyncRequest.Offset = 0;
- }
-
- sslState.ProcessReceivedBlob(asyncRequest.Buffer, asyncRequest.Offset + asyncRequest.Result, asyncRequest);
- }
- catch (Exception e)
- {
- if (asyncRequest.IsUserCompleted)
- {
- // This will throw on a worker thread.
- throw;
- }
-
- sslState.FinishHandshake(e, asyncRequest);
- }
- }
-
- private bool CheckEnqueueHandshakeRead(ref byte[] buffer, AsyncProtocolRequest request)
- {
- LazyAsyncResult lazyResult = null;
- lock (SyncLock)
- {
- if (_lockReadState == LockPendingRead)
- {
- return false;
- }
-
- int lockState = Interlocked.Exchange(ref _lockReadState, LockHandshake);
- if (lockState != LockRead)
- {
- return false;
- }
-
- if (request != null)
- {
- _queuedReadStateRequest = request;
- return true;
- }
-
- lazyResult = new LazyAsyncResult(null, null, /*must be */ null);
- _queuedReadStateRequest = lazyResult;
- }
-
- // Need to exit from lock before waiting.
- lazyResult.InternalWaitForCompletion();
- buffer = (byte[])lazyResult.Result;
- return false;
- }
-
private void FinishHandshakeRead(int newState)
{
lock (SyncLock)
}
_lockReadState = LockRead;
- HandleQueuedCallback(ref _queuedReadStateRequest);
}
}
}
}
- private void FinishRead(byte[] renegotiateBuffer)
- {
- int lockState = Interlocked.CompareExchange(ref _lockReadState, LockNone, LockRead);
-
- if (lockState != LockHandshake)
- {
- return;
- }
-
- lock (SyncLock)
- {
- LazyAsyncResult ar = _queuedReadStateRequest as LazyAsyncResult;
- if (ar != null)
- {
- _queuedReadStateRequest = null;
- ar.InvokeCallback(renegotiateBuffer);
- }
- else
- {
- AsyncProtocolRequest request = (AsyncProtocolRequest)_queuedReadStateRequest;
- request.Buffer = renegotiateBuffer;
- _queuedReadStateRequest = null;
- ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshakeRead(s.request), (sslState: this, request), preferLocal: false);
- }
- }
- }
-
private Task CheckEnqueueWriteAsync()
{
// Clear previous request.
{
return;
}
-
- lock (SyncLock)
- {
- HandleQueuedCallback(ref _queuedWriteStateRequest);
- }
- }
-
- private void HandleQueuedCallback(ref object queuedStateRequest)
- {
- object obj = queuedStateRequest;
- if (obj == null)
- {
- return;
- }
- queuedStateRequest = null;
-
- switch (obj)
- {
- case LazyAsyncResult lazy:
- lazy.InvokeCallback();
- break;
- case TaskCompletionSource<int> taskCompletionSource when taskCompletionSource.Task.AsyncState != null:
- Memory<byte> array = (Memory<byte>)taskCompletionSource.Task.AsyncState;
- int oldKeyResult = -1;
- try
- {
- oldKeyResult = CheckOldKeyDecryptedData(array);
- }
- catch (Exception exc)
- {
- taskCompletionSource.SetException(exc);
- break;
- }
- taskCompletionSource.SetResult(oldKeyResult);
- break;
- case TaskCompletionSource<int> taskCompletionSource:
- taskCompletionSource.SetResult(0);
- break;
- default:
- ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshake(s.obj), (sslState: this, obj), preferLocal: false);
- break;
- }
}
- // Returns:
- // true - operation queued
- // false - operation can proceed
- private bool CheckEnqueueHandshake(byte[] buffer, AsyncProtocolRequest asyncRequest)
+ private void FinishHandshake(Exception e)
{
- LazyAsyncResult lazyResult = null;
-
lock (SyncLock)
{
- if (_lockWriteState == LockPendingWrite)
+ if (e != null)
{
- return false;
+ SetException(e);
}
- int lockState = Interlocked.Exchange(ref _lockWriteState, LockHandshake);
- if (lockState != LockWrite)
- {
- // Proceed with handshake.
- return false;
- }
+ // Release read if any.
+ FinishHandshakeRead(LockNone);
- if (asyncRequest != null)
+ // If there is a pending write we want to keep it's lock state.
+ int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake);
+ if (lockState != LockPendingWrite)
{
- asyncRequest.Buffer = buffer;
- _queuedWriteStateRequest = asyncRequest;
- return true;
+ return;
}
- lazyResult = new LazyAsyncResult(null, null, /*must be*/null);
- _queuedWriteStateRequest = lazyResult;
- }
- lazyResult.InternalWaitForCompletion();
- return false;
- }
-
- private void FinishHandshake(Exception e, AsyncProtocolRequest asyncRequest)
- {
- try
- {
- lock (SyncLock)
- {
- if (e != null)
- {
- SetException(e);
- }
-
- // Release read if any.
- FinishHandshakeRead(LockNone);
-
- // If there is a pending write we want to keep it's lock state.
- int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake);
- if (lockState != LockPendingWrite)
- {
- return;
- }
-
- _lockWriteState = LockWrite;
- HandleQueuedCallback(ref _queuedWriteStateRequest);
- }
- }
- finally
- {
- if (asyncRequest != null)
- {
- if (e != null)
- {
- asyncRequest.CompleteUserWithError(e);
- }
- else
- {
- asyncRequest.CompleteUser();
- }
- }
+ _lockWriteState = LockWrite;
}
}
{
copyBytes = CopyDecryptedData(buffer);
- FinishRead(null);
-
return copyBytes;
}
{
if (!_sslAuthenticationOptions.AllowRenegotiation)
{
+ if (NetEventSource.IsEnabled) NetEventSource.Fail(this, "Renegotiation was requested but it is disallowed");
throw new IOException(SR.net_ssl_io_renego);
}
- ReplyOnReAuthentication(extraBuffer, adapter.CancellationToken);
-
+ await ReplyOnReAuthenticationAsync(extraBuffer, adapter.CancellationToken).ConfigureAwait(false);
// Loop on read.
continue;
}
if (message.CloseConnection)
{
- FinishRead(null);
return 0;
}
}
catch (Exception e)
{
- FinishRead(null);
-
if (e is IOException || (e is OperationCanceledException && adapter.CancellationToken.IsCancellationRequested))
{
throw;
_decryptedBytesOffset += copyBytes;
_decryptedBytesCount -= copyBytes;
}
+
ReturnReadBufferIfEmpty();
return copyBytes;
}
return buffer;
}
- private enum Framing
- {
- Unknown = 0,
- BeforeSSL3,
- SinceSSL3,
- Unified,
- Invalid
- }
-
- // This is set on the first packet to figure out the framing style.
- private Framing _Framing = Framing.Unknown;
-
- // SSL3/TLS protocol frames definitions.
- private enum FrameType : byte
- {
- ChangeCipherSpec = 20,
- Alert = 21,
- Handshake = 22,
- AppData = 23
- }
-
// We need at least 5 bytes to determine what we have.
private Framing DetectFraming(byte[] bytes, int length)
{
// If this is the first packet, the client may start with an SSL2 packet
// but stating that the version is 3.x, so check the full range.
// For the subsequent packets we assume that an SSL2 packet should have a 2.x version.
- if (_Framing == Framing.Unknown)
+ if (_framing == Framing.Unknown)
{
if (version != 0x0002 && (version < 0x200 || version >= 0x500))
{
}
// When server has replied the framing is already fixed depending on the prior client packet
- if (!_context.IsServer || _Framing == Framing.Unified)
+ if (!_context.IsServer || _framing == Framing.Unified)
{
return Framing.BeforeSSL3;
}
NetEventSource.Enter(this, buffer, offset, dataSize);
int payloadSize = -1;
- switch (_Framing)
+ switch (_framing)
{
case Framing.Unified:
case Framing.BeforeSSL3:
NetEventSource.Exit(this, payloadSize);
return payloadSize;
}
-
- //
- // Called with no user stack.
- //
- private void AsyncResumeHandshake(object state)
- {
- AsyncProtocolRequest request = state as AsyncProtocolRequest;
- Debug.Assert(request != null, "Expected an AsyncProtocolRequest reference.");
-
- try
- {
- ForceAuthentication(_context.IsServer, request.Buffer, request);
- }
- catch (Exception e)
- {
- request.CompleteUserWithError(e);
- }
- }
-
- //
- // Called with no user stack.
- //
- private void AsyncResumeHandshakeRead(AsyncProtocolRequest asyncRequest)
- {
- try
- {
- if (_pendingReHandshake)
- {
- // Resume as read a blob.
- StartReceiveBlob(asyncRequest.Buffer, asyncRequest);
- }
- else
- {
- // Resume as process the blob.
- ProcessReceivedBlob(asyncRequest.Buffer, asyncRequest.Buffer == null ? 0 : asyncRequest.Buffer.Length, asyncRequest);
- }
- }
- catch (Exception e)
- {
- if (asyncRequest.IsUserCompleted)
- {
- // This will throw on a worker thread.
- throw;
- }
-
- FinishHandshake(e, asyncRequest);
- }
- }
-
- private void RehandshakeCompleteCallback(IAsyncResult result)
- {
- LazyAsyncResult lazyAsyncResult = (LazyAsyncResult)result;
- if (lazyAsyncResult == null)
- {
- NetEventSource.Fail(this, "result is null!");
- }
-
- if (!lazyAsyncResult.InternalPeekCompleted)
- {
- NetEventSource.Fail(this, "result is not completed!");
- }
-
- // If the rehandshake succeeded, FinishHandshake has already been called; if there was a SocketException
- // during the handshake, this gets called directly from FixedSizeReader, and we need to call
- // FinishHandshake to wake up the Read that triggered this rehandshake so the error gets back to the caller
- Exception exception = lazyAsyncResult.InternalWaitForCompletion() as Exception;
- if (exception != null)
- {
- // We may be calling FinishHandshake reentrantly, as FinishHandshake can call
- // asyncRequest.CompleteWithError, which will result in this method being called.
- // This is not a problem because:
- //
- // 1. We pass null as the asyncRequest parameter, so this second call to FinishHandshake won't loop
- // back here.
- //
- // 2. _QueuedWriteStateRequest and _QueuedReadStateRequest are set to null after the first call,
- // so, we won't invoke their callbacks again.
- //
- // 3. SetException won't overwrite an already-set _Exception.
- //
- // 4. There are three possibilities for _LockReadState and _LockWriteState:
- //
- // a. They were set back to None by the first call to FinishHandshake, and this will set them to
- // None again: a no-op.
- //
- // b. They were set to None by the first call to FinishHandshake, but as soon as the lock was given
- // up, another thread took a read/write lock. Calling FinishHandshake again will set them back
- // to None, but that's fine because that thread will be throwing _Exception before it actually
- // does any reading or writing and setting them back to None in a catch block anyways.
- //
- // c. If there is a Read/Write going on another thread, and the second FinishHandshake clears its
- // read/write lock, it's fine because no other Read/Write can look at the lock until the current
- // one gives up _SslStream._NestedRead/Write, and no handshake will look at the lock because
- // handshakes are only triggered in response to successful reads (which won't happen once
- // _Exception is set).
-
- FinishHandshake(exception, null);
- }
- }
}
}
return BeginAuthenticateAsClient(options, CancellationToken.None, asyncCallback, asyncState);
}
- internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState)
- {
- SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
- SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
-
- ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+ internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) =>
+ TaskToApm.Begin(AuthenticateAsClientApm(sslClientAuthenticationOptions, cancellationToken), asyncCallback, asyncState);
- LazyAsyncResult result = new LazyAsyncResult(this, asyncState, asyncCallback);
- ProcessAuthentication(result, cancellationToken);
- return result;
- }
-
- public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => EndProcessAuthentication(asyncResult);
+ public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
//
// Server side auth.
{
return BeginAuthenticateAsServer(serverCertificate, false, SecurityProtocol.SystemDefaultSecurityProtocols, false,
asyncCallback,
- asyncState);
+ asyncState);
}
public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCertificate, bool clientCertificateRequired,
return BeginAuthenticateAsServer(options, CancellationToken.None, asyncCallback, asyncState);
}
- private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState)
- {
- SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
-
- ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
-
- LazyAsyncResult result = new LazyAsyncResult(this, asyncState, asyncCallback);
- ProcessAuthentication(result, cancellationToken);
- return result;
- }
-
- public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => EndProcessAuthentication(asyncResult);
-
- internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState)
- {
- ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
+ private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) =>
+ TaskToApm.Begin(AuthenticateAsServerApm(sslServerAuthenticationOptions, cancellationToken), asyncCallback, asyncState);
- ProtocolToken message = _context.CreateShutdownToken();
- return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState);
- }
+ public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
- internal void EndShutdown(IAsyncResult asyncResult)
- {
- ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
+ internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState) => TaskToApm.Begin(ShutdownAsync(), asyncCallback, asyncState);
- TaskToApm.End(asyncResult);
- _shutdown = true;
- }
+ internal void EndShutdown(IAsyncResult asyncResult) => TaskToApm.End(asyncResult);
public TransportContext TransportContext => new SslStreamContext(this);
SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
- ProcessAuthentication(null, default);
+ ProcessAuthentication();
}
public virtual void AuthenticateAsServer(X509Certificate serverCertificate)
SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
- ProcessAuthentication(null, default);
+ ProcessAuthentication();
}
#endregion
#region Task-based async public methods
- public virtual Task AuthenticateAsClientAsync(string targetHost) =>
- Task.Factory.FromAsync(
- (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, callback, state),
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
- targetHost,
- this);
-
- public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, bool checkCertificateRevocation) =>
- Task.Factory.FromAsync(
- (arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, SecurityProtocol.SystemDefaultSecurityProtocols, arg3, callback, state),
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
- targetHost, clientCertificates, checkCertificateRevocation,
- this);
+ public virtual Task AuthenticateAsClientAsync(string targetHost) => AuthenticateAsClientAsync(targetHost, null, false);
+
+ public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, bool checkCertificateRevocation) => AuthenticateAsClientAsync(targetHost, clientCertificates, SecurityProtocol.SystemDefaultSecurityProtocols, checkCertificateRevocation);
public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, SslProtocols enabledSslProtocols, bool checkCertificateRevocation)
{
- var beginMethod = checkCertificateRevocation ? (Func<string, X509CertificateCollection, SslProtocols, AsyncCallback, object, IAsyncResult>)
- ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, arg3, true, callback, state)) :
- ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, arg3, false, callback, state));
- return Task.Factory.FromAsync(
- beginMethod,
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
- targetHost, clientCertificates, enabledSslProtocols,
- this);
+ SslClientAuthenticationOptions options = new SslClientAuthenticationOptions()
+ {
+ TargetHost = targetHost,
+ ClientCertificates = clientCertificates,
+ EnabledSslProtocols = enabledSslProtocols,
+ CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
+ EncryptionPolicy = _encryptionPolicy,
+ };
+
+ return AuthenticateAsClientAsync(options);
}
public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default)
{
- return Task.Factory.FromAsync(
- (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, callback, state),
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar),
- sslClientAuthenticationOptions, cancellationToken,
- this);
+ SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
+ SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
+
+ ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+
+ return ProcessAuthentication(true, false, cancellationToken);
+ }
+
+ private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default)
+ {
+ SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback);
+ SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback);
+
+ ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate);
+
+ return ProcessAuthentication(true, true, cancellationToken);
}
public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate) =>
- Task.Factory.FromAsync(
- (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, callback, state),
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
- serverCertificate,
- this);
-
- public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) =>
- Task.Factory.FromAsync(
- (arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, SecurityProtocol.SystemDefaultSecurityProtocols, arg3, callback, state),
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
- serverCertificate, clientCertificateRequired, checkCertificateRevocation,
- this);
+ AuthenticateAsServerAsync(serverCertificate, false, SecurityProtocol.SystemDefaultSecurityProtocols, false);
+
+ public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation)
+ {
+ SslServerAuthenticationOptions options = new SslServerAuthenticationOptions
+ {
+ ServerCertificate = serverCertificate,
+ ClientCertificateRequired = clientCertificateRequired,
+ CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
+ EncryptionPolicy = _encryptionPolicy,
+ };
+
+ return AuthenticateAsServerAsync(options);
+ }
public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, SslProtocols enabledSslProtocols, bool checkCertificateRevocation)
{
- var beginMethod = checkCertificateRevocation ? (Func<X509Certificate, bool, SslProtocols, AsyncCallback, object, IAsyncResult>)
- ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, arg3, true, callback, state)) :
- ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, arg3, false, callback, state));
- return Task.Factory.FromAsync(
- beginMethod,
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
- serverCertificate, clientCertificateRequired, enabledSslProtocols,
- this);
+ SslServerAuthenticationOptions options = new SslServerAuthenticationOptions
+ {
+ ServerCertificate = serverCertificate,
+ ClientCertificateRequired = clientCertificateRequired,
+ EnabledSslProtocols = enabledSslProtocols,
+ CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck,
+ EncryptionPolicy = _encryptionPolicy,
+ };
+
+ return AuthenticateAsServerAsync(options);
}
public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default)
{
- return Task.Factory.FromAsync(
- (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, callback, state),
- iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar),
- sslServerAuthenticationOptions, cancellationToken,
- this);
+ SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
+ ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
+
+ return ProcessAuthentication(true, false, cancellationToken);
}
- public virtual Task ShutdownAsync() =>
- Task.Factory.FromAsync(
- (callback, state) => ((SslStream)state).BeginShutdown(callback, state),
- iar => ((SslStream)iar.AsyncState).EndShutdown(iar),
- this);
+ private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default)
+ {
+ SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback);
+ ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions));
+
+ return ProcessAuthentication(true, true, cancellationToken);
+ }
+
+ public virtual Task ShutdownAsync()
+ {
+ ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
+
+ ProtocolToken message = _context.CreateShutdownToken();
+ _shutdown = true;
+ return InnerStream.WriteAsync(message.Payload, default).AsTask();
+ }
#endregion
public override bool IsAuthenticated => _context != null && _context.IsValidContext && _exception == null && _handshakeCompleted;
public static SecurityStatusPal AcceptSecurityContext(
ref SafeFreeCredentials credential,
ref SafeDeleteSslContext context,
- ArraySegment<byte> inputBuffer,
+ byte[] inputBuffer, int offset, int count,
ref byte[] outputBuffer,
SslAuthenticationOptions sslAuthenticationOptions)
{
- return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+ return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
}
public static SecurityStatusPal InitializeSecurityContext(
ref SafeFreeCredentials credential,
ref SafeDeleteSslContext context,
string targetName,
- ArraySegment<byte> inputBuffer,
+ byte[] inputBuffer, int offset, int count,
ref byte[] outputBuffer,
SslAuthenticationOptions sslAuthenticationOptions)
{
- return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+ return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
}
public static SafeFreeCredentials AcquireCredentialsHandle(
private static SecurityStatusPal HandshakeInternal(
SafeFreeCredentials credential,
ref SafeDeleteSslContext context,
- ArraySegment<byte> inputBuffer,
+ ReadOnlySpan<byte> inputBuffer,
ref byte[] outputBuffer,
SslAuthenticationOptions sslAuthenticationOptions)
{
}
}
- if (inputBuffer.Array != null && inputBuffer.Count > 0)
+ if (inputBuffer.Length > 0)
{
- sslContext.Write(inputBuffer.Array, inputBuffer.Offset, inputBuffer.Count);
+ sslContext.Write(inputBuffer);
}
SafeSslHandle sslHandle = sslContext.SslContext;
}
public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context,
- ArraySegment<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+ byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
{
- return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+ return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
}
public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName,
- ArraySegment<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+ byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
{
- return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions);
+ return HandshakeInternal(credential, ref context, new ReadOnlySpan<byte>(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions);
}
public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate,
}
private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credential, ref SafeDeleteSslContext context,
- ArraySegment<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+ ReadOnlySpan<byte> inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
{
Debug.Assert(!credential.IsInvalid);
context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions);
}
- bool done;
-
- if (inputBuffer.Array == null)
- {
- done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, null, 0, 0, out output, out outputSize);
- }
- else
- {
- done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer.Array, inputBuffer.Offset, inputBuffer.Count, out output, out outputSize);
- }
+ bool done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer, out output, out outputSize);
// When the handshake is done, and the context is server, check if the alpnHandle target was set to null during ALPN.
// If it was, then that indicates ALPN failed, send failure.
if (encrypt)
{
- resultSize = Interop.OpenSsl.Encrypt(scHandle, input, ref output, out errorCode);
+ resultSize = Interop.OpenSsl.Encrypt(scHandle, input.Span, ref output, out errorCode);
}
else
{
return Interop.Sec_Application_Protocols.ToByteArray(protocols);
}
- public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, ArraySegment<byte> input, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+ public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
{
Interop.SspiCli.ContextFlags unusedAttributes = default;
+ ArraySegment<byte> input = inputBuffer != null ? new ArraySegment<byte>(inputBuffer, offset, count) : default;
ThreeSecurityBuffers threeSecurityBuffers = default;
SecurityBuffer? incomingSecurity = input.Array != null ?
return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode);
}
- public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, ArraySegment<byte> input, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
+ public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
{
Interop.SspiCli.ContextFlags unusedAttributes = default;
+ ArraySegment<byte> input = inputBuffer != null ? new ArraySegment<byte>(inputBuffer, offset, count) : default;
ThreeSecurityBuffers threeSecurityBuffers = default;
SecurityBuffer? incomingSecurity = input.Array != null ?
return new CipherSuitesPolicy(cipherSuites);
}
- private static async Task<Exception> WaitForSecureConnection(VirtualNetwork connection, Func<Task> server, Func<Task> client)
+ private static async Task<Exception> WaitForSecureConnection(SslStream client, SslClientAuthenticationOptions clientOptions, SslStream server, SslServerAuthenticationOptions serverOptions)
{
Task serverTask = null;
Task clientTask = null;
// check if failed synchronously
try
{
- serverTask = server();
- clientTask = client();
+ serverTask = server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None);
+ clientTask = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
}
catch (Exception e)
{
- connection.BreakConnection();
+ client.Close();
+ server.Close();
if (!(e is AuthenticationException || e is Win32Exception))
{
catch (AuthenticationException) { }
catch (Win32Exception) { }
catch (VirtualNetwork.VirtualNetworkConnectionBroken) { }
+ catch (IOException) { }
}
return e;
// Now we expect both sides to fail or both to succeed
Exception failure = null;
+ Task task = null;
try
{
- await serverTask.ConfigureAwait(false);
+ task = await Task.WhenAny(serverTask, clientTask).TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds).ConfigureAwait(false);
+ await task ;
}
catch (Exception e) when (e is AuthenticationException || e is Win32Exception)
{
failure = e;
-
// avoid client waiting for server's response
- connection.BreakConnection();
+ if (task == serverTask)
+ {
+ server.Close();
+ }
+ else
+ {
+ client.Close();
+ }
}
try
{
- await clientTask.ConfigureAwait(false);
+ // Now wait for the other task to finish.
+ task = (task == serverTask ? clientTask : serverTask);
+ await task.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds).ConfigureAwait(false);
// Fail if server has failed but client has succeeded
Assert.Null(failure);
}
- catch (Exception e) when (e is VirtualNetwork.VirtualNetworkConnectionBroken || e is AuthenticationException || e is Win32Exception)
+ catch (Exception e) when (e is VirtualNetwork.VirtualNetworkConnectionBroken || e is AuthenticationException || e is Win32Exception || e is IOException)
{
// Fail if server has succeeded but client has failed
Assert.NotNull(failure);
- if (e.GetType() != typeof(VirtualNetwork.VirtualNetworkConnectionBroken))
+ if (e.GetType() != typeof(VirtualNetwork.VirtualNetworkConnectionBroken) && e.GetType() != typeof(IOException))
{
failure = new AggregateException(new Exception[] { failure, e });
}
private static NegotiatedParams ConnectAndGetNegotiatedParams(ConnectionParams serverParams, ConnectionParams clientParams)
{
- VirtualNetwork vn = new VirtualNetwork();
- using (VirtualNetworkStream serverStream = new VirtualNetworkStream(vn, isServer: true),
- clientStream = new VirtualNetworkStream(vn, isServer: false))
+ (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
+
+ using (clientStream)
+ using (serverStream)
using (SslStream server = new SslStream(serverStream, leaveInnerStreamOpen: false),
client = new SslStream(clientStream, leaveInnerStreamOpen: false))
{
return true;
});
- Func<Task> serverTask = () => server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None);
- Func<Task> clientTask = () => client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None);
-
- Exception failure = WaitForSecureConnection(vn, serverTask, clientTask).Result;
+ Exception failure = WaitForSecureConnection(client, clientOptions, server, serverOptions).GetAwaiter().GetResult();
if (failure == null)
{
[Fact]
public async Task AuthenticateAsClientAsync_Sockets_CanceledAfterStart_ThrowsOperationCanceledException()
{
- using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
- using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
- {
- listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
- listener.Listen(1);
+ (Stream client, Stream server) = TestHelper.GetConnectedTcpStreams();
- await client.ConnectAsync(listener.LocalEndPoint);
- using (Socket server = await listener.AcceptAsync())
- using (var clientSslStream = new SslStream(new NetworkStream(client), false, AllowAnyServerCertificate))
- using (var serverSslStream = new SslStream(new NetworkStream(server)))
- using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
- {
- var cts = new CancellationTokenSource();
- Task t = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, cts.Token);
- cts.Cancel();
- await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
- }
+ using (client)
+ using (server)
+ using (var clientSslStream = new SslStream(client, false, AllowAnyServerCertificate))
+ using (var serverSslStream = new SslStream(server))
+ using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+ {
+ var cts = new CancellationTokenSource();
+ Task t = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, cts.Token);
+ cts.Cancel();
+ await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
}
}
[Fact]
public async Task AuthenticateAsServerAsync_Sockets_CanceledAfterStart_ThrowsOperationCanceledException()
{
- using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
- using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
- {
- listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
- listener.Listen(1);
+ (Stream client, Stream server) = TestHelper.GetConnectedTcpStreams();
- await client.ConnectAsync(listener.LocalEndPoint);
- using (Socket server = await listener.AcceptAsync())
- using (var clientSslStream = new SslStream(new NetworkStream(client), false, AllowAnyServerCertificate))
- using (var serverSslStream = new SslStream(new NetworkStream(server)))
- using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
- {
- var cts = new CancellationTokenSource();
- Task t = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, cts.Token);
- cts.Cancel();
- await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
- }
+ using (client)
+ using (server)
+ using (var clientSslStream = new SslStream(client, false, AllowAnyServerCertificate))
+ using (var serverSslStream = new SslStream(server))
+ using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+ {
+ var cts = new CancellationTokenSource();
+ Task t = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, cts.Token);
+ cts.Cancel();
+ await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
}
}
}
<Compile Include="NotifyReadVirtualNetworkStream.cs" />
<Compile Include="DummyTcpServer.cs" />
<Compile Include="TestConfiguration.cs" />
+ <Compile Include="TestHelper.cs" />
<!-- SslStream Tests -->
<Compile Include="CertificateValidationClientServer.cs" />
<Compile Include="CertificateValidationRemoteServer.cs" />
</ItemGroup>
</When>
</Choose>
-</Project>
\ No newline at end of file
+</Project>
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Net;
+using System.Net.Sockets;
+using System.Net.Test.Common;
+
+namespace System.Net.Security.Tests
+{
+ public static class TestHelper
+ {
+ public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()
+ {
+ if (Capability.SecurityForceSocketStreams())
+ {
+ return GetConnectedTcpStreams();
+ }
+
+ return GetConnectedVirtualStreams();
+ }
+
+ internal static (NetworkStream ClientStream, NetworkStream ServerStream) GetConnectedTcpStreams()
+ {
+ using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+ listener.Listen(1);
+
+ var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ clientSocket.Connect(listener.LocalEndPoint);
+ Socket serverSocket = listener.Accept();
+
+ return (new NetworkStream(clientSocket, ownsSocket: true), new NetworkStream(serverSocket, ownsSocket: true));
+ }
+
+ }
+
+ internal static (VirtualNetworkStream ClientStream, VirtualNetworkStream ServerStream) GetConnectedVirtualStreams()
+ {
+ VirtualNetwork vn = new VirtualNetwork();
+
+ return (new VirtualNetworkStream(vn, isServer: false), new VirtualNetworkStream(vn, isServer: true));
+ }
+ }
+}
// This method assumes that a SSPI context is already in a good shape.
// For example it is either a fresh context or already authenticated context that needs renegotiation.
//
- private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken cancellationToken)
- {
- }
-
- private void EndProcessAuthentication(IAsyncResult result)
+ private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default)
{
+ return Task.Run(() => {});
}
private void ReturnReadBufferIfEmpty()