sfdnsres: Fix serialization issues (#2992)

* sfdnsres: Fix serialization issues

Fix a crash on Monster Hunter Rise

* Address gdkchan's comments

* Address gdkchan's comments
This commit is contained in:
Mary 2022-01-12 17:43:00 +01:00 committed by GitHub
parent f4bbc019b9
commit d300a5a45b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 244 additions and 54 deletions

View file

@ -1,4 +1,5 @@
using Ryujinx.Common.Logging; using Ryujinx.Common.Logging;
using Ryujinx.Common.Memory;
using Ryujinx.Cpu; using Ryujinx.Cpu;
using Ryujinx.HLE.HOS.Services.Sockets.Nsd.Manager; using Ryujinx.HLE.HOS.Services.Sockets.Nsd.Manager;
using Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Proxy; using Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Proxy;
@ -11,6 +12,7 @@ using System.Linq;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text; using System.Text;
namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
@ -268,7 +270,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
NetDbError netDbErrorCode = NetDbError.Success; NetDbError netDbErrorCode = NetDbError.Success;
GaiError errno = GaiError.Overflow; GaiError errno = GaiError.Overflow;
ulong serializedSize = 0; int serializedSize = 0;
if (host.Length <= byte.MaxValue) if (host.Length <= byte.MaxValue)
{ {
@ -368,7 +370,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
NetDbError netDbErrorCode = NetDbError.Success; NetDbError netDbErrorCode = NetDbError.Success;
GaiError errno = GaiError.AddressFamily; GaiError errno = GaiError.AddressFamily;
ulong serializedSize = 0; int serializedSize = 0;
if (rawIp.Length == 4) if (rawIp.Length == 4)
{ {
@ -400,7 +402,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
return ResultCode.Success; return ResultCode.Success;
} }
private static ulong SerializeHostEntries(ServiceCtx context, ulong outputBufferPosition, ulong outputBufferSize, IPHostEntry hostEntry, IEnumerable<IPAddress> addresses = null) private static int SerializeHostEntries(ServiceCtx context, ulong outputBufferPosition, ulong outputBufferSize, IPHostEntry hostEntry, IEnumerable<IPAddress> addresses = null)
{ {
ulong originalBufferPosition = outputBufferPosition; ulong originalBufferPosition = outputBufferPosition;
ulong bufferPosition = originalBufferPosition; ulong bufferPosition = originalBufferPosition;
@ -443,7 +445,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
} }
} }
return bufferPosition - originalBufferPosition; return (int)(bufferPosition - originalBufferPosition);
} }
private static ResultCode GetAddrInfoRequestImpl( private static ResultCode GetAddrInfoRequestImpl(
@ -470,7 +472,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
} }
// NOTE: We ignore hints for now. // NOTE: We ignore hints for now.
DeserializeAddrInfos(context.Memory, (ulong)context.Request.SendBuff[2].Position, (ulong)context.Request.SendBuff[2].Size); List<AddrInfoSerialized> hints = DeserializeAddrInfos(context.Memory, context.Request.SendBuff[2].Position, context.Request.SendBuff[2].Size);
if (withOptions) if (withOptions)
{ {
@ -484,7 +486,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
NetDbError netDbErrorCode = NetDbError.Success; NetDbError netDbErrorCode = NetDbError.Success;
GaiError errno = GaiError.AddressFamily; GaiError errno = GaiError.AddressFamily;
ulong serializedSize = 0; int serializedSize = 0;
if (host.Length <= byte.MaxValue) if (host.Length <= byte.MaxValue)
{ {
@ -538,32 +540,37 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
return ResultCode.Success; return ResultCode.Success;
} }
private static void DeserializeAddrInfos(IVirtualMemoryManager memory, ulong address, ulong size) private static List<AddrInfoSerialized> DeserializeAddrInfos(IVirtualMemoryManager memory, ulong address, ulong size)
{ {
ulong endAddress = address + size; List<AddrInfoSerialized> result = new List<AddrInfoSerialized>();
while (address < endAddress) ReadOnlySpan<byte> data = memory.GetSpan(address, (int)size);
while (!data.IsEmpty)
{ {
AddrInfoSerializedHeader header = memory.Read<AddrInfoSerializedHeader>(address); AddrInfoSerialized info = AddrInfoSerialized.Read(data, out data);
if (header.Magic != SfdnsresContants.AddrInfoMagic) if (info == null)
{ {
break; break;
} }
address += (ulong)Unsafe.SizeOf<AddrInfoSerializedHeader>() + header.AddressLength; result.Add(info);
// ai_canonname
string canonname = MemoryHelper.ReadAsciiString(memory, address);
}
} }
private static ulong SerializeAddrInfos(ServiceCtx context, ulong responseBufferPosition, ulong responseBufferSize, IPHostEntry hostEntry, int port) return result;
}
private static int SerializeAddrInfos(ServiceCtx context, ulong responseBufferPosition, ulong responseBufferSize, IPHostEntry hostEntry, int port)
{ {
ulong originalBufferPosition = (ulong)responseBufferPosition; ulong originalBufferPosition = responseBufferPosition;
ulong bufferPosition = originalBufferPosition; ulong bufferPosition = originalBufferPosition;
string hostName = hostEntry.HostName + '\0'; byte[] hostName = Encoding.ASCII.GetBytes(hostEntry.HostName + '\0');
using (WritableRegion region = context.Memory.GetWritableRegion(responseBufferPosition, (int)responseBufferSize))
{
Span<byte> data = region.Memory.Span;
for (int i = 0; i < hostEntry.AddressList.Length; i++) for (int i = 0; i < hostEntry.AddressList.Length; i++)
{ {
@ -574,38 +581,32 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
continue; continue;
} }
AddrInfoSerializedHeader header = new AddrInfoSerializedHeader(ip, 0);
// NOTE: 0 = Any // NOTE: 0 = Any
context.Memory.Write(bufferPosition, header); AddrInfoSerializedHeader header = new AddrInfoSerializedHeader(ip, 0);
bufferPosition += (ulong)Unsafe.SizeOf<AddrInfoSerializedHeader>(); AddrInfo4 addr = new AddrInfo4(ip, (short)port);
AddrInfoSerialized info = new AddrInfoSerialized(header, addr, null, hostEntry.HostName);
// addrinfo_in data = info.Write(data);
context.Memory.Write(bufferPosition, new AddrInfo4(ip, (short)port));
bufferPosition += header.AddressLength;
// ai_canonname
context.Memory.Write(bufferPosition, Encoding.ASCII.GetBytes(hostName));
bufferPosition += (ulong)hostName.Length;
} }
// Termination zero value. uint sentinel = 0;
context.Memory.Write(bufferPosition, 0); MemoryMarshal.Write(data, ref sentinel);
bufferPosition += sizeof(int); data = data[sizeof(uint)..];
return bufferPosition - originalBufferPosition; return region.Memory.Span.Length - data.Length;
}
} }
private static void WriteResponse( private static void WriteResponse(
ServiceCtx context, ServiceCtx context,
bool withOptions, bool withOptions,
ulong serializedSize, int serializedSize,
GaiError errno, GaiError errno,
NetDbError netDbErrorCode) NetDbError netDbErrorCode)
{ {
if (withOptions) if (withOptions)
{ {
context.ResponseData.Write((int)serializedSize); context.ResponseData.Write(serializedSize);
context.ResponseData.Write((int)errno); context.ResponseData.Write((int)errno);
context.ResponseData.Write((int)netDbErrorCode); context.ResponseData.Write((int)netDbErrorCode);
context.ResponseData.Write(0); context.ResponseData.Write(0);

View file

@ -2,6 +2,7 @@
using System; using System;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types
@ -16,14 +17,34 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types
public AddrInfo4(IPAddress address, short port) public AddrInfo4(IPAddress address, short port)
{ {
Length = 0; Length = (byte)Unsafe.SizeOf<Array4<byte>>();
Family = (byte)AddressFamily.InterNetwork; Family = (byte)AddressFamily.InterNetwork;
Port = port; Port = port;
Address = default; Address = new Array4<byte>();
Span<byte> outAddress = Address.ToSpan(); address.TryWriteBytes(Address.ToSpan(), out _);
address.TryWriteBytes(outAddress, out _); }
outAddress.Reverse();
public void ToNetworkOrder()
{
Port = IPAddress.HostToNetworkOrder(Port);
RawIpv4AddressNetworkEndianSwap(ref Address);
}
public void ToHostOrder()
{
Port = IPAddress.NetworkToHostOrder(Port);
RawIpv4AddressNetworkEndianSwap(ref Address);
}
public static void RawIpv4AddressNetworkEndianSwap(ref Array4<byte> address)
{
if (BitConverter.IsLittleEndian)
{
address.ToSpan().Reverse();
}
} }
} }
} }

View file

@ -0,0 +1,136 @@
using Ryujinx.Common.Memory;
using Ryujinx.HLE.Utilities;
using System;
using System.Diagnostics;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types
{
class AddrInfoSerialized
{
public AddrInfoSerializedHeader Header;
public AddrInfo4? SocketAddress;
public Array4<byte>? RawIPv4Address;
public string CanonicalName;
public AddrInfoSerialized(AddrInfoSerializedHeader header, AddrInfo4? address, Array4<byte>? rawIPv4Address, string canonicalName)
{
Header = header;
SocketAddress = address;
RawIPv4Address = rawIPv4Address;
CanonicalName = canonicalName;
}
public static AddrInfoSerialized Read(ReadOnlySpan<byte> buffer, out ReadOnlySpan<byte> rest)
{
if (!MemoryMarshal.TryRead(buffer, out AddrInfoSerializedHeader header))
{
rest = buffer;
return null;
}
AddrInfo4? socketAddress = null;
Array4<byte>? rawIPv4Address = null;
string canonicalName = null;
buffer = buffer[Unsafe.SizeOf<AddrInfoSerializedHeader>()..];
header.ToHostOrder();
if (header.Magic != SfdnsresContants.AddrInfoMagic)
{
rest = buffer;
return null;
}
Debug.Assert(header.Magic == SfdnsresContants.AddrInfoMagic);
if (header.Family == (int)AddressFamily.InterNetwork)
{
socketAddress = MemoryMarshal.Read<AddrInfo4>(buffer);
socketAddress.Value.ToHostOrder();
buffer = buffer[Unsafe.SizeOf<AddrInfo4>()..];
}
// AF_INET6
else if (header.Family == 28)
{
throw new NotImplementedException();
}
else
{
// Nintendo hardcode 4 bytes in that case here.
Array4<byte> address = MemoryMarshal.Read<Array4<byte>>(buffer);
AddrInfo4.RawIpv4AddressNetworkEndianSwap(ref address);
rawIPv4Address = address;
buffer = buffer[Unsafe.SizeOf<Array4<byte>>()..];
}
canonicalName = StringUtils.ReadUtf8String(buffer, out int dataRead);
buffer = buffer[dataRead..];
rest = buffer;
return new AddrInfoSerialized(header, socketAddress, rawIPv4Address, canonicalName);
}
public Span<byte> Write(Span<byte> buffer)
{
int familly = Header.Family;
Header.ToNetworkOrder();
MemoryMarshal.Write(buffer, ref Header);
buffer = buffer[Unsafe.SizeOf<AddrInfoSerializedHeader>()..];
if (familly == (int)AddressFamily.InterNetwork)
{
AddrInfo4 socketAddress = SocketAddress.Value;
socketAddress.ToNetworkOrder();
MemoryMarshal.Write(buffer, ref socketAddress);
buffer = buffer[Unsafe.SizeOf<AddrInfo4>()..];
}
// AF_INET6
else if (familly == 28)
{
throw new NotImplementedException();
}
else
{
Array4<byte> rawIPv4Address = RawIPv4Address.Value;
AddrInfo4.RawIpv4AddressNetworkEndianSwap(ref rawIPv4Address);
MemoryMarshal.Write(buffer, ref rawIPv4Address);
buffer = buffer[Unsafe.SizeOf<Array4<byte>>()..];
}
if (CanonicalName == null)
{
buffer[0] = 0;
buffer = buffer[1..];
}
else
{
byte[] canonicalName = Encoding.ASCII.GetBytes(CanonicalName + '\0');
canonicalName.CopyTo(buffer);
buffer = buffer[canonicalName.Length..];
}
return buffer;
}
}
}

View file

@ -1,4 +1,4 @@
using System.Buffers.Binary; using Ryujinx.Common.Memory;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
@ -18,11 +18,11 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types
public AddrInfoSerializedHeader(IPAddress address, SocketType socketType) public AddrInfoSerializedHeader(IPAddress address, SocketType socketType)
{ {
Magic = (uint)BinaryPrimitives.ReverseEndianness(unchecked((int)SfdnsresContants.AddrInfoMagic)); Magic = SfdnsresContants.AddrInfoMagic;
Flags = 0; // Big Endian Flags = 0;
Family = BinaryPrimitives.ReverseEndianness((int)address.AddressFamily); Family = (int)address.AddressFamily;
SocketType = BinaryPrimitives.ReverseEndianness((int)socketType); SocketType = (int)socketType;
Protocol = 0; // Big Endian Protocol = 0;
if (address.AddressFamily == AddressFamily.InterNetwork) if (address.AddressFamily == AddressFamily.InterNetwork)
{ {
@ -30,8 +30,28 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types
} }
else else
{ {
AddressLength = 4; AddressLength = (uint)Unsafe.SizeOf<Array4<byte>>();
} }
}
public void ToNetworkOrder()
{
Magic = (uint)IPAddress.HostToNetworkOrder((int)Magic);
Flags = IPAddress.HostToNetworkOrder(Flags);
Family = IPAddress.HostToNetworkOrder(Family);
SocketType = IPAddress.HostToNetworkOrder(SocketType);
Protocol = IPAddress.HostToNetworkOrder(Protocol);
AddressLength = (uint)IPAddress.HostToNetworkOrder((int)AddressLength);
}
public void ToHostOrder()
{
Magic = (uint)IPAddress.NetworkToHostOrder((int)Magic);
Flags = IPAddress.NetworkToHostOrder(Flags);
Family = IPAddress.NetworkToHostOrder(Family);
SocketType = IPAddress.NetworkToHostOrder(SocketType);
Protocol = IPAddress.NetworkToHostOrder(Protocol);
AddressLength = (uint)IPAddress.NetworkToHostOrder((int)AddressLength);
} }
} }
} }

View file

@ -60,6 +60,18 @@ namespace Ryujinx.HLE.Utilities
return output; return output;
} }
public static string ReadUtf8String(ReadOnlySpan<byte> data, out int dataRead)
{
dataRead = data.IndexOf((byte)0) + 1;
if (dataRead <= 1)
{
return string.Empty;
}
return Encoding.UTF8.GetString(data[..dataRead]);
}
public static string ReadUtf8String(ServiceCtx context, int index = 0) public static string ReadUtf8String(ServiceCtx context, int index = 0)
{ {
ulong position = context.Request.PtrBuff[index].Position; ulong position = context.Request.PtrBuff[index].Position;