}
//
- internal ProtocolToken NextMessage(byte[] incoming, int offset, int count)
+ internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
{
if (NetEventSource.IsEnabled)
NetEventSource.Enter(this);
byte[] nextmsg = null;
- SecurityStatusPal status = GenerateToken(incoming, offset, count, ref nextmsg);
+ SecurityStatusPal status = GenerateToken(incomingBuffer, ref nextmsg);
if (!_sslAuthenticationOptions.IsServer && status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded)
{
NetEventSource.Info(this, "NextMessage() returned SecurityStatusPal.CredentialsNeeded");
SetRefreshCredentialNeeded();
- status = GenerateToken(incoming, offset, count, ref nextmsg);
+ status = GenerateToken(incomingBuffer, ref nextmsg);
}
ProtocolToken token = new ProtocolToken(nextmsg, status);
Return:
status - error information
--*/
- private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref byte[] output)
+ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte[] output)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, $"_refreshCredentialNeeded = {_refreshCredentialNeeded}");
- if (offset < 0 || offset > (input == null ? 0 : input.Length))
- {
- NetEventSource.Fail(this, "Argument 'offset' out of range.");
- throw new ArgumentOutOfRangeException(nameof(offset));
- }
-
- if (count < 0 || count > (input == null ? 0 : input.Length - offset))
- {
- NetEventSource.Fail(this, "Argument 'count' out of range.");
- throw new ArgumentOutOfRangeException(nameof(count));
- }
-
byte[] result = Array.Empty<byte>();
SecurityStatusPal status = default;
bool cachedCreds = false;
byte[] thumbPrint = null;
- ReadOnlySpan<byte> inputBuffer = new ReadOnlySpan<byte>(input, offset, count);
//
// Looping through ASC or ISC with potentially cached credential that could have been
byte[] nextmsg = null;
SecurityStatusPal status;
- status = GenerateToken(null, 0, 0, ref nextmsg);
+ status = GenerateToken(default, ref nextmsg);
ProtocolToken token = new ProtocolToken(nextmsg, status);
private const int FrameOverhead = 32;
private const int ReadBufferSize = 4096 * 4 + FrameOverhead; // We read in 16K chunks + headers.
+ private const int InitialHandshakeBufferSize = 4096 + FrameOverhead; // try to fit at least 4K ServerCertificate
+ private ArrayBuffer _handshakeBuffer;
private int _lockWriteState;
private int _lockReadState;
if (reAuthenticationData == null)
{
- // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation tansparently.
+ // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently.
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate"));
if (!receiveFirst)
{
- message = _context.NextMessage(reAuthenticationData, 0, (reAuthenticationData == null ? 0 : reAuthenticationData.Length));
+ message = _context.NextMessage(reAuthenticationData);
if (message.Size > 0)
{
await adapter.WriteAsync(message.Payload, 0, message.Size).ConfigureAwait(false);
}
}
+ _handshakeBuffer = new ArrayBuffer(InitialHandshakeBufferSize);
do
{
message = await ReceiveBlobAsync(adapter).ConfigureAwait(false);
}
} while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK);
+ if (_handshakeBuffer.ActiveLength > 0)
+ {
+ // If we read more than we needed for handshake, move it to input buffer for further processing.
+ ResetReadBuffer();
+ _handshakeBuffer.ActiveSpan.CopyTo(_internalBuffer);
+ _internalBufferCount = _handshakeBuffer.ActiveLength;
+ }
+
ProtocolToken alertToken = null;
if (!CompleteHandshake(ref alertToken))
{
}
finally
{
+ _handshakeBuffer.Dispose();
if (reAuthenticationData == null)
{
_nestedAuth = 0;
private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter adapter)
where TIOAdapter : ISslIOAdapter
{
- ResetReadBuffer();
- int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
+ int readBytes = await FillHandshakeBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
if (readBytes == 0)
{
throw new IOException(SR.net_io_eof);
if (_framing == Framing.Unified || _framing == Framing.Unknown)
{
- _framing = DetectFraming(_internalBuffer, readBytes);
+ _framing = DetectFraming(_handshakeBuffer.ActiveReadOnlySpan);
}
- int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes);
- if (payloadBytes < 0)
+ int frameSize= GetFrameSize(_handshakeBuffer.ActiveReadOnlySpan);
+ if (frameSize < 0)
{
throw new IOException(SR.net_frame_read_size);
}
- int frameSize = SecureChannel.ReadHeaderSize + payloadBytes;
-
- if (readBytes < frameSize)
+ if (_handshakeBuffer.ActiveLength < frameSize)
{
- readBytes = await FillBufferAsync(adapter, frameSize).ConfigureAwait(false);
- Debug.Assert(readBytes >= 0);
- if (readBytes == 0)
- {
- throw new IOException(SR.net_io_eof);
- }
+ await FillHandshakeBufferAsync(adapter, frameSize).ConfigureAwait(false);
}
- ProtocolToken token = _context.NextMessage(_internalBuffer, _internalOffset, frameSize);
- ConsumeBufferedBytes(frameSize);
+ ProtocolToken token = _context.NextMessage(_handshakeBuffer.ActiveReadOnlySpan.Slice(0, frameSize));
+ _handshakeBuffer.Discard(frameSize);
return token;
}
{
// Re-handshake status is not supported.
ArrayPool<byte>.Shared.Return(rentedBuffer);
- ProtocolToken message = new ProtocolToken(null, status);
- return new ValueTask(Task.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(SR.net_io_encrypt, message.GetException()))));
+ return new ValueTask(Task.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(SR.net_io_encrypt, SslStreamPal.GetException(status)))));
}
ValueTask t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes);
return 0;
}
- int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes);
+ int payloadBytes = GetFrameSize(new ReadOnlySpan<byte>(_internalBuffer, _internalOffset, readBytes));
if (payloadBytes < 0)
{
throw new IOException(SR.net_frame_read_size);
}
- readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize + payloadBytes).ConfigureAwait(false);
+ readBytes = await FillBufferAsync(adapter, payloadBytes).ConfigureAwait(false);
Debug.Assert(readBytes >= 0);
if (readBytes == 0)
{
// Set _decrytpedBytesOffset/Count to the current frame we have (including header)
// DecryptData will decrypt in-place and modify these to point to the actual decrypted data, which may be smaller.
_decryptedBytesOffset = _internalOffset;
- _decryptedBytesCount = readBytes;
+ _decryptedBytesCount = payloadBytes;
SecurityStatusPal status = DecryptData();
// Treat the bytes we just decrypted as consumed
_decryptedBytesCount = 0;
}
- ProtocolToken message = new ProtocolToken(null, status);
if (NetEventSource.IsEnabled)
- NetEventSource.Info(null, $"***Processing an error Status = {message.Status}");
+ NetEventSource.Info(null, $"***Processing an error Status = {status}");
- if (message.Renegotiate)
+ if (status.ErrorCode == SecurityStatusPalErrorCode.Renegotiate)
{
if (!_sslAuthenticationOptions.AllowRenegotiation)
{
continue;
}
- if (message.CloseConnection)
+ if (status.ErrorCode == SecurityStatusPalErrorCode.ContextExpired)
{
return 0;
}
- throw new IOException(SR.net_io_decrypt, message.GetException());
+ throw new IOException(SR.net_io_decrypt, SslStreamPal.GetException(status));
}
}
}
}
}
+ // This function tries to make sure buffer has at least minSize bytes available.
+ // If we have enough data, it returns synchronously. If not, it will try to read
+ // remaining bytes from given stream.
+ private ValueTask<int> FillHandshakeBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
+ where TIOAdapter : ISslIOAdapter
+ {
+ if (_handshakeBuffer.ActiveLength >= minSize)
+ {
+ return new ValueTask<int>(minSize);
+ }
+
+ int bytesNeeded = minSize - _handshakeBuffer.ActiveLength;
+ _handshakeBuffer.EnsureAvailableSpace(bytesNeeded);
+
+ while (_handshakeBuffer.ActiveLength < minSize)
+ {
+ ValueTask<int> t = adapter.ReadAsync(_handshakeBuffer.AvailableMemory);
+ if (!t.IsCompletedSuccessfully)
+ {
+ return InternalFillHandshakeBufferAsync(adapter, t, minSize);
+ }
+ int bytesRead = t.Result;
+ if (bytesRead == 0)
+ {
+ return new ValueTask<int>(0);
+ }
+
+ _handshakeBuffer.Commit(bytesRead);
+ }
+
+ return new ValueTask<int>(minSize);
+
+ async ValueTask<int> InternalFillHandshakeBufferAsync(TIOAdapter adap, ValueTask<int> task, int minSize)
+ {
+ while (true)
+ {
+ int bytesRead = await task.ConfigureAwait(false);
+ if (bytesRead == 0)
+ {
+ throw new IOException(SR.net_io_eof);
+ }
+
+ _handshakeBuffer.Commit(bytesRead);
+ if (_handshakeBuffer.ActiveLength >= minSize)
+ {
+ return minSize;
+ }
+
+ int bytesNeeded = minSize - _handshakeBuffer.ActiveLength;
+ task = adap.ReadAsync(_handshakeBuffer.AvailableMemory);
+ }
+ }
+ }
+
private ValueTask<int> FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
where TIOAdapter : ISslIOAdapter
{
int initialCount = _internalBufferCount;
do
{
- ValueTask<int> t = adapter.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount);
+ ValueTask<int> t = adapter.ReadAsync(new Memory<byte>(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount));
if (!t.IsCompletedSuccessfully)
{
return InternalFillBufferAsync(adapter, t, minSize, initialCount);
return min;
}
- task = adap.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount);
+ task = adap.ReadAsync(new Memory<byte>(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount));
}
}
}
}
// We need at least 5 bytes to determine what we have.
- private Framing DetectFraming(byte[] bytes, int length)
+ private Framing DetectFraming(ReadOnlySpan<byte> bytes)
{
/* PCTv1.0 Hello starts with
* RECORD_LENGTH_MSB (ignore)
int version = -1;
- if ((bytes == null || bytes.Length <= 0))
+ if (bytes.Length == 0)
{
NetEventSource.Fail(this, "Header buffer is not allocated.");
}
if (bytes[0] == (byte)FrameType.Handshake || bytes[0] == (byte)FrameType.AppData
|| bytes[0] == (byte)FrameType.Alert)
{
- if (length < 3)
+ if (bytes.Length < 3)
{
return Framing.Invalid;
}
}
#endif
- if (length < 3)
+ if (bytes.Length < 3)
{
return Framing.Invalid;
}
if (bytes[2] == 0x1) // SSL_MT_CLIENT_HELLO
{
- if (length >= 5)
+ if (bytes.Length >= 5)
{
version = (bytes[3] << 8) | bytes[4];
}
}
else if (bytes[2] == 0x4) // SSL_MT_SERVER_HELLO
{
- if (length >= 7)
+ if (bytes.Length >= 7)
{
version = (bytes[5] << 8) | bytes[6];
}
//
// This is called from SslStream class too.
- private int GetRemainingFrameSize(byte[] buffer, int offset, int dataSize)
+ // Returns TLS Frame size.
+ //
+ private int GetFrameSize(ReadOnlySpan<byte> buffer)
{
if (NetEventSource.IsEnabled)
- NetEventSource.Enter(this, buffer, offset, dataSize);
+ NetEventSource.Enter(this, buffer.Length);
int payloadSize = -1;
switch (_framing)
{
case Framing.Unified:
case Framing.BeforeSSL3:
- if (dataSize < 2)
+ if (buffer.Length < 2)
{
throw new System.IO.IOException(SR.net_ssl_io_frame);
}
// Note: Cannot detect version mismatch for <= SSL2
- if ((buffer[offset] & 0x80) != 0)
+ if ((buffer[0] & 0x80) != 0)
{
// Two bytes
- payloadSize = (((buffer[offset] & 0x7f) << 8) | buffer[offset + 1]) + 2;
- payloadSize -= dataSize;
+ payloadSize = (((buffer[0] & 0x7f) << 8) | buffer[1]) + 2;
}
else
{
// Three bytes
- payloadSize = (((buffer[offset] & 0x3f) << 8) | buffer[offset + 1]) + 3;
- payloadSize -= dataSize;
+ payloadSize = (((buffer[0] & 0x3f) << 8) | buffer[1]) + 3;
}
break;
case Framing.SinceSSL3:
- if (dataSize < 5)
+ if (buffer.Length < 5)
{
throw new System.IO.IOException(SR.net_ssl_io_frame);
}
- payloadSize = ((buffer[offset + 3] << 8) | buffer[offset + 4]) + 5;
- payloadSize -= dataSize;
+ payloadSize = ((buffer[3] << 8) | buffer[4]) + 5;
break;
default:
break;
if (NetEventSource.IsEnabled)
NetEventSource.Exit(this, payloadSize);
+
return payloadSize;
}
}