private SafeSocketHandle _handle;
- // _rightEndPoint is null if the socket has not been bound. Otherwise, it is any EndPoint of the
- // correct type (IPEndPoint, etc).
+ // _rightEndPoint is null if the socket has not been bound. Otherwise, it is an EndPoint of the
+ // correct type (IPEndPoint, etc). The Bind operation sets _rightEndPoint. Other operations must only set
+ // it when the value is still null.
+ // This enables tracking the file created by UnixDomainSocketEndPoint when the Socket is bound,
+ // and to delete that file when the Socket gets disposed.
internal EndPoint? _rightEndPoint;
internal EndPoint? _remoteEndPoint;
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
- _rightEndPoint = _nonBlockingConnectRightEndPoint;
+ _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
}
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
- _rightEndPoint = _nonBlockingConnectRightEndPoint;
+ _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
}
{
// Update the state if we've become connected after a non-blocking connect.
_isConnected = true;
- _rightEndPoint = _nonBlockingConnectRightEndPoint;
+ _rightEndPoint ??= _nonBlockingConnectRightEndPoint;
UpdateLocalEndPointOnConnect();
_nonBlockingConnectInProgress = false;
}
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
}
- if (_rightEndPoint == null)
- {
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint = endPointSnapshot;
- }
+ // Save a copy of the EndPoint so we can use it for Create().
+ // For UnixDomainSocketEndPoint, track the file to delete on Dispose.
+ _rightEndPoint = endPointSnapshot is UnixDomainSocketEndPoint unixEndPoint ?
+ unixEndPoint.CreateBoundEndPoint() :
+ endPointSnapshot;
}
// Establishes a connection to a remote system.
if (SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramSent();
}
- if (_rightEndPoint == null)
- {
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint = remoteEP;
- }
+ // Save a copy of the EndPoint so we can use it for Create().
+ _rightEndPoint ??= remoteEP;
if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer, offset, size);
return bytesTransferred;
catch
{
}
- if (_rightEndPoint == null)
- {
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint = endPointSnapshot;
- }
+ // Save a copy of the EndPoint so we can use it for Create().
+ _rightEndPoint ??= endPointSnapshot;
}
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode);
catch
{
}
- if (_rightEndPoint == null)
- {
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint = endPointSnapshot;
- }
+ // Save a copy of the EndPoint so we can use it for Create().
+ _rightEndPoint ??= endPointSnapshot;
}
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode);
catch
{
}
- if (_rightEndPoint == null)
- {
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint = endPointSnapshot;
- }
+ // Save a copy of the EndPoint so we can use it for Create().
+ _rightEndPoint ??= endPointSnapshot;
}
if (socketException != null)
e.StartOperationCommon(this, SocketAsyncOperation.SendTo);
EndPoint? oldEndPoint = _rightEndPoint;
- if (_rightEndPoint == null)
- {
- _rightEndPoint = endPointSnapshot;
- }
+ _rightEndPoint ??= endPointSnapshot;
SocketError socketError;
try
}
catch
{
- _rightEndPoint = null;
+ _rightEndPoint = oldEndPoint;
_localEndPoint = null;
// Clear in-use flag on event args object.
e.Complete();
if (SocketsTelemetry.Log.IsEnabled()) SocketsTelemetry.Log.AfterConnect(SocketError.Success);
- if (_rightEndPoint == null)
- {
- // Save a copy of the EndPoint so we can use it for Create().
- _rightEndPoint = endPointSnapshot;
- }
+ // Save a copy of the EndPoint so we can use it for Create().
+ _rightEndPoint ??= endPointSnapshot;
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"connection to:{endPointSnapshot}");
{
}
}
+
+ // Delete file of bound UnixDomainSocketEndPoint.
+ if (_rightEndPoint is UnixDomainSocketEndPoint unixEndPoint &&
+ unixEndPoint.BoundFileName is not null)
+ {
+ try
+ {
+ File.Delete(unixEndPoint.BoundFileName);
+ }
+ catch
+ { }
+ }
}
// Clean up any cached data
socket._addressFamily = _addressFamily;
socket._socketType = _socketType;
socket._protocolType = _protocolType;
- socket._rightEndPoint = _rightEndPoint;
socket._remoteEndPoint = remoteEP;
+ // If the _rightEndpoint tracks a UnixDomainSocketEndPoint to delete
+ // then create a new EndPoint.
+ if (_rightEndPoint is UnixDomainSocketEndPoint unixEndPoint &&
+ unixEndPoint.BoundFileName is not null)
+ {
+ socket._rightEndPoint = unixEndPoint.CreateUnboundEndPoint();
+ }
+ else
+ {
+ socket._rightEndPoint = _rightEndPoint;
+ }
+
// If the listener socket was bound to a wildcard address, then the `accept` system call
// will assign a specific address to the accept socket's local endpoint instead of a
// wildcard address. In that case we should not copy listener's wildcard local endpoint.
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.DotNet.RemoteExecutor;
using Xunit;
using Xunit.Abstractions;
{
server.Dispose();
- try { File.Delete(path); }
- catch { }
+ Assert.False(File.Exists(path));
}
}
if (willRaiseEvent)
{
await complete.Task;
+
+ Assert.Equal(
+ OperatingSystem.IsWindows() ? SocketError.ConnectionRefused : SocketError.AddressNotAvailable,
+ args.SocketError);
}
Assert.Equal(
- OperatingSystem.IsWindows() ? SocketError.ConnectionRefused : SocketError.AddressNotAvailable,
+ RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.ConnectionRefused : SocketError.AddressNotAvailable,
args.SocketError);
}
}
{
string path = GetRandomNonExistingFilePath();
var endPoint = new UnixDomainSocketEndPoint(path);
- try
+ using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
{
- using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- {
- server.Bind(endPoint);
- server.Listen(1);
+ server.Bind(endPoint);
+ server.Listen(1);
- client.Connect(endPoint);
- using (Socket accepted = server.Accept())
+ client.Connect(endPoint);
+ using (Socket accepted = server.Accept())
+ {
+ var data = new byte[1];
+ for (int i = 0; i < 10; i++)
{
- var data = new byte[1];
- for (int i = 0; i < 10; i++)
- {
- data[0] = (byte)i;
+ data[0] = (byte)i;
- accepted.Send(data);
- data[0] = 0;
+ accepted.Send(data);
+ data[0] = 0;
- Assert.Equal(1, client.Receive(data));
- Assert.Equal(i, data[0]);
- }
+ Assert.Equal(1, client.Receive(data));
+ Assert.Equal(i, data[0]);
}
}
}
- finally
- {
- try { File.Delete(path); }
- catch { }
- }
+
+ Assert.False(File.Exists(path));
}
[ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
{
string path = GetRandomNonExistingFilePath();
var endPoint = new UnixDomainSocketEndPoint(path);
- try
+ using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
{
- using var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
- using var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
- {
- server.Bind(endPoint);
- server.Listen(1);
- client.Connect(endPoint);
+ server.Bind(endPoint);
+ server.Listen(1);
+ client.Connect(endPoint);
- using (Socket accepted = server.Accept())
- {
- using var clientClone = new Socket(client.SafeHandle);
- using var acceptedClone = new Socket(accepted.SafeHandle);
+ using (Socket accepted = server.Accept())
+ {
+ using var clientClone = new Socket(client.SafeHandle);
+ using var acceptedClone = new Socket(accepted.SafeHandle);
- _log.WriteLine($"accepted: LocalEndPoint={accepted.LocalEndPoint} RemoteEndPoint={accepted.RemoteEndPoint}");
- _log.WriteLine($"acceptedClone: LocalEndPoint={acceptedClone.LocalEndPoint} RemoteEndPoint={acceptedClone.RemoteEndPoint}");
+ _log.WriteLine($"accepted: LocalEndPoint={accepted.LocalEndPoint} RemoteEndPoint={accepted.RemoteEndPoint}");
+ _log.WriteLine($"acceptedClone: LocalEndPoint={acceptedClone.LocalEndPoint} RemoteEndPoint={acceptedClone.RemoteEndPoint}");
- Assert.True(clientClone.Connected);
- Assert.True(acceptedClone.Connected);
- Assert.Equal(client.LocalEndPoint.ToString(), clientClone.LocalEndPoint.ToString());
- Assert.Equal(client.RemoteEndPoint.ToString(), clientClone.RemoteEndPoint.ToString());
- Assert.Equal(accepted.LocalEndPoint.ToString(), acceptedClone.LocalEndPoint.ToString());
- Assert.Equal(accepted.RemoteEndPoint.ToString(), acceptedClone.RemoteEndPoint.ToString());
+ Assert.True(clientClone.Connected);
+ Assert.True(acceptedClone.Connected);
+ Assert.Equal(client.LocalEndPoint.ToString(), clientClone.LocalEndPoint.ToString());
+ Assert.Equal(client.RemoteEndPoint.ToString(), clientClone.RemoteEndPoint.ToString());
+ Assert.Equal(accepted.LocalEndPoint.ToString(), acceptedClone.LocalEndPoint.ToString());
+ Assert.Equal(accepted.RemoteEndPoint.ToString(), acceptedClone.RemoteEndPoint.ToString());
- var data = new byte[1];
- for (int i = 0; i < 10; i++)
- {
- data[0] = (byte)i;
+ var data = new byte[1];
+ for (int i = 0; i < 10; i++)
+ {
+ data[0] = (byte)i;
- acceptedClone.Send(data);
- data[0] = 0;
+ acceptedClone.Send(data);
+ data[0] = 0;
- Assert.Equal(1, clientClone.Receive(data));
- Assert.Equal(i, data[0]);
- }
+ Assert.Equal(1, clientClone.Receive(data));
+ Assert.Equal(i, data[0]);
}
}
}
- finally
- {
- try { File.Delete(path); }
- catch { }
- }
+
+ Assert.False(File.Exists(path));
}
[ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
{
string path = GetRandomNonExistingFilePath();
var endPoint = new UnixDomainSocketEndPoint(path);
- try
+ using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
{
- using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- {
- server.Bind(endPoint);
- server.Listen(1);
+ server.Bind(endPoint);
+ server.Listen(1);
- await client.ConnectAsync(endPoint);
- using (Socket accepted = await server.AcceptAsync())
+ await client.ConnectAsync(endPoint);
+ using (Socket accepted = await server.AcceptAsync())
+ {
+ var data = new byte[1];
+ for (int i = 0; i < 10; i++)
{
- var data = new byte[1];
- for (int i = 0; i < 10; i++)
- {
- data[0] = (byte)i;
+ data[0] = (byte)i;
- await accepted.SendAsync(new ArraySegment<byte>(data), SocketFlags.None);
- data[0] = 0;
+ await accepted.SendAsync(new ArraySegment<byte>(data), SocketFlags.None);
+ data[0] = 0;
- Assert.Equal(1, await client.ReceiveAsync(new ArraySegment<byte>(data), SocketFlags.None));
- Assert.Equal(i, data[0]);
- }
+ Assert.Equal(1, await client.ReceiveAsync(new ArraySegment<byte>(data), SocketFlags.None));
+ Assert.Equal(i, data[0]);
}
}
}
- finally
- {
- try { File.Delete(path); }
- catch { }
- }
+
+ Assert.False(File.Exists(path));
}
[ActiveIssue("https://github.com/dotnet/runtime/issues/26189", TestPlatforms.Windows)]
string path = GetRandomNonExistingFilePath();
var endPoint = new UnixDomainSocketEndPoint(path);
- try
+ using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
{
- using (var server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- using (var client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- {
- server.Bind(endPoint);
- server.Listen(1);
+ server.Bind(endPoint);
+ server.Listen(1);
- Task<Socket> serverAccept = server.AcceptAsync();
- await Task.WhenAll(serverAccept, client.ConnectAsync(endPoint));
+ Task<Socket> serverAccept = server.AcceptAsync();
+ await Task.WhenAll(serverAccept, client.ConnectAsync(endPoint));
- Task clientReceives = Task.Run(async () =>
+ Task clientReceives = Task.Run(async () =>
+ {
+ byte[] buffer = new byte[readBufferSize];
+ while (true)
{
- byte[] buffer = new byte[readBufferSize];
- while (true)
+ int bytesRead = await client.ReceiveAsync(new Memory<byte>(buffer), SocketFlags.None);
+ if (bytesRead == 0)
{
- int bytesRead = await client.ReceiveAsync(new Memory<byte>(buffer), SocketFlags.None);
- if (bytesRead == 0)
- {
- break;
- }
- Assert.InRange(bytesRead, 1, writeBuffer.Length - readData.Length);
- readData.Write(buffer, 0, bytesRead);
+ break;
}
- });
+ Assert.InRange(bytesRead, 1, writeBuffer.Length - readData.Length);
+ readData.Write(buffer, 0, bytesRead);
+ }
+ });
- using (Socket accepted = await serverAccept)
+ using (Socket accepted = await serverAccept)
+ {
+ for (int iter = 0; iter < iterations; iter++)
{
- for (int iter = 0; iter < iterations; iter++)
- {
- Task<int> sendTask = accepted.SendAsync(new ArraySegment<byte>(writeBuffer, iter * writeBufferSize, writeBufferSize), SocketFlags.None);
- await await Task.WhenAny(clientReceives, sendTask);
- Assert.Equal(writeBufferSize, await sendTask);
- }
+ Task<int> sendTask = accepted.SendAsync(new ArraySegment<byte>(writeBuffer, iter * writeBufferSize, writeBufferSize), SocketFlags.None);
+ await await Task.WhenAny(clientReceives, sendTask);
+ Assert.Equal(writeBufferSize, await sendTask);
}
-
- await clientReceives;
}
- Assert.Equal(writeBuffer.Length, readData.Length);
- AssertExtensions.SequenceEqual(writeBuffer, readData.ToArray());
- }
- finally
- {
- try { File.Delete(path); }
- catch { }
+ await clientReceives;
}
+
+ Assert.Equal(writeBuffer.Length, readData.Length);
+ AssertExtensions.SequenceEqual(writeBuffer, readData.ToArray());
+
+ Assert.False(File.Exists(path));
}
[ConditionalTheory(nameof(PlatformSupportsUnixDomainSockets))]
expectedClientAddress = clientAddress;
}
- try
+ using (Socket server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
{
- using (Socket server = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
- {
- server.Bind(new UnixDomainSocketEndPoint(serverAddress));
- server.Listen(1);
+ server.Bind(new UnixDomainSocketEndPoint(serverAddress));
+ server.Listen(1);
- using (Socket client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ using (Socket client = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ {
+ // Bind the client.
+ client.Bind(new UnixDomainSocketEndPoint(clientAddress));
+ client.Connect(new UnixDomainSocketEndPoint(serverAddress));
+ using (Socket acceptedClient = server.Accept())
{
- // Bind the client.
- client.Bind(new UnixDomainSocketEndPoint(clientAddress));
- client.Connect(new UnixDomainSocketEndPoint(serverAddress));
- using (Socket acceptedClient = server.Accept())
- {
- // Verify the client address on the server.
- EndPoint clientAddressOnServer = acceptedClient.RemoteEndPoint;
- Assert.True(string.CompareOrdinal(expectedClientAddress, clientAddressOnServer.ToString()) == 0);
- }
+ // Verify the client address on the server.
+ EndPoint clientAddressOnServer = acceptedClient.RemoteEndPoint;
+ Assert.True(string.CompareOrdinal(expectedClientAddress, clientAddressOnServer.ToString()) == 0);
}
}
}
- finally
- {
- if (!abstractAddress)
- {
- try { File.Delete(serverAddress); }
- catch { }
- try { File.Delete(clientAddress); }
- catch { }
- }
- }
+
+ Assert.False(File.Exists(serverAddress));
+ Assert.False(File.Exists(clientAddress));
}
[ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
Assert.Throws<PlatformNotSupportedException>(() => new UnixDomainSocketEndPoint("hello"));
}
+ [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+ public void UnixDomainSocketEndPoint_RelativePathDeletesFile()
+ {
+ if (!PlatformSupportsUnixDomainSockets)
+ {
+ return;
+ }
+ RemoteExecutor.Invoke(() =>
+ {
+ using (Socket socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified))
+ {
+ // Bind to a relative path.
+ string path = GetRandomNonExistingFilePath();
+ string wd = Path.GetDirectoryName(path);
+ Directory.SetCurrentDirectory(wd);
+ socket.Bind(new UnixDomainSocketEndPoint(Path.GetFileName(path)));
+ Assert.True(File.Exists(path));
+
+ string otherDir = GetRandomNonExistingFilePath();
+ Directory.CreateDirectory(otherDir);
+ try
+ {
+ // Change to another directory.
+ Directory.SetCurrentDirectory(Path.GetDirectoryName(path));
+
+ // Dispose deletes file from original path.
+ socket.Dispose();
+ Assert.False(File.Exists(path));
+ }
+ finally
+ {
+ Directory.SetCurrentDirectory(wd);
+ Directory.Delete(otherDir);
+ }
+ }
+ }).Dispose();
+ }
+
private static string GetRandomNonExistingFilePath()
{
string result;