// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Test.Common;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
-
using Xunit;
namespace System.Net.Security.Tests
{
private readonly byte[] _sampleMsg = Encoding.UTF8.GetBytes("Sample Test Message");
- protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream);
+ protected static async Task WithServerCertificate(X509Certificate serverCertificate, Func<X509Certificate, string, Task> func)
+ {
+ X509Certificate certificate = serverCertificate ?? Configuration.Certificates.GetServerCertificate();
+ try
+ {
+ string name;
+ if (certificate is X509Certificate2 cert2)
+ {
+ name = cert2.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
+ }
+ else
+ {
+ using (cert2 = new X509Certificate2(certificate))
+ {
+ name = cert2.GetNameInfo(X509NameType.SimpleName, forIssuer: false);
+ }
+ }
+
+ await func(certificate, name).ConfigureAwait(false);
+ }
+ finally
+ {
+ if (certificate != serverCertificate)
+ {
+ certificate.Dispose();
+ }
+ }
+ }
+
+ protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null);
protected abstract Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
protected abstract Task WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
- [Fact]
- public async Task SslStream_StreamToStream_Authentication_Success()
+ public static IEnumerable<object[]> SslStream_StreamToStream_Authentication_Success_MemberData()
{
- VirtualNetwork network = new VirtualNetwork();
+ using (X509Certificate2 serverCert = Configuration.Certificates.GetServerCertificate())
+ using (X509Certificate2 clientCert = Configuration.Certificates.GetClientCertificate())
+ {
+ yield return new object[] { new X509Certificate2(serverCert), new X509Certificate2(clientCert) };
+ yield return new object[] { new X509Certificate(serverCert.Export(X509ContentType.Pfx)), new X509Certificate(clientCert.Export(X509ContentType.Pfx)) };
+ }
+ }
- using (var clientStream = new VirtualNetworkStream(network, isServer: false))
- using (var serverStream = new VirtualNetworkStream(network, isServer: true))
- using (var client = new SslStream(clientStream, false, AllowAnyServerCertificate))
- using (var server = new SslStream(serverStream))
- using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+ [Theory]
+ [MemberData(nameof(SslStream_StreamToStream_Authentication_Success_MemberData))]
+ public async Task SslStream_StreamToStream_Authentication_Success(X509Certificate serverCert = null, X509Certificate clientCert = null)
+ {
+ var network = new VirtualNetwork();
+ using (var client = new SslStream(new VirtualNetworkStream(network, isServer: false), false, AllowAnyServerCertificate))
+ using (var server = new SslStream(new VirtualNetworkStream(network, isServer: true), false, delegate { return true; }))
{
- await DoHandshake(client, server);
+ await DoHandshake(client, server, serverCert, clientCert);
Assert.True(client.IsAuthenticated);
Assert.True(server.IsAuthenticated);
}
+
+ clientCert?.Dispose();
+ serverCert?.Dispose();
}
[Fact]
public sealed class SslStreamStreamToStreamTest_Async : SslStreamStreamToStreamTest_CancelableReadWriteAsync
{
- protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+ protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
{
- using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+ X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+ await WithServerCertificate(serverCertificate, async(certificate, name) =>
{
- Task t1 = clientSslStream.AuthenticateAsClientAsync(certificate.GetNameInfo(X509NameType.SimpleName, false));
- Task t2 = serverSslStream.AuthenticateAsServerAsync(certificate);
+ Task t1 = clientSslStream.AuthenticateAsClientAsync(name, clientCerts, SslProtocols.None, checkCertificateRevocation: false);
+ Task t2 = serverSslStream.AuthenticateAsServerAsync(certificate, clientCertificateRequired: clientCertificate != null, checkCertificateRevocation: false);
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
- }
+ });
}
protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
public sealed class SslStreamStreamToStreamTest_BeginEnd : SslStreamStreamToStreamTest
{
- protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+ protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
{
- using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+ X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+ await WithServerCertificate(serverCertificate, async (certificate, name) =>
{
- Task t1 = Task.Factory.FromAsync(clientSslStream.BeginAuthenticateAsClient(certificate.GetNameInfo(X509NameType.SimpleName, false), null, null), clientSslStream.EndAuthenticateAsClient);
- Task t2 = Task.Factory.FromAsync(serverSslStream.BeginAuthenticateAsServer(certificate, null, null), serverSslStream.EndAuthenticateAsServer);
+ Task t1 = Task.Factory.FromAsync(clientSslStream.BeginAuthenticateAsClient(name, clientCerts, SslProtocols.None, checkCertificateRevocation: false, null, null), clientSslStream.EndAuthenticateAsClient);
+ Task t2 = Task.Factory.FromAsync(serverSslStream.BeginAuthenticateAsServer(certificate, clientCertificateRequired: clientCertificate != null, checkCertificateRevocation: false, null, null), serverSslStream.EndAuthenticateAsServer);
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
- }
+ });
}
protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
public sealed class SslStreamStreamToStreamTest_Sync : SslStreamStreamToStreamTest
{
- protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream)
+ protected override async Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream, X509Certificate serverCertificate = null, X509Certificate clientCertificate = null)
{
- using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate())
+ X509CertificateCollection clientCerts = clientCertificate != null ? new X509CertificateCollection() { clientCertificate } : null;
+ await WithServerCertificate(serverCertificate, async (certificate, name) =>
{
- Task t1 = Task.Run(() => clientSslStream.AuthenticateAsClient(certificate.GetNameInfo(X509NameType.SimpleName, false)));
- Task t2 = Task.Run(() => serverSslStream.AuthenticateAsServer(certificate));
+ Task t1 = Task.Run(() => clientSslStream.AuthenticateAsClient(name, clientCerts, SslProtocols.None, checkCertificateRevocation: false));
+ Task t2 = Task.Run(() => serverSslStream.AuthenticateAsServer(certificate, clientCertificateRequired: clientCertificate != null, checkCertificateRevocation: false));
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
- }
+ });
}
protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken)