From 366fe2dbb24e853662d0b910ee09e1733b863b69 Mon Sep 17 00:00:00 2001 From: Mary Date: Wed, 12 Jan 2022 19:31:08 +0100 Subject: [PATCH] bsd: Revamp API and make socket abstract (#2960) * bsd: Revamp API and make socket abstract This part of the code was really ancient and needed some love. As such this commit aims at separating the socket core logic from the IClient class and make it uses more modern APIs to read/write/parse data. * Address gdkchan's comment * Move TryConvertSocketOption to WinSockHelper * Allow reusing old fds and add missing locks around SocketInternal and ShutdownAllSockets * bsd: ton of changes - Make sockets per process - Implement eventfds - Rework Poll for support of eventfds - Handle protocol auto selection by type (used by gRPC) - Handle IPv6 socket creation * Address most of gdkchan comments * Fix inverted read logic for BSD socket read * bsd: Make Poll abstract via IBsdSocketPollManager * bsd: Improve naming of everything * Fix build issue from last commit (missed to save on VC) * Switch BsdContext registry to a concurrent dictionary * bsd: Implement socket creation flags logic and the non blocking flag * Remove unused enum from previous commit * bsd: Fix poll logic when 0 fds are present for a given poll manager and when timeout is very small (or 0) * Address gdkchan's comment --- .../HOS/Services/Sockets/Bsd/BsdContext.cs | 150 +++ .../HOS/Services/Sockets/Bsd/IClient.cs | 898 ++++++------------ .../Services/Sockets/Bsd/IFileDescriptor.cs | 14 + .../HOS/Services/Sockets/Bsd/ISocket.cs | 47 + .../Sockets/Bsd/Impl/EventFileDescriptor.cs | 130 +++ .../Impl/EventFileDescriptorPollManager.cs | 96 ++ .../Sockets/Bsd/Impl/ManagedSocket.cs | 338 +++++++ .../Bsd/Impl/ManagedSocketPollManager.cs | 129 +++ .../Sockets/Bsd/{Types => Impl}/WSAError.cs | 0 .../Sockets/Bsd/Impl/WinSockHelper.cs | 165 ++++ .../Sockets/Bsd/Types/BsdAddressFamily.cs | 11 + .../Services/Sockets/Bsd/Types/BsdSockAddr.cs | 39 + .../Services/Sockets/Bsd/Types/BsdSocket.cs | 13 - .../Bsd/Types/BsdSocketCreationFlags.cs | 14 + .../Bsd/Types/BsdSocketShutdownFlags.cs | 9 + .../Sockets/Bsd/Types/BsdSocketType.cs | 13 + .../Sockets/Bsd/Types/EventFdFlags.cs | 12 + .../Sockets/Bsd/Types/IPollManager.cs | 11 + .../Services/Sockets/Bsd/Types/PollEvent.cs | 24 +- .../Sockets/Bsd/Types/PollEventData.cs | 13 + .../Sockets/Bsd/Types/PollEventTypeMask.cs | 15 + 21 files changed, 1482 insertions(+), 659 deletions(-) create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptorPollManager.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocketPollManager.cs rename Ryujinx.HLE/HOS/Services/Sockets/Bsd/{Types => Impl}/WSAError.cs (100%) create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdAddressFamily.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSockAddr.cs delete mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocket.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketCreationFlags.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketShutdownFlags.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketType.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/EventFdFlags.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/IPollManager.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventData.cs create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventTypeMask.cs diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs new file mode 100644 index 00000000..071c1317 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/BsdContext.cs @@ -0,0 +1,150 @@ +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class BsdContext + { + private static ConcurrentDictionary _registry = new ConcurrentDictionary(); + + private readonly object _lock = new object(); + + private List _fds; + + private BsdContext() + { + _fds = new List(); + } + + public ISocket RetrieveSocket(int socketFd) + { + IFileDescriptor file = RetrieveFileDescriptor(socketFd); + + if (file is ISocket socket) + { + return socket; + } + + return null; + } + + public IFileDescriptor RetrieveFileDescriptor(int fd) + { + lock (_lock) + { + if (fd >= 0 && _fds.Count > fd) + { + return _fds[fd]; + } + } + + return null; + } + + public int RegisterFileDescriptor(IFileDescriptor file) + { + lock (_lock) + { + for (int fd = 0; fd < _fds.Count; fd++) + { + if (_fds[fd] == null) + { + _fds[fd] = file; + + return fd; + } + } + + _fds.Add(file); + + return _fds.Count - 1; + } + } + + public int DuplicateFileDescriptor(int fd) + { + IFileDescriptor oldFile = RetrieveFileDescriptor(fd); + + if (oldFile != null) + { + lock (_lock) + { + oldFile.Refcount++; + + return RegisterFileDescriptor(oldFile); + } + } + + return -1; + } + + public bool CloseFileDescriptor(int fd) + { + IFileDescriptor file = RetrieveFileDescriptor(fd); + + if (file != null) + { + file.Refcount--; + + if (file.Refcount <= 0) + { + file.Dispose(); + } + + lock (_lock) + { + _fds[fd] = null; + } + + return true; + } + + return false; + } + + public LinuxError ShutdownAllSockets(BsdSocketShutdownFlags how) + { + lock (_lock) + { + foreach (IFileDescriptor file in _fds) + { + if (file is ISocket socket) + { + LinuxError errno = socket.Shutdown(how); + + if (errno != LinuxError.SUCCESS) + { + return errno; + } + } + } + } + + return LinuxError.SUCCESS; + } + + public static BsdContext GetOrRegister(long processId) + { + BsdContext context = GetContext(processId); + + if (context == null) + { + context = new BsdContext(); + + _registry.TryAdd(processId, context); + } + + return context; + } + + public static BsdContext GetContext(long processId) + { + if (!_registry.TryGetValue(processId, out BsdContext processContext)) + { + return null; + } + + return processContext; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs index 76f80f92..ae245ec8 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IClient.cs @@ -1,11 +1,12 @@ -using Ryujinx.Common.Logging; +using Ryujinx.Common; +using Ryujinx.Common.Logging; +using Ryujinx.Memory; using System; -using System.Buffers.Binary; using System.Collections.Generic; using System.Net; using System.Net.Sockets; +using System.Runtime.CompilerServices; using System.Text; -using System.Threading; namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { @@ -13,221 +14,21 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd [Service("bsd:u", false)] class IClient : IpcService { - private static readonly Dictionary _errorMap = new() + private static readonly List _pollManagers = new List { - // WSAEINTR - {WsaError.WSAEINTR, LinuxError.EINTR}, - // WSAEWOULDBLOCK - {WsaError.WSAEWOULDBLOCK, LinuxError.EWOULDBLOCK}, - // WSAEINPROGRESS - {WsaError.WSAEINPROGRESS, LinuxError.EINPROGRESS}, - // WSAEALREADY - {WsaError.WSAEALREADY, LinuxError.EALREADY}, - // WSAENOTSOCK - {WsaError.WSAENOTSOCK, LinuxError.ENOTSOCK}, - // WSAEDESTADDRREQ - {WsaError.WSAEDESTADDRREQ, LinuxError.EDESTADDRREQ}, - // WSAEMSGSIZE - {WsaError.WSAEMSGSIZE, LinuxError.EMSGSIZE}, - // WSAEPROTOTYPE - {WsaError.WSAEPROTOTYPE, LinuxError.EPROTOTYPE}, - // WSAENOPROTOOPT - {WsaError.WSAENOPROTOOPT, LinuxError.ENOPROTOOPT}, - // WSAEPROTONOSUPPORT - {WsaError.WSAEPROTONOSUPPORT, LinuxError.EPROTONOSUPPORT}, - // WSAESOCKTNOSUPPORT - {WsaError.WSAESOCKTNOSUPPORT, LinuxError.ESOCKTNOSUPPORT}, - // WSAEOPNOTSUPP - {WsaError.WSAEOPNOTSUPP, LinuxError.EOPNOTSUPP}, - // WSAEPFNOSUPPORT - {WsaError.WSAEPFNOSUPPORT, LinuxError.EPFNOSUPPORT}, - // WSAEAFNOSUPPORT - {WsaError.WSAEAFNOSUPPORT, LinuxError.EAFNOSUPPORT}, - // WSAEADDRINUSE - {WsaError.WSAEADDRINUSE, LinuxError.EADDRINUSE}, - // WSAEADDRNOTAVAIL - {WsaError.WSAEADDRNOTAVAIL, LinuxError.EADDRNOTAVAIL}, - // WSAENETDOWN - {WsaError.WSAENETDOWN, LinuxError.ENETDOWN}, - // WSAENETUNREACH - {WsaError.WSAENETUNREACH, LinuxError.ENETUNREACH}, - // WSAENETRESET - {WsaError.WSAENETRESET, LinuxError.ENETRESET}, - // WSAECONNABORTED - {WsaError.WSAECONNABORTED, LinuxError.ECONNABORTED}, - // WSAECONNRESET - {WsaError.WSAECONNRESET, LinuxError.ECONNRESET}, - // WSAENOBUFS - {WsaError.WSAENOBUFS, LinuxError.ENOBUFS}, - // WSAEISCONN - {WsaError.WSAEISCONN, LinuxError.EISCONN}, - // WSAENOTCONN - {WsaError.WSAENOTCONN, LinuxError.ENOTCONN}, - // WSAESHUTDOWN - {WsaError.WSAESHUTDOWN, LinuxError.ESHUTDOWN}, - // WSAETOOMANYREFS - {WsaError.WSAETOOMANYREFS, LinuxError.ETOOMANYREFS}, - // WSAETIMEDOUT - {WsaError.WSAETIMEDOUT, LinuxError.ETIMEDOUT}, - // WSAECONNREFUSED - {WsaError.WSAECONNREFUSED, LinuxError.ECONNREFUSED}, - // WSAELOOP - {WsaError.WSAELOOP, LinuxError.ELOOP}, - // WSAENAMETOOLONG - {WsaError.WSAENAMETOOLONG, LinuxError.ENAMETOOLONG}, - // WSAEHOSTDOWN - {WsaError.WSAEHOSTDOWN, LinuxError.EHOSTDOWN}, - // WSAEHOSTUNREACH - {WsaError.WSAEHOSTUNREACH, LinuxError.EHOSTUNREACH}, - // WSAENOTEMPTY - {WsaError.WSAENOTEMPTY, LinuxError.ENOTEMPTY}, - // WSAEUSERS - {WsaError.WSAEUSERS, LinuxError.EUSERS}, - // WSAEDQUOT - {WsaError.WSAEDQUOT, LinuxError.EDQUOT}, - // WSAESTALE - {WsaError.WSAESTALE, LinuxError.ESTALE}, - // WSAEREMOTE - {WsaError.WSAEREMOTE, LinuxError.EREMOTE}, - // WSAEINVAL - {WsaError.WSAEINVAL, LinuxError.EINVAL}, - // WSAEFAULT - {WsaError.WSAEFAULT, LinuxError.EFAULT}, - // NOERROR - {0, 0} - }; - - private static readonly Dictionary _soSocketOptionMap = new() - { - { BsdSocketOption.SoDebug, SocketOptionName.Debug }, - { BsdSocketOption.SoReuseAddr, SocketOptionName.ReuseAddress }, - { BsdSocketOption.SoKeepAlive, SocketOptionName.KeepAlive }, - { BsdSocketOption.SoDontRoute, SocketOptionName.DontRoute }, - { BsdSocketOption.SoBroadcast, SocketOptionName.Broadcast }, - { BsdSocketOption.SoUseLoopBack, SocketOptionName.UseLoopback }, - { BsdSocketOption.SoLinger, SocketOptionName.Linger }, - { BsdSocketOption.SoOobInline, SocketOptionName.OutOfBandInline }, - { BsdSocketOption.SoReusePort, SocketOptionName.ReuseAddress }, - { BsdSocketOption.SoSndBuf, SocketOptionName.SendBuffer }, - { BsdSocketOption.SoRcvBuf, SocketOptionName.ReceiveBuffer }, - { BsdSocketOption.SoSndLoWat, SocketOptionName.SendLowWater }, - { BsdSocketOption.SoRcvLoWat, SocketOptionName.ReceiveLowWater }, - { BsdSocketOption.SoSndTimeo, SocketOptionName.SendTimeout }, - { BsdSocketOption.SoRcvTimeo, SocketOptionName.ReceiveTimeout }, - { BsdSocketOption.SoError, SocketOptionName.Error }, - { BsdSocketOption.SoType, SocketOptionName.Type } - }; - - private static readonly Dictionary _ipSocketOptionMap = new() - { - { BsdSocketOption.IpOptions, SocketOptionName.IPOptions }, - { BsdSocketOption.IpHdrIncl, SocketOptionName.HeaderIncluded }, - { BsdSocketOption.IpTtl, SocketOptionName.IpTimeToLive }, - { BsdSocketOption.IpMulticastIf, SocketOptionName.MulticastInterface }, - { BsdSocketOption.IpMulticastTtl, SocketOptionName.MulticastTimeToLive }, - { BsdSocketOption.IpMulticastLoop, SocketOptionName.MulticastLoopback }, - { BsdSocketOption.IpAddMembership, SocketOptionName.AddMembership }, - { BsdSocketOption.IpDropMembership, SocketOptionName.DropMembership }, - { BsdSocketOption.IpDontFrag, SocketOptionName.DontFragment }, - { BsdSocketOption.IpAddSourceMembership, SocketOptionName.AddSourceMembership }, - { BsdSocketOption.IpDropSourceMembership, SocketOptionName.DropSourceMembership } - }; - - private static readonly Dictionary _tcpSocketOptionMap = new() - { - { BsdSocketOption.TcpNoDelay, SocketOptionName.NoDelay }, - { BsdSocketOption.TcpKeepIdle, SocketOptionName.TcpKeepAliveTime }, - { BsdSocketOption.TcpKeepIntvl, SocketOptionName.TcpKeepAliveInterval }, - { BsdSocketOption.TcpKeepCnt, SocketOptionName.TcpKeepAliveRetryCount } + EventFileDescriptorPollManager.Instance, + ManagedSocketPollManager.Instance }; + private BsdContext _context; private bool _isPrivileged; - private List _sockets = new List(); - public IClient(ServiceCtx context, bool isPrivileged) : base(context.Device.System.BsdServer) { _isPrivileged = isPrivileged; } - private static LinuxError ConvertError(WsaError errorCode) - { - if (!_errorMap.TryGetValue(errorCode, out LinuxError errno)) - { - errno = (LinuxError)errorCode; - } - - return errno; - } - - private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags) - { - SocketFlags socketFlags = SocketFlags.None; - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob)) - { - socketFlags |= SocketFlags.OutOfBand; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek)) - { - socketFlags |= SocketFlags.Peek; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute)) - { - socketFlags |= SocketFlags.DontRoute; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc)) - { - socketFlags |= SocketFlags.Truncated; - } - - if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc)) - { - socketFlags |= SocketFlags.ControlDataTruncated; - } - - bsdSocketFlags &= ~(BsdSocketFlags.Oob | - BsdSocketFlags.Peek | - BsdSocketFlags.DontRoute | - BsdSocketFlags.Trunc | - BsdSocketFlags.CTrunc); - - if (bsdSocketFlags != BsdSocketFlags.None) - { - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}"); - } - - return socketFlags; - } - - private static bool TryConvertSocketOption(BsdSocketOption option, SocketOptionLevel level, out SocketOptionName name) - { - var table = level switch - { - SocketOptionLevel.Socket => _soSocketOptionMap, - SocketOptionLevel.IP => _ipSocketOptionMap, - SocketOptionLevel.Tcp => _tcpSocketOptionMap, - _ => null - }; - - if (table == null) - { - name = default; - return false; - } - - return table.TryGetValue(option, out name); - } - - private ResultCode WriteWinSock2Error(ServiceCtx context, WsaError errorCode) - { - return WriteBsdResult(context, -1, ConvertError(errorCode)); - } - - private ResultCode WriteBsdResult(ServiceCtx context, int result, LinuxError errorCode = 0) + private ResultCode WriteBsdResult(ServiceCtx context, int result, LinuxError errorCode = LinuxError.SUCCESS) { if (errorCode != LinuxError.SUCCESS) { @@ -240,100 +41,96 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd return ResultCode.Success; } - private BsdSocket RetrieveSocket(int socketFd) + private static AddressFamily ConvertBsdAddressFamily(BsdAddressFamily family) { - if (socketFd >= 0 && _sockets.Count > socketFd) + switch (family) { - return _sockets[socketFd]; + case BsdAddressFamily.Unspecified: + return AddressFamily.Unspecified; + case BsdAddressFamily.InterNetwork: + return AddressFamily.InterNetwork; + case BsdAddressFamily.InterNetworkV6: + return AddressFamily.InterNetworkV6; + case BsdAddressFamily.Unknown: + return AddressFamily.Unknown; + default: + throw new NotImplementedException(family.ToString()); } - - return null; } - private LinuxError SetResultErrno(Socket socket, int result) + private LinuxError SetResultErrno(IFileDescriptor socket, int result) { return result == 0 && !socket.Blocking ? LinuxError.EWOULDBLOCK : LinuxError.SUCCESS; } - private AddressFamily ConvertFromBsd(int domain) - { - if (domain == 2) - { - return AddressFamily.InterNetwork; - } - - // FIXME: AF_ROUTE ignored, is that really needed? - return AddressFamily.Unknown; - } - private ResultCode SocketInternal(ServiceCtx context, bool exempt) { - AddressFamily domain = (AddressFamily)context.RequestData.ReadInt32(); - SocketType type = (SocketType)context.RequestData.ReadInt32(); - ProtocolType protocol = (ProtocolType)context.RequestData.ReadInt32(); + BsdAddressFamily domain = (BsdAddressFamily)context.RequestData.ReadInt32(); + BsdSocketType type = (BsdSocketType)context.RequestData.ReadInt32(); + ProtocolType protocol = (ProtocolType)context.RequestData.ReadInt32(); - if (domain == AddressFamily.Unknown) + BsdSocketCreationFlags creationFlags = (BsdSocketCreationFlags)((int)type >> (int)BsdSocketCreationFlags.FlagsShift); + type &= BsdSocketType.TypeMask; + + if (domain == BsdAddressFamily.Unknown) { return WriteBsdResult(context, -1, LinuxError.EPROTONOSUPPORT); } - else if ((type == SocketType.Seqpacket || type == SocketType.Raw) && !_isPrivileged) + else if ((type == BsdSocketType.Seqpacket || type == BsdSocketType.Raw) && !_isPrivileged) { - if (domain != AddressFamily.InterNetwork || type != SocketType.Raw || protocol != ProtocolType.Icmp) + if (domain != BsdAddressFamily.InterNetwork || type != BsdSocketType.Raw || protocol != ProtocolType.Icmp) { return WriteBsdResult(context, -1, LinuxError.ENOENT); } } - BsdSocket newBsdSocket = new BsdSocket - { - Family = (int)domain, - Type = (int)type, - Protocol = (int)protocol, - Handle = new Socket(domain, type, protocol) - }; + AddressFamily netDomain = ConvertBsdAddressFamily(domain); - _sockets.Add(newBsdSocket); + if (protocol == ProtocolType.IP) + { + if (type == BsdSocketType.Stream) + { + protocol = ProtocolType.Tcp; + } + else if (type == BsdSocketType.Dgram) + { + protocol = ProtocolType.Udp; + } + } + + ISocket newBsdSocket = new ManagedSocket(netDomain, (SocketType)type, protocol); + newBsdSocket.Blocking = !creationFlags.HasFlag(BsdSocketCreationFlags.NonBlocking); + + LinuxError errno = LinuxError.SUCCESS; + + int newSockFd = _context.RegisterFileDescriptor(newBsdSocket); + + if (newSockFd == -1) + { + errno = LinuxError.EBADF; + } if (exempt) { - newBsdSocket.Handle.Disconnect(true); + newBsdSocket.Disconnect(); } - return WriteBsdResult(context, _sockets.Count - 1); + return WriteBsdResult(context, newSockFd, errno); } - private IPEndPoint ParseSockAddr(ServiceCtx context, ulong bufferPosition, ulong bufferSize) + private void WriteSockAddr(ServiceCtx context, ulong bufferPosition, ISocket socket, bool isRemote) { - int size = context.Memory.Read(bufferPosition); - int family = context.Memory.Read(bufferPosition + 1); - int port = BinaryPrimitives.ReverseEndianness(context.Memory.Read(bufferPosition + 2)); + IPEndPoint endPoint = isRemote ? socket.RemoteEndPoint : socket.LocalEndPoint; - byte[] rawIp = new byte[4]; - - context.Memory.Read(bufferPosition + 4, rawIp); - - return new IPEndPoint(new IPAddress(rawIp), port); - } - - private void WriteSockAddr(ServiceCtx context, ulong bufferPosition, IPEndPoint endPoint) - { - context.Memory.Write(bufferPosition, (byte)0); - context.Memory.Write(bufferPosition + 1, (byte)endPoint.AddressFamily); - context.Memory.Write(bufferPosition + 2, BinaryPrimitives.ReverseEndianness((ushort)endPoint.Port)); - context.Memory.Write(bufferPosition + 4, endPoint.Address.GetAddressBytes()); - } - - private void WriteSockAddr(ServiceCtx context, ulong bufferPosition, BsdSocket socket, bool isRemote) - { - IPEndPoint endPoint = (isRemote ? socket.Handle.RemoteEndPoint : socket.Handle.LocalEndPoint) as IPEndPoint; - - WriteSockAddr(context, bufferPosition, endPoint); + context.Memory.Write(bufferPosition, BsdSockAddr.FromIPEndPoint(endPoint)); } [CommandHipc(0)] // Initialize(nn::socket::BsdBufferConfig config, u64 pid, u64 transferMemorySize, KObject, pid) -> u32 bsd_errno public ResultCode RegisterClient(ServiceCtx context) { + _context = BsdContext.GetOrRegister(context.Request.HandleDesc.PId); + /* typedef struct { u32 version; // Observed 1 on 2.0 LibAppletWeb, 2 on 3.0. @@ -424,7 +221,6 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x21(); - if (timeout < -1 || fdsCount < 0 || (ulong)(fdsCount * 8) > bufferSize) { return WriteBsdResult(context, -1, LinuxError.EINVAL); @@ -434,75 +230,94 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd for (int i = 0; i < fdsCount; i++) { - int socketFd = context.Memory.Read(bufferPosition + (ulong)i * 8); + PollEventData pollEventData = context.Memory.Read(bufferPosition + (ulong)(i * Unsafe.SizeOf())); - BsdSocket socket = RetrieveSocket(socketFd); + IFileDescriptor fileDescriptor = _context.RetrieveFileDescriptor(pollEventData.SocketFd); - if (socket == null) + if (fileDescriptor == null) { - return WriteBsdResult(context, -1, LinuxError.EBADF);} + return WriteBsdResult(context, -1, LinuxError.EBADF); + } - PollEvent.EventTypeMask inputEvents = (PollEvent.EventTypeMask)context.Memory.Read(bufferPosition + (ulong)i * 8 + 4); - PollEvent.EventTypeMask outputEvents = (PollEvent.EventTypeMask)context.Memory.Read(bufferPosition + (ulong)i * 8 + 6); - - events[i] = new PollEvent(socketFd, socket, inputEvents, outputEvents); + events[i] = new PollEvent(pollEventData, fileDescriptor); } - List readEvents = new List(); - List writeEvents = new List(); - List errorEvents = new List(); + List discoveredEvents = new List(); + List[] eventsByPollManager = new List[_pollManagers.Count]; - foreach (PollEvent Event in events) + for (int i = 0; i < eventsByPollManager.Length; i++) { - bool isValidEvent = false; + eventsByPollManager[i] = new List(); - if ((Event.InputEvents & PollEvent.EventTypeMask.Input) != 0) + foreach (PollEvent evnt in events) { - readEvents.Add(Event.Socket.Handle); - errorEvents.Add(Event.Socket.Handle); - - isValidEvent = true; - } - - if ((Event.InputEvents & PollEvent.EventTypeMask.UrgentInput) != 0) - { - readEvents.Add(Event.Socket.Handle); - errorEvents.Add(Event.Socket.Handle); - - isValidEvent = true; - } - - if ((Event.InputEvents & PollEvent.EventTypeMask.Output) != 0) - { - writeEvents.Add(Event.Socket.Handle); - errorEvents.Add(Event.Socket.Handle); - - isValidEvent = true; - } - - if ((Event.InputEvents & PollEvent.EventTypeMask.Error) != 0) - { - errorEvents.Add(Event.Socket.Handle); - isValidEvent = true; - } - - if (!isValidEvent) - { - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported Poll input event type: {Event.InputEvents}"); - return WriteBsdResult(context, -1, LinuxError.EINVAL); + if (_pollManagers[i].IsCompatible(evnt)) + { + eventsByPollManager[i].Add(evnt); + discoveredEvents.Add(evnt); + } } } + foreach (PollEvent evnt in events) + { + if (!discoveredEvents.Contains(evnt)) + { + Logger.Error?.Print(LogClass.ServiceBsd, $"Poll operation is not supported for {evnt.FileDescriptor.GetType().Name}!"); + + return WriteBsdResult(context, -1, LinuxError.EBADF); + } + } + + int updateCount = 0; + + LinuxError errno = LinuxError.SUCCESS; + if (fdsCount != 0) { - try + bool IsUnexpectedLinuxError(LinuxError error) { - System.Net.Sockets.Socket.Select(readEvents, writeEvents, errorEvents, timeout); + return errno != LinuxError.SUCCESS && errno != LinuxError.ETIMEDOUT; } - catch (SocketException exception) + + // Hybrid approach + long budgetLeftMilliseconds; + + if (timeout == -1) { - return WriteWinSock2Error(context, (WsaError)exception.ErrorCode); + budgetLeftMilliseconds = PerformanceCounter.ElapsedMilliseconds + uint.MaxValue; } + else + { + budgetLeftMilliseconds = PerformanceCounter.ElapsedMilliseconds + timeout; + } + + do + { + for (int i = 0; i < eventsByPollManager.Length; i++) + { + if (eventsByPollManager[i].Count == 0) + { + continue; + } + + errno = _pollManagers[i].Poll(eventsByPollManager[i], 0, out updateCount); + + if (IsUnexpectedLinuxError(errno)) + { + break; + } + + if (updateCount > 0) + { + break; + } + } + + // If we are here, that mean nothing was availaible, sleep for 50ms + context.Device.System.KernelContext.Syscall.SleepThread(50 * 1000000); + } + while (PerformanceCounter.ElapsedMilliseconds < budgetLeftMilliseconds); } else if (timeout == -1) { @@ -511,47 +326,16 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd } else { - // FIXME: We should make the KThread sleep but we can't do much about it yet. - Thread.Sleep(timeout); + context.Device.System.KernelContext.Syscall.SleepThread(timeout); } + // TODO: Spanify for (int i = 0; i < fdsCount; i++) { - PollEvent Event = events[i]; - context.Memory.Write(bufferPosition + (ulong)i * 8, Event.SocketFd); - context.Memory.Write(bufferPosition + (ulong)i * 8 + 4, (short)Event.InputEvents); - - PollEvent.EventTypeMask outputEvents = 0; - - Socket socket = Event.Socket.Handle; - - if (errorEvents.Contains(socket)) - { - outputEvents |= PollEvent.EventTypeMask.Error; - - if (!socket.Connected || !socket.IsBound) - { - outputEvents |= PollEvent.EventTypeMask.Disconnected; - } - } - - if (readEvents.Contains(socket)) - { - if ((Event.InputEvents & PollEvent.EventTypeMask.Input) != 0) - { - outputEvents |= PollEvent.EventTypeMask.Input; - } - } - - if (writeEvents.Contains(socket)) - { - outputEvents |= PollEvent.EventTypeMask.Output; - } - - context.Memory.Write(bufferPosition + (ulong)i * 8 + 6, (short)outputEvents); + context.Memory.Write(bufferPosition + (ulong)(i * Unsafe.SizeOf()), events[i].Data); } - return WriteBsdResult(context, readEvents.Count + writeEvents.Count + errorEvents.Count, LinuxError.SUCCESS); + return WriteBsdResult(context, updateCount, errno); } [CommandHipc(7)] @@ -574,24 +358,21 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd (ulong receivePosition, ulong receiveLength) = context.Request.GetBufferType0x22(); + WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength); + LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); int result = -1; if (socket != null) { - byte[] receivedBuffer = new byte[receiveLength]; + errno = socket.Receive(out result, receiveRegion.Memory.Span, socketFlags); - try + if (errno == LinuxError.SUCCESS) { - result = socket.Handle.Receive(receivedBuffer, ConvertBsdSocketFlags(socketFlags)); - errno = SetResultErrno(socket.Handle, result); + SetResultErrno(socket, result); - context.Memory.Write(receivePosition, receivedBuffer); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); + receiveRegion.Dispose(); } } @@ -605,29 +386,26 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd int socketFd = context.RequestData.ReadInt32(); BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32(); - (ulong receivePosition, ulong receiveLength) = context.Request.GetBufferType0x22(); + (ulong receivePosition, ulong receiveLength) = context.Request.GetBufferType0x22(0); (ulong sockAddrOutPosition, ulong sockAddrOutSize) = context.Request.GetBufferType0x22(1); + WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength); + LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); int result = -1; if (socket != null) { - byte[] receivedBuffer = new byte[receiveLength]; - EndPoint endPoint = new IPEndPoint(IPAddress.Any, 0); + errno = socket.ReceiveFrom(out result, receiveRegion.Memory.Span, receiveRegion.Memory.Span.Length, socketFlags, out IPEndPoint endPoint); - try + if (errno == LinuxError.SUCCESS) { - result = socket.Handle.ReceiveFrom(receivedBuffer, receivedBuffer.Length, ConvertBsdSocketFlags(socketFlags), ref endPoint); - errno = SetResultErrno(socket.Handle, result); + SetResultErrno(socket, result); - context.Memory.Write(receivePosition, receivedBuffer); - WriteSockAddr(context, sockAddrOutPosition, (IPEndPoint)endPoint); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); + receiveRegion.Dispose(); + + context.Memory.Write(sockAddrOutPosition, BsdSockAddr.FromIPEndPoint(endPoint)); } } @@ -643,26 +421,20 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd (ulong sendPosition, ulong sendSize) = context.Request.GetBufferType0x21(); + ReadOnlySpan sendBuffer = context.Memory.GetSpan(sendPosition, (int)sendSize); + LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); int result = -1; if (socket != null) { - byte[] sendBuffer = new byte[sendSize]; + errno = socket.Send(out result, sendBuffer, socketFlags); - context.Memory.Read(sendPosition, sendBuffer); - - try + if (errno == LinuxError.SUCCESS) { - result = socket.Handle.Send(sendBuffer, ConvertBsdSocketFlags(socketFlags)); - errno = SetResultErrno(socket.Handle, result); + SetResultErrno(socket, result); } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - } - } return WriteBsdResult(context, result, errno); @@ -675,31 +447,25 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd int socketFd = context.RequestData.ReadInt32(); BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32(); - (ulong sendPosition, ulong sendSize) = context.Request.GetBufferType0x21(); + (ulong sendPosition, ulong sendSize) = context.Request.GetBufferType0x21(0); (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x21(1); + ReadOnlySpan sendBuffer = context.Memory.GetSpan(sendPosition, (int)sendSize); + LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); int result = -1; if (socket != null) { - byte[] sendBuffer = new byte[sendSize]; + IPEndPoint endPoint = context.Memory.Read(bufferPosition).ToIPEndPoint(); - context.Memory.Read(sendPosition, sendBuffer); + errno = socket.SendTo(out result, sendBuffer, sendBuffer.Length, socketFlags, endPoint); - EndPoint endPoint = ParseSockAddr(context, bufferPosition, bufferSize); - - try + if (errno == LinuxError.SUCCESS) { - result = socket.Handle.SendTo(sendBuffer, sendBuffer.Length, ConvertBsdSocketFlags(socketFlags), endPoint); - errno = SetResultErrno(socket.Handle, result); + SetResultErrno(socket, result); } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - } - } return WriteBsdResult(context, result, errno); @@ -714,22 +480,11 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x22(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { - errno = LinuxError.SUCCESS; - - Socket newSocket = null; - - try - { - newSocket = socket.Handle.Accept(); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - } + errno = socket.Accept(out ISocket newSocket); if (newSocket == null && errno == LinuxError.SUCCESS) { @@ -737,19 +492,18 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd } else if (errno == LinuxError.SUCCESS) { - BsdSocket newBsdSocket = new BsdSocket + int newSockFd = _context.RegisterFileDescriptor(newSocket); + + if (newSockFd == -1) { - Family = (int)newSocket.AddressFamily, - Type = (int)newSocket.SocketType, - Protocol = (int)newSocket.ProtocolType, - Handle = newSocket - }; + errno = LinuxError.EBADF; + } + else + { + WriteSockAddr(context, bufferPos, newSocket, true); + } - _sockets.Add(newBsdSocket); - - WriteSockAddr(context, bufferPos, newBsdSocket, true); - - WriteBsdResult(context, _sockets.Count - 1, errno); + WriteBsdResult(context, newSockFd, errno); context.ResponseData.Write(0x10); @@ -766,25 +520,16 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { int socketFd = context.RequestData.ReadInt32(); - (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x21(); + (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x21(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { - errno = LinuxError.SUCCESS; + IPEndPoint endPoint = context.Memory.Read(bufferPosition).ToIPEndPoint(); - try - { - IPEndPoint endPoint = ParseSockAddr(context, bufferPos, bufferSize); - - socket.Handle.Bind(endPoint); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - } + errno = socket.Bind(endPoint); } return WriteBsdResult(context, 0, errno); @@ -796,31 +541,16 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { int socketFd = context.RequestData.ReadInt32(); - (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x21(); + (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x21(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { - errno = LinuxError.SUCCESS; - try - { - IPEndPoint endPoint = ParseSockAddr(context, bufferPos, bufferSize); + IPEndPoint endPoint = context.Memory.Read(bufferPosition).ToIPEndPoint(); - socket.Handle.Connect(endPoint); - } - catch (SocketException exception) - { - if (!socket.Handle.Blocking && exception.ErrorCode == (int)WsaError.WSAEWOULDBLOCK) - { - errno = LinuxError.EINPROGRESS; - } - else - { - errno = ConvertError((WsaError)exception.ErrorCode); - } - } + errno = socket.Connect(endPoint); } return WriteBsdResult(context, 0, errno); @@ -832,18 +562,18 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd { int socketFd = context.RequestData.ReadInt32(); - (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x22(); + (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x22(); - LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + LinuxError errno = LinuxError.EBADF; + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { errno = LinuxError.SUCCESS; - WriteSockAddr(context, bufferPos, socket, true); + WriteSockAddr(context, bufferPosition, socket, true); WriteBsdResult(context, 0, errno); - context.ResponseData.Write(0x10); + context.ResponseData.Write(Unsafe.SizeOf()); } return WriteBsdResult(context, 0, errno); @@ -858,7 +588,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x22(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { @@ -866,7 +596,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd WriteSockAddr(context, bufferPos, socket, false); WriteBsdResult(context, 0, errno); - context.ResponseData.Write(0x10); + context.ResponseData.Write(Unsafe.SizeOf()); } return WriteBsdResult(context, 0, errno); @@ -881,13 +611,19 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd BsdSocketOption option = (BsdSocketOption)context.RequestData.ReadInt32(); (ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x22(); + WritableRegion optionValue = context.Memory.GetWritableRegion(bufferPosition, (int)bufferSize); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { - errno = HandleGetSocketOption(context, socket, option, level, bufferPosition, bufferSize); + errno = socket.GetSocketOption(option, level, optionValue.Memory.Span); + + if (errno == LinuxError.SUCCESS) + { + optionValue.Dispose(); + } } return WriteBsdResult(context, 0, errno); @@ -901,20 +637,11 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd int backlog = context.RequestData.ReadInt32(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { - errno = LinuxError.SUCCESS; - - try - { - socket.Handle.Listen(backlog); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - } + errno = socket.Listen(backlog); } return WriteBsdResult(context, 0, errno); @@ -929,7 +656,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd int bufferCount = context.RequestData.ReadInt32(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { @@ -965,7 +692,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd int result = 0; LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { @@ -973,11 +700,11 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd if (cmd == 0x3) { - result = !socket.Handle.Blocking ? 0x800 : 0; + result = !socket.Blocking ? 0x800 : 0; } else if (cmd == 0x4 && arg == 0x800) { - socket.Handle.Blocking = false; + socket.Blocking = false; result = 0; } else @@ -989,74 +716,6 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd return WriteBsdResult(context, result, errno); } - private static LinuxError HandleGetSocketOption( - ServiceCtx context, - BsdSocket socket, - BsdSocketOption option, - SocketOptionLevel level, - ulong optionValuePosition, - ulong optionValueSize) - { - try - { - if (!TryConvertSocketOption(option, level, out SocketOptionName optionName)) - { - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}"); - - return LinuxError.EOPNOTSUPP; - } - - byte[] optionValue = new byte[optionValueSize]; - - socket.Handle.GetSocketOption(level, optionName, optionValue); - context.Memory.Write(optionValuePosition, optionValue); - - return LinuxError.SUCCESS; - } - catch (SocketException exception) - { - return ConvertError((WsaError)exception.ErrorCode); - } - } - - private static LinuxError HandleSetSocketOption( - ServiceCtx context, - BsdSocket socket, - BsdSocketOption option, - SocketOptionLevel level, - ulong optionValuePosition, - ulong optionValueSize) - { - try - { - if (!TryConvertSocketOption(option, level, out SocketOptionName optionName)) - { - Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}"); - - return LinuxError.EOPNOTSUPP; - } - - int value = context.Memory.Read((ulong)optionValuePosition); - - if (option == BsdSocketOption.SoLinger) - { - int value2 = context.Memory.Read((ulong)optionValuePosition + 4); - - socket.Handle.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2)); - } - else - { - socket.Handle.SetSocketOption(level, optionName, value); - } - - return LinuxError.SUCCESS; - } - catch (SocketException exception) - { - return ConvertError((WsaError)exception.ErrorCode); - } - } - [CommandHipc(21)] // SetSockOpt(u32 socket, u32 level, u32 option_name, buffer option_value) -> (i32 ret, u32 bsd_errno) public ResultCode SetSockOpt(ServiceCtx context) @@ -1067,12 +726,14 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd (ulong bufferPos, ulong bufferSize) = context.Request.GetBufferType0x21(); + ReadOnlySpan optionValue = context.Memory.GetSpan(bufferPos, (int)bufferSize); + LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { - errno = HandleSetSocketOption(context, socket, option, level, bufferPos, bufferSize); + errno = socket.SetSocketOption(option, level, optionValue); } return WriteBsdResult(context, 0, errno); @@ -1086,7 +747,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd int how = context.RequestData.ReadInt32(); LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + ISocket socket = _context.RetrieveSocket(socketFd); if (socket != null) { @@ -1094,16 +755,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd if (how >= 0 && how <= 2) { - errno = LinuxError.SUCCESS; - - try - { - socket.Handle.Shutdown((SocketShutdown)how); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - } + errno = socket.Shutdown((BsdSocketShutdownFlags)how); } } @@ -1120,54 +772,33 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd if (how >= 0 && how <= 2) { - errno = LinuxError.SUCCESS; - - foreach (BsdSocket socket in _sockets) - { - if (socket != null) - { - try - { - socket.Handle.Shutdown((SocketShutdown)how); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); - break; - } - } - } + errno = _context.ShutdownAllSockets((BsdSocketShutdownFlags)how); } return WriteBsdResult(context, 0, errno); } [CommandHipc(24)] - // Write(u32 socket, buffer message) -> (i32 ret, u32 bsd_errno) + // Write(u32 fd, buffer message) -> (i32 ret, u32 bsd_errno) public ResultCode Write(ServiceCtx context) { - int socketFd = context.RequestData.ReadInt32(); + int fd = context.RequestData.ReadInt32(); (ulong sendPosition, ulong sendSize) = context.Request.GetBufferType0x21(); - LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); - int result = -1; + ReadOnlySpan sendBuffer = context.Memory.GetSpan(sendPosition, (int)sendSize); - if (socket != null) + LinuxError errno = LinuxError.EBADF; + IFileDescriptor file = _context.RetrieveFileDescriptor(fd); + int result = -1; + + if (file != null) { - byte[] sendBuffer = new byte[sendSize]; + errno = file.Write(out result, sendBuffer); - context.Memory.Read(sendPosition, sendBuffer); - - try + if (errno == LinuxError.SUCCESS) { - result = socket.Handle.Send(sendBuffer); - errno = SetResultErrno(socket.Handle, result); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); + SetResultErrno(file, result); } } @@ -1175,30 +806,28 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd } [CommandHipc(25)] - // Read(u32 socket) -> (i32 ret, u32 bsd_errno, buffer message) + // Read(u32 fd) -> (i32 ret, u32 bsd_errno, buffer message) public ResultCode Read(ServiceCtx context) { - int socketFd = context.RequestData.ReadInt32(); + int fd = context.RequestData.ReadInt32(); (ulong receivePosition, ulong receiveLength) = context.Request.GetBufferType0x22(); - LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); - int result = -1; + WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength); - if (socket != null) + LinuxError errno = LinuxError.EBADF; + IFileDescriptor file = _context.RetrieveFileDescriptor(fd); + int result = -1; + + if (file != null) { - byte[] receivedBuffer = new byte[receiveLength]; + errno = file.Read(out result, receiveRegion.Memory.Span); - try + if (errno == LinuxError.SUCCESS) { - result = socket.Handle.Receive(receivedBuffer); - errno = SetResultErrno(socket.Handle, result); - context.Memory.Write(receivePosition, receivedBuffer); - } - catch (SocketException exception) - { - errno = ConvertError((WsaError)exception.ErrorCode); + SetResultErrno(file, result); + + receiveRegion.Dispose(); } } @@ -1206,20 +835,15 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd } [CommandHipc(26)] - // Close(u32 socket) -> (i32 ret, u32 bsd_errno) + // Close(u32 fd) -> (i32 ret, u32 bsd_errno) public ResultCode Close(ServiceCtx context) { - int socketFd = context.RequestData.ReadInt32(); + int fd = context.RequestData.ReadInt32(); - LinuxError errno = LinuxError.EBADF; - BsdSocket socket = RetrieveSocket(socketFd); + LinuxError errno = LinuxError.EBADF; - if (socket != null) + if (_context.CloseFileDescriptor(fd)) { - socket.Handle.Close(); - - _sockets[socketFd] = null; - errno = LinuxError.SUCCESS; } @@ -1227,29 +851,49 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd } [CommandHipc(27)] - // DuplicateSocket(u32 socket, u64 reserved) -> (i32 ret, u32 bsd_errno) + // DuplicateSocket(u32 fd, u64 reserved) -> (i32 ret, u32 bsd_errno) public ResultCode DuplicateSocket(ServiceCtx context) { - int socketFd = context.RequestData.ReadInt32(); + int fd = context.RequestData.ReadInt32(); ulong reserved = context.RequestData.ReadUInt64(); - LinuxError errno = LinuxError.ENOENT; - int newSockFd = -1; + LinuxError errno = LinuxError.ENOENT; + int newSockFd = -1; if (_isPrivileged) { - errno = LinuxError.EBADF; + errno = LinuxError.SUCCESS; - BsdSocket oldSocket = RetrieveSocket(socketFd); + newSockFd = _context.DuplicateFileDescriptor(fd); - if (oldSocket != null) + if (newSockFd == -1) { - _sockets.Add(oldSocket); - newSockFd = _sockets.Count - 1; + errno = LinuxError.EBADF; } } return WriteBsdResult(context, newSockFd, errno); } + + [CommandHipc(31)] // 7.0.0+ + // EventFd(u64 initval, nn::socket::EventFdFlags flags) -> (i32 ret, u32 bsd_errno) + public ResultCode EventFd(ServiceCtx context) + { + ulong initialValue = context.RequestData.ReadUInt64(); + EventFdFlags flags = (EventFdFlags)context.RequestData.ReadUInt32(); + + EventFileDescriptor newEventFile = new EventFileDescriptor(initialValue, flags); + + LinuxError errno = LinuxError.SUCCESS; + + int newSockFd = _context.RegisterFileDescriptor(newEventFile); + + if (newSockFd == -1) + { + errno = LinuxError.EBADF; + } + + return WriteBsdResult(context, newSockFd, errno); + } } } \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs new file mode 100644 index 00000000..56f67539 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/IFileDescriptor.cs @@ -0,0 +1,14 @@ +using System; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + interface IFileDescriptor : IDisposable + { + bool Blocking { get; set; } + int Refcount { get; set; } + + LinuxError Read(out int readSize, Span buffer); + + LinuxError Write(out int writeSize, ReadOnlySpan buffer); + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs new file mode 100644 index 00000000..ee6bd9e8 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/ISocket.cs @@ -0,0 +1,47 @@ +using System; +using System.Net; +using System.Net.Sockets; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + interface ISocket : IDisposable, IFileDescriptor + { + IPEndPoint RemoteEndPoint { get; } + IPEndPoint LocalEndPoint { get; } + + AddressFamily AddressFamily { get; } + + SocketType SocketType { get; } + + ProtocolType ProtocolType { get; } + + IntPtr Handle { get; } + + LinuxError Receive(out int receiveSize, Span buffer, BsdSocketFlags flags); + + LinuxError ReceiveFrom(out int receiveSize, Span buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint); + + LinuxError Send(out int sendSize, ReadOnlySpan buffer, BsdSocketFlags flags); + + LinuxError SendTo(out int sendSize, ReadOnlySpan buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint); + + LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span optionValue); + LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan optionValue); + + bool Poll(int microSeconds, SelectMode mode); + + LinuxError Bind(IPEndPoint localEndPoint); + + LinuxError Connect(IPEndPoint remoteEndPoint); + + LinuxError Listen(int backlog); + + LinuxError Accept(out ISocket newSocket); + + void Disconnect(); + + LinuxError Shutdown(BsdSocketShutdownFlags how); + + void Close(); + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs new file mode 100644 index 00000000..e92b42ef --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptor.cs @@ -0,0 +1,130 @@ +using System; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class EventFileDescriptor : IFileDescriptor + { + private ulong _value; + private readonly EventFdFlags _flags; + private AutoResetEvent _event; + + private object _lock = new object(); + + public bool Blocking { get => !_flags.HasFlag(EventFdFlags.NonBlocking); set => throw new NotSupportedException(); } + + public ManualResetEvent WriteEvent { get; } + public ManualResetEvent ReadEvent { get; } + + public EventFileDescriptor(ulong value, EventFdFlags flags) + { + _value = value; + _flags = flags; + _event = new AutoResetEvent(false); + + WriteEvent = new ManualResetEvent(true); + ReadEvent = new ManualResetEvent(true); + } + + public int Refcount { get; set; } + + public void Dispose() + { + _event.Dispose(); + WriteEvent.Dispose(); + ReadEvent.Dispose(); + } + + public LinuxError Read(out int readSize, Span buffer) + { + if (buffer.Length < sizeof(ulong)) + { + readSize = 0; + + return LinuxError.EINVAL; + } + + ReadEvent.Reset(); + + lock (_lock) + { + ref ulong count = ref MemoryMarshal.Cast(buffer)[0]; + + if (_value == 0) + { + if (Blocking) + { + while (_value == 0) + { + _event.WaitOne(); + } + } + else + { + readSize = 0; + + return LinuxError.EAGAIN; + } + } + + readSize = sizeof(ulong); + + if (_flags.HasFlag(EventFdFlags.Semaphore)) + { + --_value; + + count = 1; + } + else + { + count = _value; + + _value = 0; + } + + ReadEvent.Set(); + + return LinuxError.SUCCESS; + } + } + + public LinuxError Write(out int writeSize, ReadOnlySpan buffer) + { + if (!MemoryMarshal.TryRead(buffer, out ulong count) || count == ulong.MaxValue) + { + writeSize = 0; + + return LinuxError.EINVAL; + } + + WriteEvent.Reset(); + + lock (_lock) + { + if (_value > _value + count) + { + if (Blocking) + { + _event.WaitOne(); + } + else + { + writeSize = 0; + + return LinuxError.EAGAIN; + } + } + + writeSize = sizeof(ulong); + + _value += count; + _event.Set(); + + WriteEvent.Set(); + + return LinuxError.SUCCESS; + } + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptorPollManager.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptorPollManager.cs new file mode 100644 index 00000000..8bd9652b --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/EventFileDescriptorPollManager.cs @@ -0,0 +1,96 @@ +using Ryujinx.Common.Logging; +using System.Collections.Generic; +using System.Threading; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class EventFileDescriptorPollManager : IPollManager + { + private static EventFileDescriptorPollManager _instance; + + public static EventFileDescriptorPollManager Instance + { + get + { + if (_instance == null) + { + _instance = new EventFileDescriptorPollManager(); + } + + return _instance; + } + } + + public bool IsCompatible(PollEvent evnt) + { + return evnt.FileDescriptor is EventFileDescriptor; + } + + public LinuxError Poll(List events, int timeoutMilliseconds, out int updatedCount) + { + updatedCount = 0; + + List waiters = new List(); + + for (int i = 0; i < events.Count; i++) + { + PollEvent evnt = events[i]; + + EventFileDescriptor socket = (EventFileDescriptor)evnt.FileDescriptor; + + bool isValidEvent = false; + + if (evnt.Data.InputEvents.HasFlag(PollEventTypeMask.Input) || + evnt.Data.InputEvents.HasFlag(PollEventTypeMask.UrgentInput)) + { + waiters.Add(socket.ReadEvent); + + isValidEvent = true; + } + if (evnt.Data.InputEvents.HasFlag(PollEventTypeMask.Output)) + { + waiters.Add(socket.WriteEvent); + + isValidEvent = true; + } + + if (!isValidEvent) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported Poll input event type: {evnt.Data.InputEvents}"); + + return LinuxError.EINVAL; + } + } + + int index = WaitHandle.WaitAny(waiters.ToArray(), timeoutMilliseconds); + + if (index != WaitHandle.WaitTimeout) + { + for (int i = 0; i < events.Count; i++) + { + PollEvent evnt = events[i]; + + EventFileDescriptor socket = (EventFileDescriptor)evnt.FileDescriptor; + + if ((evnt.Data.InputEvents.HasFlag(PollEventTypeMask.Input) || + evnt.Data.InputEvents.HasFlag(PollEventTypeMask.UrgentInput)) + && socket.ReadEvent.WaitOne(0)) + { + waiters.Add(socket.ReadEvent); + } + if ((evnt.Data.InputEvents.HasFlag(PollEventTypeMask.Output)) + && socket.WriteEvent.WaitOne(0)) + { + waiters.Add(socket.WriteEvent); + } + } + } + else + { + return LinuxError.ETIMEDOUT; + } + + return LinuxError.SUCCESS; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs new file mode 100644 index 00000000..349dbde0 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocket.cs @@ -0,0 +1,338 @@ +using Ryujinx.Common.Logging; +using System; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class ManagedSocket : ISocket + { + public int Refcount { get; set; } + + public AddressFamily AddressFamily => Socket.AddressFamily; + + public SocketType SocketType => Socket.SocketType; + + public ProtocolType ProtocolType => Socket.ProtocolType; + + public bool Blocking { get => Socket.Blocking; set => Socket.Blocking = value; } + + public IntPtr Handle => Socket.Handle; + + public IPEndPoint RemoteEndPoint => Socket.RemoteEndPoint as IPEndPoint; + + public IPEndPoint LocalEndPoint => Socket.LocalEndPoint as IPEndPoint; + + public Socket Socket { get; } + + public ManagedSocket(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType) + { + Socket = new Socket(addressFamily, socketType, protocolType); + Refcount = 1; + } + + private ManagedSocket(Socket socket) + { + Socket = socket; + Refcount = 1; + } + + private static SocketFlags ConvertBsdSocketFlags(BsdSocketFlags bsdSocketFlags) + { + SocketFlags socketFlags = SocketFlags.None; + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.Oob)) + { + socketFlags |= SocketFlags.OutOfBand; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.Peek)) + { + socketFlags |= SocketFlags.Peek; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.DontRoute)) + { + socketFlags |= SocketFlags.DontRoute; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.Trunc)) + { + socketFlags |= SocketFlags.Truncated; + } + + if (bsdSocketFlags.HasFlag(BsdSocketFlags.CTrunc)) + { + socketFlags |= SocketFlags.ControlDataTruncated; + } + + bsdSocketFlags &= ~(BsdSocketFlags.Oob | + BsdSocketFlags.Peek | + BsdSocketFlags.DontRoute | + BsdSocketFlags.DontWait | + BsdSocketFlags.Trunc | + BsdSocketFlags.CTrunc); + + if (bsdSocketFlags != BsdSocketFlags.None) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported socket flags: {bsdSocketFlags}"); + } + + return socketFlags; + } + + public LinuxError Accept(out ISocket newSocket) + { + try + { + newSocket = new ManagedSocket(Socket.Accept()); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + newSocket = null; + + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError Bind(IPEndPoint localEndPoint) + { + try + { + Socket.Bind(localEndPoint); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public void Close() + { + Socket.Close(); + } + + public LinuxError Connect(IPEndPoint remoteEndPoint) + { + try + { + Socket.Connect(remoteEndPoint); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + if (!Blocking && exception.ErrorCode == (int)WsaError.WSAEWOULDBLOCK) + { + return LinuxError.EINPROGRESS; + } + else + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + } + + public void Disconnect() + { + Socket.Disconnect(true); + } + + public void Dispose() + { + Socket.Close(); + Socket.Dispose(); + } + + public LinuxError Listen(int backlog) + { + try + { + Socket.Listen(backlog); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public bool Poll(int microSeconds, SelectMode mode) + { + return Socket.Poll(microSeconds, mode); + } + + public LinuxError Shutdown(BsdSocketShutdownFlags how) + { + try + { + Socket.Shutdown((SocketShutdown)how); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError Receive(out int receiveSize, Span buffer, BsdSocketFlags flags) + { + try + { + receiveSize = Socket.Receive(buffer, ConvertBsdSocketFlags(flags)); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + receiveSize = -1; + + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError ReceiveFrom(out int receiveSize, Span buffer, int size, BsdSocketFlags flags, out IPEndPoint remoteEndPoint) + { + remoteEndPoint = new IPEndPoint(IPAddress.Any, 0); + + LinuxError result; + + bool shouldBlockAfterOperation = false; + + try + { + EndPoint temp = new IPEndPoint(IPAddress.Any, 0); + + if (Blocking && flags.HasFlag(BsdSocketFlags.DontWait)) + { + Blocking = false; + shouldBlockAfterOperation = true; + } + + receiveSize = Socket.ReceiveFrom(buffer[..size], ConvertBsdSocketFlags(flags), ref temp); + + remoteEndPoint = (IPEndPoint)temp; + result = LinuxError.SUCCESS; + } + catch (SocketException exception) + { + receiveSize = -1; + + result = WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + if (shouldBlockAfterOperation) + { + Blocking = true; + } + + return result; + } + + public LinuxError Send(out int sendSize, ReadOnlySpan buffer, BsdSocketFlags flags) + { + try + { + sendSize = Socket.Send(buffer, ConvertBsdSocketFlags(flags)); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + sendSize = -1; + + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError SendTo(out int sendSize, ReadOnlySpan buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint) + { + try + { + sendSize = Socket.SendTo(buffer[..size], ConvertBsdSocketFlags(flags), remoteEndPoint); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + sendSize = -1; + + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span optionValue) + { + try + { + if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported GetSockOpt Option: {option} Level: {level}"); + + return LinuxError.EOPNOTSUPP; + } + + byte[] tempOptionValue = new byte[optionValue.Length]; + + Socket.GetSocketOption(level, optionName, tempOptionValue); + + tempOptionValue.AsSpan().CopyTo(optionValue); + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan optionValue) + { + try + { + if (!WinSockHelper.TryConvertSocketOption(option, level, out SocketOptionName optionName)) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported SetSockOpt Option: {option} Level: {level}"); + + return LinuxError.EOPNOTSUPP; + } + + int value = MemoryMarshal.Read(optionValue); + + if (option == BsdSocketOption.SoLinger) + { + int value2 = MemoryMarshal.Read(optionValue[4..]); + + Socket.SetSocketOption(level, SocketOptionName.Linger, new LingerOption(value != 0, value2)); + } + else + { + Socket.SetSocketOption(level, optionName, value); + } + + return LinuxError.SUCCESS; + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + } + + public LinuxError Read(out int readSize, Span buffer) + { + return Receive(out readSize, buffer, BsdSocketFlags.None); + } + + public LinuxError Write(out int writeSize, ReadOnlySpan buffer) + { + return Send(out writeSize, buffer, BsdSocketFlags.None); + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocketPollManager.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocketPollManager.cs new file mode 100644 index 00000000..b2414bc1 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/ManagedSocketPollManager.cs @@ -0,0 +1,129 @@ +using Ryujinx.Common.Logging; +using System.Collections.Generic; +using System.Net.Sockets; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + class ManagedSocketPollManager : IPollManager + { + private static ManagedSocketPollManager _instance; + + public static ManagedSocketPollManager Instance + { + get + { + if (_instance == null) + { + _instance = new ManagedSocketPollManager(); + } + + return _instance; + } + } + + public bool IsCompatible(PollEvent evnt) + { + return evnt.FileDescriptor is ManagedSocket; + } + + public LinuxError Poll(List events, int timeoutMilliseconds, out int updatedCount) + { + List readEvents = new List(); + List writeEvents = new List(); + List errorEvents = new List(); + + updatedCount = 0; + + foreach (PollEvent evnt in events) + { + ManagedSocket socket = (ManagedSocket)evnt.FileDescriptor; + + bool isValidEvent = false; + + if ((evnt.Data.InputEvents & PollEventTypeMask.Input) != 0) + { + readEvents.Add(socket.Socket); + errorEvents.Add(socket.Socket); + + isValidEvent = true; + } + + if ((evnt.Data.InputEvents & PollEventTypeMask.UrgentInput) != 0) + { + readEvents.Add(socket.Socket); + errorEvents.Add(socket.Socket); + + isValidEvent = true; + } + + if ((evnt.Data.InputEvents & PollEventTypeMask.Output) != 0) + { + writeEvents.Add(socket.Socket); + errorEvents.Add(socket.Socket); + + isValidEvent = true; + } + + if ((evnt.Data.InputEvents & PollEventTypeMask.Error) != 0) + { + errorEvents.Add(socket.Socket); + + isValidEvent = true; + } + + if (!isValidEvent) + { + Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported Poll input event type: {evnt.Data.InputEvents}"); + return LinuxError.EINVAL; + } + } + + try + { + int actualTimeoutMicroseconds = timeoutMilliseconds == -1 ? -1 : timeoutMilliseconds * 1000; + + Socket.Select(readEvents, writeEvents, errorEvents, actualTimeoutMicroseconds); + } + catch (SocketException exception) + { + return WinSockHelper.ConvertError((WsaError)exception.ErrorCode); + } + + foreach (PollEvent evnt in events) + { + Socket socket = ((ManagedSocket)evnt.FileDescriptor).Socket; + + PollEventTypeMask outputEvents = 0; + + if (errorEvents.Contains(socket)) + { + outputEvents |= PollEventTypeMask.Error; + + if (!socket.Connected || !socket.IsBound) + { + outputEvents |= PollEventTypeMask.Disconnected; + } + } + + if (readEvents.Contains(socket)) + { + if ((evnt.Data.InputEvents & PollEventTypeMask.Input) != 0) + { + outputEvents |= PollEventTypeMask.Input; + } + } + + if (writeEvents.Contains(socket)) + { + outputEvents |= PollEventTypeMask.Output; + } + + evnt.Data.OutputEvents = outputEvents; + } + + updatedCount = readEvents.Count + writeEvents.Count + errorEvents.Count; + + return LinuxError.SUCCESS; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/WSAError.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WSAError.cs similarity index 100% rename from Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/WSAError.cs rename to Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WSAError.cs diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs new file mode 100644 index 00000000..ad12745e --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Impl/WinSockHelper.cs @@ -0,0 +1,165 @@ +using System.Collections.Generic; +using System.Net.Sockets; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + static class WinSockHelper + { + private static readonly Dictionary _errorMap = new() + { + // WSAEINTR + {WsaError.WSAEINTR, LinuxError.EINTR}, + // WSAEWOULDBLOCK + {WsaError.WSAEWOULDBLOCK, LinuxError.EWOULDBLOCK}, + // WSAEINPROGRESS + {WsaError.WSAEINPROGRESS, LinuxError.EINPROGRESS}, + // WSAEALREADY + {WsaError.WSAEALREADY, LinuxError.EALREADY}, + // WSAENOTSOCK + {WsaError.WSAENOTSOCK, LinuxError.ENOTSOCK}, + // WSAEDESTADDRREQ + {WsaError.WSAEDESTADDRREQ, LinuxError.EDESTADDRREQ}, + // WSAEMSGSIZE + {WsaError.WSAEMSGSIZE, LinuxError.EMSGSIZE}, + // WSAEPROTOTYPE + {WsaError.WSAEPROTOTYPE, LinuxError.EPROTOTYPE}, + // WSAENOPROTOOPT + {WsaError.WSAENOPROTOOPT, LinuxError.ENOPROTOOPT}, + // WSAEPROTONOSUPPORT + {WsaError.WSAEPROTONOSUPPORT, LinuxError.EPROTONOSUPPORT}, + // WSAESOCKTNOSUPPORT + {WsaError.WSAESOCKTNOSUPPORT, LinuxError.ESOCKTNOSUPPORT}, + // WSAEOPNOTSUPP + {WsaError.WSAEOPNOTSUPP, LinuxError.EOPNOTSUPP}, + // WSAEPFNOSUPPORT + {WsaError.WSAEPFNOSUPPORT, LinuxError.EPFNOSUPPORT}, + // WSAEAFNOSUPPORT + {WsaError.WSAEAFNOSUPPORT, LinuxError.EAFNOSUPPORT}, + // WSAEADDRINUSE + {WsaError.WSAEADDRINUSE, LinuxError.EADDRINUSE}, + // WSAEADDRNOTAVAIL + {WsaError.WSAEADDRNOTAVAIL, LinuxError.EADDRNOTAVAIL}, + // WSAENETDOWN + {WsaError.WSAENETDOWN, LinuxError.ENETDOWN}, + // WSAENETUNREACH + {WsaError.WSAENETUNREACH, LinuxError.ENETUNREACH}, + // WSAENETRESET + {WsaError.WSAENETRESET, LinuxError.ENETRESET}, + // WSAECONNABORTED + {WsaError.WSAECONNABORTED, LinuxError.ECONNABORTED}, + // WSAECONNRESET + {WsaError.WSAECONNRESET, LinuxError.ECONNRESET}, + // WSAENOBUFS + {WsaError.WSAENOBUFS, LinuxError.ENOBUFS}, + // WSAEISCONN + {WsaError.WSAEISCONN, LinuxError.EISCONN}, + // WSAENOTCONN + {WsaError.WSAENOTCONN, LinuxError.ENOTCONN}, + // WSAESHUTDOWN + {WsaError.WSAESHUTDOWN, LinuxError.ESHUTDOWN}, + // WSAETOOMANYREFS + {WsaError.WSAETOOMANYREFS, LinuxError.ETOOMANYREFS}, + // WSAETIMEDOUT + {WsaError.WSAETIMEDOUT, LinuxError.ETIMEDOUT}, + // WSAECONNREFUSED + {WsaError.WSAECONNREFUSED, LinuxError.ECONNREFUSED}, + // WSAELOOP + {WsaError.WSAELOOP, LinuxError.ELOOP}, + // WSAENAMETOOLONG + {WsaError.WSAENAMETOOLONG, LinuxError.ENAMETOOLONG}, + // WSAEHOSTDOWN + {WsaError.WSAEHOSTDOWN, LinuxError.EHOSTDOWN}, + // WSAEHOSTUNREACH + {WsaError.WSAEHOSTUNREACH, LinuxError.EHOSTUNREACH}, + // WSAENOTEMPTY + {WsaError.WSAENOTEMPTY, LinuxError.ENOTEMPTY}, + // WSAEUSERS + {WsaError.WSAEUSERS, LinuxError.EUSERS}, + // WSAEDQUOT + {WsaError.WSAEDQUOT, LinuxError.EDQUOT}, + // WSAESTALE + {WsaError.WSAESTALE, LinuxError.ESTALE}, + // WSAEREMOTE + {WsaError.WSAEREMOTE, LinuxError.EREMOTE}, + // WSAEINVAL + {WsaError.WSAEINVAL, LinuxError.EINVAL}, + // WSAEFAULT + {WsaError.WSAEFAULT, LinuxError.EFAULT}, + // NOERROR + {0, 0} + }; + + private static readonly Dictionary _soSocketOptionMap = new() + { + { BsdSocketOption.SoDebug, SocketOptionName.Debug }, + { BsdSocketOption.SoReuseAddr, SocketOptionName.ReuseAddress }, + { BsdSocketOption.SoKeepAlive, SocketOptionName.KeepAlive }, + { BsdSocketOption.SoDontRoute, SocketOptionName.DontRoute }, + { BsdSocketOption.SoBroadcast, SocketOptionName.Broadcast }, + { BsdSocketOption.SoUseLoopBack, SocketOptionName.UseLoopback }, + { BsdSocketOption.SoLinger, SocketOptionName.Linger }, + { BsdSocketOption.SoOobInline, SocketOptionName.OutOfBandInline }, + { BsdSocketOption.SoReusePort, SocketOptionName.ReuseAddress }, + { BsdSocketOption.SoSndBuf, SocketOptionName.SendBuffer }, + { BsdSocketOption.SoRcvBuf, SocketOptionName.ReceiveBuffer }, + { BsdSocketOption.SoSndLoWat, SocketOptionName.SendLowWater }, + { BsdSocketOption.SoRcvLoWat, SocketOptionName.ReceiveLowWater }, + { BsdSocketOption.SoSndTimeo, SocketOptionName.SendTimeout }, + { BsdSocketOption.SoRcvTimeo, SocketOptionName.ReceiveTimeout }, + { BsdSocketOption.SoError, SocketOptionName.Error }, + { BsdSocketOption.SoType, SocketOptionName.Type } + }; + + private static readonly Dictionary _ipSocketOptionMap = new() + { + { BsdSocketOption.IpOptions, SocketOptionName.IPOptions }, + { BsdSocketOption.IpHdrIncl, SocketOptionName.HeaderIncluded }, + { BsdSocketOption.IpTtl, SocketOptionName.IpTimeToLive }, + { BsdSocketOption.IpMulticastIf, SocketOptionName.MulticastInterface }, + { BsdSocketOption.IpMulticastTtl, SocketOptionName.MulticastTimeToLive }, + { BsdSocketOption.IpMulticastLoop, SocketOptionName.MulticastLoopback }, + { BsdSocketOption.IpAddMembership, SocketOptionName.AddMembership }, + { BsdSocketOption.IpDropMembership, SocketOptionName.DropMembership }, + { BsdSocketOption.IpDontFrag, SocketOptionName.DontFragment }, + { BsdSocketOption.IpAddSourceMembership, SocketOptionName.AddSourceMembership }, + { BsdSocketOption.IpDropSourceMembership, SocketOptionName.DropSourceMembership } + }; + + private static readonly Dictionary _tcpSocketOptionMap = new() + { + { BsdSocketOption.TcpNoDelay, SocketOptionName.NoDelay }, + { BsdSocketOption.TcpKeepIdle, SocketOptionName.TcpKeepAliveTime }, + { BsdSocketOption.TcpKeepIntvl, SocketOptionName.TcpKeepAliveInterval }, + { BsdSocketOption.TcpKeepCnt, SocketOptionName.TcpKeepAliveRetryCount } + }; + + public static LinuxError ConvertError(WsaError errorCode) + { + if (!_errorMap.TryGetValue(errorCode, out LinuxError errno)) + { + errno = (LinuxError)errorCode; + } + + return errno; + } + + public static bool TryConvertSocketOption(BsdSocketOption option, SocketOptionLevel level, out SocketOptionName name) + { + var table = level switch + { + SocketOptionLevel.Socket => _soSocketOptionMap, + SocketOptionLevel.IP => _ipSocketOptionMap, + SocketOptionLevel.Tcp => _tcpSocketOptionMap, + _ => null + }; + + if (table == null) + { + name = default; + return false; + } + + return table.TryGetValue(option, out name); + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdAddressFamily.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdAddressFamily.cs new file mode 100644 index 00000000..dcc9f0fd --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdAddressFamily.cs @@ -0,0 +1,11 @@ +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + enum BsdAddressFamily : uint + { + Unspecified, + InterNetwork = 2, + InterNetworkV6 = 28, + + Unknown = uint.MaxValue + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSockAddr.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSockAddr.cs new file mode 100644 index 00000000..916ca2bb --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSockAddr.cs @@ -0,0 +1,39 @@ +using Ryujinx.Common.Memory; +using System; +using System.Net; +using System.Runtime.InteropServices; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + [StructLayout(LayoutKind.Sequential, Pack = 1, Size = 0x10)] + struct BsdSockAddr + { + public byte Length; + public byte Family; + public ushort Port; + public Array4 Address; + private Array8 _reserved; + + public IPEndPoint ToIPEndPoint() + { + IPAddress address = new IPAddress(Address.ToSpan()); + int port = (ushort)IPAddress.NetworkToHostOrder((short)Port); + + return new IPEndPoint(address, port); + } + + public static BsdSockAddr FromIPEndPoint(IPEndPoint endpoint) + { + BsdSockAddr result = new BsdSockAddr + { + Length = 0, + Family = (byte)endpoint.AddressFamily, + Port = (ushort)IPAddress.HostToNetworkOrder((short)endpoint.Port) + }; + + endpoint.Address.GetAddressBytes().AsSpan().CopyTo(result.Address.ToSpan()); + + return result; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocket.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocket.cs deleted file mode 100644 index 2d5bf429..00000000 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocket.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System.Net.Sockets; - -namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd -{ - class BsdSocket - { - public int Family; - public int Type; - public int Protocol; - - public Socket Handle; - } -} \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketCreationFlags.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketCreationFlags.cs new file mode 100644 index 00000000..77718800 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketCreationFlags.cs @@ -0,0 +1,14 @@ +using System; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + [Flags] + enum BsdSocketCreationFlags + { + None = 0, + CloseOnExecution = 1, + NonBlocking = 2, + + FlagsShift = 28 + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketShutdownFlags.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketShutdownFlags.cs new file mode 100644 index 00000000..2588376b --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketShutdownFlags.cs @@ -0,0 +1,9 @@ +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + enum BsdSocketShutdownFlags + { + Receive, + Send, + ReceiveAndSend + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketType.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketType.cs new file mode 100644 index 00000000..9b13e669 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdSocketType.cs @@ -0,0 +1,13 @@ +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + enum BsdSocketType + { + Stream = 1, + Dgram, + Raw, + Rdm, + Seqpacket, + + TypeMask = 0xFFFFFFF, + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/EventFdFlags.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/EventFdFlags.cs new file mode 100644 index 00000000..7d08fb24 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/EventFdFlags.cs @@ -0,0 +1,12 @@ +using System; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + [Flags] + enum EventFdFlags : uint + { + None = 0, + Semaphore = 1 << 0, + NonBlocking = 1 << 2 + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/IPollManager.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/IPollManager.cs new file mode 100644 index 00000000..8b0959fd --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/IPollManager.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + interface IPollManager + { + bool IsCompatible(PollEvent evnt); + + LinuxError Poll(List events, int timeoutMilliseconds, out int updatedCount); + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEvent.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEvent.cs index ff47a4c7..8056e7a8 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEvent.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEvent.cs @@ -2,27 +2,13 @@ { class PollEvent { - public enum EventTypeMask - { - Input = 1, - UrgentInput = 2, - Output = 4, - Error = 8, - Disconnected = 0x10, - Invalid = 0x20 - } + public PollEventData Data; + public IFileDescriptor FileDescriptor { get; } - public int SocketFd { get; private set; } - public BsdSocket Socket { get; private set; } - public EventTypeMask InputEvents { get; private set; } - public EventTypeMask OutputEvents { get; private set; } - - public PollEvent(int socketFd, BsdSocket socket, EventTypeMask inputEvents, EventTypeMask outputEvents) + public PollEvent(PollEventData data, IFileDescriptor fileDescriptor) { - SocketFd = socketFd; - Socket = socket; - InputEvents = inputEvents; - OutputEvents = outputEvents; + Data = data; + FileDescriptor = fileDescriptor; } } } \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventData.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventData.cs new file mode 100644 index 00000000..ee400e69 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventData.cs @@ -0,0 +1,13 @@ +using System; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + struct PollEventData + { +#pragma warning disable CS0649 + public int SocketFd; + public PollEventTypeMask InputEvents; +#pragma warning restore CS0649 + public PollEventTypeMask OutputEvents; + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventTypeMask.cs b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventTypeMask.cs new file mode 100644 index 00000000..899072bf --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/PollEventTypeMask.cs @@ -0,0 +1,15 @@ +using System; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd +{ + [Flags] + enum PollEventTypeMask : ushort + { + Input = 1, + UrgentInput = 2, + Output = 4, + Error = 8, + Disconnected = 0x10, + Invalid = 0x20 + } +}