From 90279d96ea7c89df8876798caad106bcf1972762 Mon Sep 17 00:00:00 2001 From: gdkchan Date: Sat, 21 Apr 2018 16:07:16 -0300 Subject: [PATCH] Implement the synchronization primitives like the Horizon kernel does (#97) * Started to work in improving the sync primitives * Some fixes * Check that the mutex address matches before waking a waiting thread * Add MutexOwner field to keep track of the thread owning the mutex, update wait list when priority changes, other tweaks * Add new priority information to the log * SvcSetThreadPriority should update just the WantedPriority --- .../OsHle/Handles/KProcessScheduler.cs | 192 ++++++---- Ryujinx.Core/OsHle/Handles/KThread.cs | 96 ++++- .../OsHle/Kernel/ConditionVariable.cs | 148 -------- Ryujinx.Core/OsHle/Kernel/KernelErr.cs | 10 +- Ryujinx.Core/OsHle/Kernel/MutualExclusion.cs | 95 ----- Ryujinx.Core/OsHle/Kernel/SvcHandler.cs | 7 +- Ryujinx.Core/OsHle/Kernel/SvcThread.cs | 4 +- Ryujinx.Core/OsHle/Kernel/SvcThreadSync.cs | 358 ++++++++++++++++-- Ryujinx.Core/OsHle/Process.cs | 42 +- 9 files changed, 577 insertions(+), 375 deletions(-) delete mode 100644 Ryujinx.Core/OsHle/Kernel/ConditionVariable.cs delete mode 100644 Ryujinx.Core/OsHle/Kernel/MutualExclusion.cs diff --git a/Ryujinx.Core/OsHle/Handles/KProcessScheduler.cs b/Ryujinx.Core/OsHle/Handles/KProcessScheduler.cs index 2f694600..81aa3fdd 100644 --- a/Ryujinx.Core/OsHle/Handles/KProcessScheduler.cs +++ b/Ryujinx.Core/OsHle/Handles/KProcessScheduler.cs @@ -13,17 +13,23 @@ namespace Ryujinx.Core.OsHle.Handles { public KThread Thread { get; private set; } - public AutoResetEvent WaitEvent { get; private set; } + public ManualResetEvent SyncWaitEvent { get; private set; } + public AutoResetEvent SchedWaitEvent { get; private set; } public bool Active { get; set; } + public int SyncTimeout { get; set; } + public SchedulerThread(KThread Thread) { this.Thread = Thread; - WaitEvent = new AutoResetEvent(false); + SyncWaitEvent = new ManualResetEvent(true); + SchedWaitEvent = new AutoResetEvent(false); Active = true; + + SyncTimeout = 0; } public void Dispose() @@ -35,7 +41,8 @@ namespace Ryujinx.Core.OsHle.Handles { if (Disposing) { - WaitEvent.Dispose(); + SyncWaitEvent.Dispose(); + SchedWaitEvent.Dispose(); } } } @@ -71,9 +78,9 @@ namespace Ryujinx.Core.OsHle.Handles { SchedThread = Threads[Index]; - if (HighestPriority > SchedThread.Thread.Priority) + if (HighestPriority > SchedThread.Thread.ActualPriority) { - HighestPriority = SchedThread.Thread.Priority; + HighestPriority = SchedThread.Thread.ActualPriority; HighestPrioIndex = Index; } @@ -194,45 +201,66 @@ namespace Ryujinx.Core.OsHle.Handles throw new InvalidOperationException(); } - lock (SchedLock) + SchedThread.Active = Active; + + UpdateSyncWaitEvent(SchedThread); + + WaitIfNeeded(SchedThread); + } + + public bool EnterWait(KThread Thread, int Timeout = -1) + { + if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread)) { - bool OldState = SchedThread.Active; - - SchedThread.Active = Active; - - if (!OldState && Active) - { - if (ActiveProcessors.Add(Thread.ProcessorId)) - { - RunThread(SchedThread); - } - else - { - WaitingToRun[Thread.ProcessorId].Push(SchedThread); - - PrintDbgThreadInfo(Thread, "entering wait state..."); - } - } - else if (OldState && !Active) - { - if (Thread.Thread.IsCurrentThread()) - { - Suspend(Thread.ProcessorId); - - PrintDbgThreadInfo(Thread, "entering inactive wait state..."); - } - else - { - WaitingToRun[Thread.ProcessorId].Remove(SchedThread); - } - } + throw new InvalidOperationException(); } - if (!Active && Thread.Thread.IsCurrentThread()) - { - SchedThread.WaitEvent.WaitOne(); + SchedThread.SyncTimeout = Timeout; - PrintDbgThreadInfo(Thread, "resuming execution..."); + UpdateSyncWaitEvent(SchedThread); + + return WaitIfNeeded(SchedThread); + } + + public void WakeUp(KThread Thread) + { + if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread)) + { + throw new InvalidOperationException(); + } + + SchedThread.SyncTimeout = 0; + + UpdateSyncWaitEvent(SchedThread); + + WaitIfNeeded(SchedThread); + } + + private void UpdateSyncWaitEvent(SchedulerThread SchedThread) + { + if (SchedThread.Active && SchedThread.SyncTimeout == 0) + { + SchedThread.SyncWaitEvent.Set(); + } + else + { + SchedThread.SyncWaitEvent.Reset(); + } + } + + private bool WaitIfNeeded(SchedulerThread SchedThread) + { + KThread Thread = SchedThread.Thread; + + if (!IsActive(SchedThread) && Thread.Thread.IsCurrentThread()) + { + Suspend(Thread.ProcessorId); + + return Resume(Thread); + } + else + { + return false; } } @@ -261,66 +289,78 @@ namespace Ryujinx.Core.OsHle.Handles lock (SchedLock) { - SchedulerThread SchedThread = WaitingToRun[Thread.ProcessorId].Pop(Thread.Priority); + SchedulerThread SchedThread = WaitingToRun[Thread.ProcessorId].Pop(Thread.ActualPriority); - if (SchedThread == null) + if (IsActive(Thread) && SchedThread == null) { PrintDbgThreadInfo(Thread, "resumed because theres nothing better to run."); return; } - RunThread(SchedThread); + if (SchedThread != null) + { + RunThread(SchedThread); + } } Resume(Thread); } - public void Resume(KThread Thread) + public bool Resume(KThread Thread) { if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread)) { throw new InvalidOperationException(); } - TryResumingExecution(SchedThread); + return TryResumingExecution(SchedThread); } - private void TryResumingExecution(SchedulerThread SchedThread) + private bool TryResumingExecution(SchedulerThread SchedThread) { KThread Thread = SchedThread.Thread; - if (SchedThread.Active) - { - lock (SchedLock) - { - if (ActiveProcessors.Add(Thread.ProcessorId)) - { - PrintDbgThreadInfo(Thread, "resuming execution..."); - - return; - } - - WaitingToRun[Thread.ProcessorId].Push(SchedThread); - - PrintDbgThreadInfo(Thread, "entering wait state..."); - } - } - else + if (!SchedThread.Active || SchedThread.SyncTimeout != 0) { PrintDbgThreadInfo(Thread, "entering inactive wait state..."); } - SchedThread.WaitEvent.WaitOne(); + bool Result = false; + + if (SchedThread.SyncTimeout != 0) + { + Result = SchedThread.SyncWaitEvent.WaitOne(SchedThread.SyncTimeout); + + SchedThread.SyncTimeout = 0; + } + + lock (SchedLock) + { + if (ActiveProcessors.Add(Thread.ProcessorId)) + { + PrintDbgThreadInfo(Thread, "resuming execution..."); + + return Result; + } + + WaitingToRun[Thread.ProcessorId].Push(SchedThread); + + PrintDbgThreadInfo(Thread, "entering wait state..."); + } + + SchedThread.SchedWaitEvent.WaitOne(); PrintDbgThreadInfo(Thread, "resuming execution..."); + + return Result; } private void RunThread(SchedulerThread SchedThread) { if (!SchedThread.Thread.Thread.Execute()) { - SchedThread.WaitEvent.Set(); + SchedThread.SchedWaitEvent.Set(); } else { @@ -328,12 +368,28 @@ namespace Ryujinx.Core.OsHle.Handles } } + private bool IsActive(KThread Thread) + { + if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread)) + { + throw new InvalidOperationException(); + } + + return IsActive(SchedThread); + } + + private bool IsActive(SchedulerThread SchedThread) + { + return SchedThread.Active && SchedThread.SyncTimeout == 0; + } + private void PrintDbgThreadInfo(KThread Thread, string Message) { Logging.Debug(LogClass.KernelScheduler, "(" + - "ThreadId: " + Thread.ThreadId + ", " + - "ProcessorId: " + Thread.ProcessorId + ", " + - "Priority: " + Thread.Priority + ") " + Message); + "ThreadId: " + Thread.ThreadId + ", " + + "ProcessorId: " + Thread.ProcessorId + ", " + + "ActualPriority: " + Thread.ActualPriority + ", " + + "WantedPriority: " + Thread.WantedPriority + ") " + Message); } public void Dispose() diff --git a/Ryujinx.Core/OsHle/Handles/KThread.cs b/Ryujinx.Core/OsHle/Handles/KThread.cs index 9742f492..4286984e 100644 --- a/Ryujinx.Core/OsHle/Handles/KThread.cs +++ b/Ryujinx.Core/OsHle/Handles/KThread.cs @@ -1,4 +1,5 @@ using ChocolArm64; +using System; namespace Ryujinx.Core.OsHle.Handles { @@ -6,10 +7,20 @@ namespace Ryujinx.Core.OsHle.Handles { public AThread Thread { get; private set; } + public KThread MutexOwner { get; set; } + + public KThread NextMutexThread { get; set; } + public KThread NextCondVarThread { get; set; } + + public long MutexAddress { get; set; } + public long CondVarAddress { get; set; } + + public int ActualPriority { get; private set; } + public int WantedPriority { get; private set; } + public int ProcessorId { get; private set; } - public int Priority { get; set; } - public int Handle { get; set; } + public int WaitHandle { get; set; } public int ThreadId => Thread.ThreadId; @@ -17,7 +28,86 @@ namespace Ryujinx.Core.OsHle.Handles { this.Thread = Thread; this.ProcessorId = ProcessorId; - this.Priority = Priority; + + ActualPriority = WantedPriority = Priority; + } + + public void SetPriority(int Priority) + { + WantedPriority = Priority; + + UpdatePriority(); + } + + public void UpdatePriority() + { + int OldPriority = ActualPriority; + + int CurrPriority = WantedPriority; + + if (NextMutexThread != null && CurrPriority > NextMutexThread.WantedPriority) + { + CurrPriority = NextMutexThread.WantedPriority; + } + + if (CurrPriority != OldPriority) + { + ActualPriority = CurrPriority; + + UpdateWaitList(); + + MutexOwner?.UpdatePriority(); + } + } + + private void UpdateWaitList() + { + KThread OwnerThread = MutexOwner; + + if (OwnerThread != null) + { + //The MutexOwner field should only be non null when the thread is + //waiting for the lock, and the lock belongs to another thread. + if (OwnerThread == this) + { + throw new InvalidOperationException(); + } + + lock (OwnerThread) + { + //Remove itself from the list. + KThread CurrThread = OwnerThread; + + while (CurrThread.NextMutexThread != null) + { + if (CurrThread.NextMutexThread == this) + { + CurrThread.NextMutexThread = NextMutexThread; + + break; + } + + CurrThread = CurrThread.NextMutexThread; + } + + //Re-add taking new priority into account. + CurrThread = OwnerThread; + + while (CurrThread.NextMutexThread != null) + { + if (CurrThread.NextMutexThread.ActualPriority < ActualPriority) + { + break; + } + + CurrThread = CurrThread.NextMutexThread; + } + + NextMutexThread = CurrThread.NextMutexThread; + + CurrThread.NextMutexThread = this; + } + } } } } \ No newline at end of file diff --git a/Ryujinx.Core/OsHle/Kernel/ConditionVariable.cs b/Ryujinx.Core/OsHle/Kernel/ConditionVariable.cs deleted file mode 100644 index f7657376..00000000 --- a/Ryujinx.Core/OsHle/Kernel/ConditionVariable.cs +++ /dev/null @@ -1,148 +0,0 @@ -using Ryujinx.Core.OsHle.Handles; -using System.Collections.Generic; -using System.Threading; - -namespace Ryujinx.Core.OsHle.Kernel -{ - class ConditionVariable - { - private Process Process; - - private long CondVarAddress; - - private bool OwnsCondVarValue; - - private List<(KThread Thread, AutoResetEvent WaitEvent)> WaitingThreads; - - public ConditionVariable(Process Process, long CondVarAddress) - { - this.Process = Process; - this.CondVarAddress = CondVarAddress; - - WaitingThreads = new List<(KThread, AutoResetEvent)>(); - } - - public bool WaitForSignal(KThread Thread, ulong Timeout) - { - bool Result = true; - - int Count = Process.Memory.ReadInt32(CondVarAddress); - - if (Count <= 0) - { - using (AutoResetEvent WaitEvent = new AutoResetEvent(false)) - { - lock (WaitingThreads) - { - WaitingThreads.Add((Thread, WaitEvent)); - } - - if (Timeout == ulong.MaxValue) - { - Result = WaitEvent.WaitOne(); - } - else - { - Result = WaitEvent.WaitOne(NsTimeConverter.GetTimeMs(Timeout)); - - lock (WaitingThreads) - { - WaitingThreads.Remove((Thread, WaitEvent)); - } - } - } - } - - AcquireCondVarValue(); - - Count = Process.Memory.ReadInt32(CondVarAddress); - - if (Result && Count > 0) - { - Process.Memory.WriteInt32(CondVarAddress, Count - 1); - } - - ReleaseCondVarValue(); - - return Result; - } - - public void SetSignal(KThread Thread, int Count) - { - lock (WaitingThreads) - { - if (Count < 0) - { - foreach ((_, AutoResetEvent WaitEvent) in WaitingThreads) - { - IncrementCondVarValue(); - - WaitEvent.Set(); - } - - WaitingThreads.Clear(); - } - else - { - while (WaitingThreads.Count > 0 && Count-- > 0) - { - int HighestPriority = WaitingThreads[0].Thread.Priority; - int HighestPrioIndex = 0; - - for (int Index = 1; Index < WaitingThreads.Count; Index++) - { - if (HighestPriority > WaitingThreads[Index].Thread.Priority) - { - HighestPriority = WaitingThreads[Index].Thread.Priority; - - HighestPrioIndex = Index; - } - } - - IncrementCondVarValue(); - - WaitingThreads[HighestPrioIndex].WaitEvent.Set(); - - WaitingThreads.RemoveAt(HighestPrioIndex); - } - } - } - - Process.Scheduler.Yield(Thread); - } - - private void IncrementCondVarValue() - { - AcquireCondVarValue(); - - int Count = Process.Memory.ReadInt32(CondVarAddress); - - Process.Memory.WriteInt32(CondVarAddress, Count + 1); - - ReleaseCondVarValue(); - } - - private void AcquireCondVarValue() - { - if (!OwnsCondVarValue) - { - while (!Process.Memory.AcquireAddress(CondVarAddress)) - { - Thread.Yield(); - } - - OwnsCondVarValue = true; - } - } - - private void ReleaseCondVarValue() - { - if (OwnsCondVarValue) - { - OwnsCondVarValue = false; - - Process.Memory.ReleaseAddress(CondVarAddress); - } - } - } -} \ No newline at end of file diff --git a/Ryujinx.Core/OsHle/Kernel/KernelErr.cs b/Ryujinx.Core/OsHle/Kernel/KernelErr.cs index e7cd72dc..b568405b 100644 --- a/Ryujinx.Core/OsHle/Kernel/KernelErr.cs +++ b/Ryujinx.Core/OsHle/Kernel/KernelErr.cs @@ -2,9 +2,11 @@ namespace Ryujinx.Core.OsHle.Kernel { static class KernelErr { - public const int InvalidMemRange = 110; - public const int InvalidHandle = 114; - public const int Timeout = 117; - public const int InvalidInfo = 120; + public const int InvalidAlignment = 102; + public const int InvalidAddress = 106; + public const int InvalidMemRange = 110; + public const int InvalidHandle = 114; + public const int Timeout = 117; + public const int InvalidInfo = 120; } } \ No newline at end of file diff --git a/Ryujinx.Core/OsHle/Kernel/MutualExclusion.cs b/Ryujinx.Core/OsHle/Kernel/MutualExclusion.cs deleted file mode 100644 index 9f05406b..00000000 --- a/Ryujinx.Core/OsHle/Kernel/MutualExclusion.cs +++ /dev/null @@ -1,95 +0,0 @@ -using Ryujinx.Core.OsHle.Handles; -using System.Collections.Generic; -using System.Threading; - -namespace Ryujinx.Core.OsHle.Kernel -{ - class MutualExclusion - { - private const int MutexHasListenersMask = 0x40000000; - - private Process Process; - - private long MutexAddress; - - private int OwnerThreadHandle; - - private List<(KThread Thread, AutoResetEvent WaitEvent)> WaitingThreads; - - public MutualExclusion(Process Process, long MutexAddress) - { - this.Process = Process; - this.MutexAddress = MutexAddress; - - WaitingThreads = new List<(KThread, AutoResetEvent)>(); - } - - public void WaitForLock(KThread RequestingThread) - { - WaitForLock(RequestingThread, OwnerThreadHandle); - } - - public void WaitForLock(KThread RequestingThread, int OwnerThreadHandle) - { - if (OwnerThreadHandle == RequestingThread.Handle || - OwnerThreadHandle == 0) - { - return; - } - - using (AutoResetEvent WaitEvent = new AutoResetEvent(false)) - { - lock (WaitingThreads) - { - WaitingThreads.Add((RequestingThread, WaitEvent)); - } - - Process.Scheduler.Suspend(RequestingThread.ProcessorId); - - WaitEvent.WaitOne(); - - Process.Scheduler.Resume(RequestingThread); - } - } - - public void Unlock() - { - lock (WaitingThreads) - { - int HasListeners = WaitingThreads.Count > 1 ? MutexHasListenersMask : 0; - - if (WaitingThreads.Count > 0) - { - int HighestPriority = WaitingThreads[0].Thread.Priority; - int HighestPrioIndex = 0; - - for (int Index = 1; Index < WaitingThreads.Count; Index++) - { - if (HighestPriority > WaitingThreads[Index].Thread.Priority) - { - HighestPriority = WaitingThreads[Index].Thread.Priority; - - HighestPrioIndex = Index; - } - } - - int Handle = WaitingThreads[HighestPrioIndex].Thread.Handle; - - WaitingThreads[HighestPrioIndex].WaitEvent.Set(); - - WaitingThreads.RemoveAt(HighestPrioIndex); - - Process.Memory.WriteInt32(MutexAddress, HasListeners | Handle); - - OwnerThreadHandle = Handle; - } - else - { - Process.Memory.WriteInt32(MutexAddress, 0); - - OwnerThreadHandle = 0; - } - } - } - } -} \ No newline at end of file diff --git a/Ryujinx.Core/OsHle/Kernel/SvcHandler.cs b/Ryujinx.Core/OsHle/Kernel/SvcHandler.cs index 16ef8697..c74da061 100644 --- a/Ryujinx.Core/OsHle/Kernel/SvcHandler.cs +++ b/Ryujinx.Core/OsHle/Kernel/SvcHandler.cs @@ -3,7 +3,6 @@ using ChocolArm64.Memory; using ChocolArm64.State; using Ryujinx.Core.OsHle.Handles; using System; -using System.Collections.Concurrent; using System.Collections.Generic; namespace Ryujinx.Core.OsHle.Kernel @@ -18,8 +17,7 @@ namespace Ryujinx.Core.OsHle.Kernel private Process Process; private AMemory Memory; - private ConcurrentDictionary Mutexes; - private ConcurrentDictionary CondVars; + private object CondVarLock; private HashSet<(HSharedMem, long)> MappedSharedMems; @@ -71,8 +69,7 @@ namespace Ryujinx.Core.OsHle.Kernel this.Process = Process; this.Memory = Process.Memory; - Mutexes = new ConcurrentDictionary(); - CondVars = new ConcurrentDictionary(); + CondVarLock = new object(); MappedSharedMems = new HashSet<(HSharedMem, long)>(); } diff --git a/Ryujinx.Core/OsHle/Kernel/SvcThread.cs b/Ryujinx.Core/OsHle/Kernel/SvcThread.cs index 06147b28..1e4d61b4 100644 --- a/Ryujinx.Core/OsHle/Kernel/SvcThread.cs +++ b/Ryujinx.Core/OsHle/Kernel/SvcThread.cs @@ -91,7 +91,7 @@ namespace Ryujinx.Core.OsHle.Kernel if (CurrThread != null) { ThreadState.X0 = 0; - ThreadState.X1 = (ulong)CurrThread.Priority; + ThreadState.X1 = (ulong)CurrThread.ActualPriority; } else { @@ -110,7 +110,7 @@ namespace Ryujinx.Core.OsHle.Kernel if (CurrThread != null) { - CurrThread.Priority = Priority; + CurrThread.SetPriority(Priority); ThreadState.X0 = 0; } diff --git a/Ryujinx.Core/OsHle/Kernel/SvcThreadSync.cs b/Ryujinx.Core/OsHle/Kernel/SvcThreadSync.cs index 38d759d3..6502e4c9 100644 --- a/Ryujinx.Core/OsHle/Kernel/SvcThreadSync.cs +++ b/Ryujinx.Core/OsHle/Kernel/SvcThreadSync.cs @@ -1,5 +1,7 @@ using ChocolArm64.State; using Ryujinx.Core.OsHle.Handles; +using System; +using System.Threading; using static Ryujinx.Core.OsHle.ErrorCode; @@ -7,11 +9,31 @@ namespace Ryujinx.Core.OsHle.Kernel { partial class SvcHandler { + private const int MutexHasListenersMask = 0x40000000; + private void SvcArbitrateLock(AThreadState ThreadState) { - int OwnerThreadHandle = (int)ThreadState.X0; - long MutexAddress = (long)ThreadState.X1; - int RequestingThreadHandle = (int)ThreadState.X2; + int OwnerThreadHandle = (int)ThreadState.X0; + long MutexAddress = (long)ThreadState.X1; + int WaitThreadHandle = (int)ThreadState.X2; + + if (IsPointingInsideKernel(MutexAddress)) + { + Logging.Warn(LogClass.KernelSvc, $"Invalid mutex address 0x{MutexAddress:x16}!"); + + ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAddress); + + return; + } + + if (IsWordAddressUnaligned(MutexAddress)) + { + Logging.Warn(LogClass.KernelSvc, $"Unaligned mutex address 0x{MutexAddress:x16}!"); + + ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAlignment); + + return; + } KThread OwnerThread = Process.HandleTable.GetData(OwnerThreadHandle); @@ -24,20 +46,20 @@ namespace Ryujinx.Core.OsHle.Kernel return; } - KThread RequestingThread = Process.HandleTable.GetData(RequestingThreadHandle); + KThread WaitThread = Process.HandleTable.GetData(WaitThreadHandle); - if (RequestingThread == null) + if (WaitThread == null) { - Logging.Warn(LogClass.KernelSvc, $"Invalid requesting thread handle 0x{RequestingThreadHandle:x8}!"); + Logging.Warn(LogClass.KernelSvc, $"Invalid requesting thread handle 0x{WaitThreadHandle:x8}!"); ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidHandle); return; } - MutualExclusion Mutex = GetMutex(MutexAddress); + KThread CurrThread = Process.GetThread(ThreadState.Tpidr); - Mutex.WaitForLock(RequestingThread, OwnerThreadHandle); + MutexLock(CurrThread, WaitThread, OwnerThreadHandle, WaitThreadHandle, MutexAddress); ThreadState.X0 = 0; } @@ -46,9 +68,28 @@ namespace Ryujinx.Core.OsHle.Kernel { long MutexAddress = (long)ThreadState.X0; - GetMutex(MutexAddress).Unlock(); + if (IsPointingInsideKernel(MutexAddress)) + { + Logging.Warn(LogClass.KernelSvc, $"Invalid mutex address 0x{MutexAddress:x16}!"); - Process.Scheduler.Yield(Process.GetThread(ThreadState.Tpidr)); + ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAddress); + + return; + } + + if (IsWordAddressUnaligned(MutexAddress)) + { + Logging.Warn(LogClass.KernelSvc, $"Unaligned mutex address 0x{MutexAddress:x16}!"); + + ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAlignment); + + return; + } + + if (MutexUnlock(Process.GetThread(ThreadState.Tpidr), MutexAddress)) + { + Process.Scheduler.Yield(Process.GetThread(ThreadState.Tpidr)); + } ThreadState.X0 = 0; } @@ -60,6 +101,24 @@ namespace Ryujinx.Core.OsHle.Kernel int ThreadHandle = (int)ThreadState.X2; ulong Timeout = ThreadState.X3; + if (IsPointingInsideKernel(MutexAddress)) + { + Logging.Warn(LogClass.KernelSvc, $"Invalid mutex address 0x{MutexAddress:x16}!"); + + ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAddress); + + return; + } + + if (IsWordAddressUnaligned(MutexAddress)) + { + Logging.Warn(LogClass.KernelSvc, $"Unaligned mutex address 0x{MutexAddress:x16}!"); + + ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAlignment); + + return; + } + KThread Thread = Process.HandleTable.GetData(ThreadHandle); if (Thread == null) @@ -67,24 +126,22 @@ namespace Ryujinx.Core.OsHle.Kernel Logging.Warn(LogClass.KernelSvc, $"Invalid thread handle 0x{ThreadHandle:x8}!"); ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidHandle); + + return; } - Process.Scheduler.Suspend(Thread.ProcessorId); + KThread CurrThread = Process.GetThread(ThreadState.Tpidr); - MutualExclusion Mutex = GetMutex(MutexAddress); + MutexUnlock(CurrThread, MutexAddress); - Mutex.Unlock(); - - if (!GetCondVar(CondVarAddress).WaitForSignal(Thread, Timeout)) + if (!CondVarWait(CurrThread, ThreadHandle, MutexAddress, CondVarAddress, Timeout)) { ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.Timeout); return; } - Mutex.WaitForLock(Thread); - - Process.Scheduler.Resume(Thread); + Process.Scheduler.Yield(Process.GetThread(ThreadState.Tpidr)); ThreadState.X0 = 0; } @@ -94,31 +151,274 @@ namespace Ryujinx.Core.OsHle.Kernel long CondVarAddress = (long)ThreadState.X0; int Count = (int)ThreadState.X1; - KThread CurrThread = Process.GetThread(ThreadState.Tpidr); - - GetCondVar(CondVarAddress).SetSignal(CurrThread, Count); + CondVarSignal(CondVarAddress, Count); ThreadState.X0 = 0; } - private MutualExclusion GetMutex(long MutexAddress) + private void MutexLock( + KThread CurrThread, + KThread WaitThread, + int OwnerThreadHandle, + int WaitThreadHandle, + long MutexAddress) { - MutualExclusion MutexFactory(long Key) + int MutexValue = Process.Memory.ReadInt32(MutexAddress); + + if (MutexValue != (OwnerThreadHandle | MutexHasListenersMask)) { - return new MutualExclusion(Process, MutexAddress); + return; } - return Mutexes.GetOrAdd(MutexAddress, MutexFactory); + CurrThread.WaitHandle = WaitThreadHandle; + CurrThread.MutexAddress = MutexAddress; + + InsertWaitingMutexThread(OwnerThreadHandle, WaitThread); + + Process.Scheduler.EnterWait(WaitThread); } - private ConditionVariable GetCondVar(long CondVarAddress) + private bool MutexUnlock(KThread CurrThread, long MutexAddress) { - ConditionVariable CondVarFactory(long Key) + if (CurrThread == null) { - return new ConditionVariable(Process, CondVarAddress); + Logging.Warn(LogClass.KernelSvc, $"Invalid mutex 0x{MutexAddress:x16}!"); + + return false; } - return CondVars.GetOrAdd(CondVarAddress, CondVarFactory); + lock (CurrThread) + { + //This is the new thread that will not own the mutex. + //If no threads are waiting for the lock, then it should be null. + KThread OwnerThread = CurrThread.NextMutexThread; + + while (OwnerThread != null && OwnerThread.MutexAddress != MutexAddress) + { + OwnerThread = OwnerThread.NextMutexThread; + } + + CurrThread.NextMutexThread = null; + + if (OwnerThread != null) + { + int HasListeners = OwnerThread.NextMutexThread != null ? MutexHasListenersMask : 0; + + Process.Memory.WriteInt32(MutexAddress, HasListeners | OwnerThread.WaitHandle); + + OwnerThread.WaitHandle = 0; + OwnerThread.MutexAddress = 0; + OwnerThread.CondVarAddress = 0; + + OwnerThread.MutexOwner = null; + + OwnerThread.UpdatePriority(); + + Process.Scheduler.WakeUp(OwnerThread); + + return true; + } + else + { + Process.Memory.WriteInt32(MutexAddress, 0); + + return false; + } + } + } + + private bool CondVarWait( + KThread WaitThread, + int WaitThreadHandle, + long MutexAddress, + long CondVarAddress, + ulong Timeout) + { + WaitThread.WaitHandle = WaitThreadHandle; + WaitThread.MutexAddress = MutexAddress; + WaitThread.CondVarAddress = CondVarAddress; + + lock (CondVarLock) + { + KThread CurrThread = Process.ThreadArbiterList; + + if (CurrThread != null) + { + bool DoInsert = CurrThread != WaitThread; + + while (CurrThread.NextCondVarThread != null) + { + if (CurrThread.NextCondVarThread.ActualPriority < WaitThread.ActualPriority) + { + break; + } + + CurrThread = CurrThread.NextCondVarThread; + + DoInsert &= CurrThread != WaitThread; + } + + //Only insert if the node doesn't already exist in the list. + //This prevents circular references. + if (DoInsert) + { + if (WaitThread.NextCondVarThread != null) + { + throw new InvalidOperationException(); + } + + WaitThread.NextCondVarThread = CurrThread.NextCondVarThread; + CurrThread.NextCondVarThread = WaitThread; + } + } + else + { + Process.ThreadArbiterList = WaitThread; + } + } + + if (Timeout != ulong.MaxValue) + { + return Process.Scheduler.EnterWait(WaitThread, NsTimeConverter.GetTimeMs(Timeout)); + } + else + { + return Process.Scheduler.EnterWait(WaitThread); + } + } + + private void CondVarSignal(long CondVarAddress, int Count) + { + lock (CondVarLock) + { + KThread PrevThread = null; + KThread CurrThread = Process.ThreadArbiterList; + + while (CurrThread != null && (Count == -1 || Count > 0)) + { + if (CurrThread.CondVarAddress == CondVarAddress) + { + if (PrevThread != null) + { + PrevThread.NextCondVarThread = CurrThread.NextCondVarThread; + } + else + { + Process.ThreadArbiterList = CurrThread.NextCondVarThread; + } + + CurrThread.NextCondVarThread = null; + + AcquireMutexValue(CurrThread.MutexAddress); + + int MutexValue = Process.Memory.ReadInt32(CurrThread.MutexAddress); + + MutexValue &= ~MutexHasListenersMask; + + if (MutexValue == 0) + { + //Give the lock to this thread. + Process.Memory.WriteInt32(CurrThread.MutexAddress, CurrThread.WaitHandle); + + CurrThread.WaitHandle = 0; + CurrThread.MutexAddress = 0; + CurrThread.CondVarAddress = 0; + + CurrThread.MutexOwner = null; + + CurrThread.UpdatePriority(); + + Process.Scheduler.WakeUp(CurrThread); + } + else + { + //Wait until the lock is released. + InsertWaitingMutexThread(MutexValue, CurrThread); + + MutexValue |= MutexHasListenersMask; + + Process.Memory.WriteInt32(CurrThread.MutexAddress, MutexValue); + } + + ReleaseMutexValue(CurrThread.MutexAddress); + + Count--; + } + + PrevThread = CurrThread; + CurrThread = CurrThread.NextCondVarThread; + } + } + } + + private void InsertWaitingMutexThread(int OwnerThreadHandle, KThread WaitThread) + { + KThread OwnerThread = Process.HandleTable.GetData(OwnerThreadHandle); + + if (OwnerThread == null) + { + Logging.Warn(LogClass.KernelSvc, $"Invalid thread handle 0x{OwnerThreadHandle:x8}!"); + + return; + } + + WaitThread.MutexOwner = OwnerThread; + + lock (OwnerThread) + { + KThread CurrThread = OwnerThread; + + while (CurrThread.NextMutexThread != null) + { + if (CurrThread == WaitThread) + { + return; + } + + if (CurrThread.NextMutexThread.ActualPriority < WaitThread.ActualPriority) + { + break; + } + + CurrThread = CurrThread.NextMutexThread; + } + + if (CurrThread != WaitThread) + { + if (WaitThread.NextCondVarThread != null) + { + throw new InvalidOperationException(); + } + + WaitThread.NextMutexThread = CurrThread.NextMutexThread; + CurrThread.NextMutexThread = WaitThread; + } + } + + OwnerThread.UpdatePriority(); + } + + private void AcquireMutexValue(long MutexAddress) + { + while (!Process.Memory.AcquireAddress(MutexAddress)) + { + Thread.Yield(); + } + } + + private void ReleaseMutexValue(long MutexAddress) + { + Process.Memory.ReleaseAddress(MutexAddress); + } + + private bool IsPointingInsideKernel(long Address) + { + return ((ulong)Address + 0x1000000000) < 0xffffff000; + } + + private bool IsWordAddressUnaligned(long Address) + { + return (Address & 3) != 0; } } } \ No newline at end of file diff --git a/Ryujinx.Core/OsHle/Process.cs b/Ryujinx.Core/OsHle/Process.cs index bacca9a3..bd4ff1ff 100644 --- a/Ryujinx.Core/OsHle/Process.cs +++ b/Ryujinx.Core/OsHle/Process.cs @@ -35,6 +35,8 @@ namespace Ryujinx.Core.OsHle public KProcessScheduler Scheduler { get; private set; } + public KThread ThreadArbiterList { get; set; } + public KProcessHandleTable HandleTable { get; private set; } public AppletStateMgr AppletState { get; private set; } @@ -43,7 +45,7 @@ namespace Ryujinx.Core.OsHle private ConcurrentDictionary TlsSlots; - private ConcurrentDictionary ThreadsByTpidr; + private ConcurrentDictionary Threads; private List Executables; @@ -71,7 +73,7 @@ namespace Ryujinx.Core.OsHle TlsSlots = new ConcurrentDictionary(); - ThreadsByTpidr = new ConcurrentDictionary(); + Threads = new ConcurrentDictionary(); Executables = new List(); @@ -185,34 +187,32 @@ namespace Ryujinx.Core.OsHle throw new ObjectDisposedException(nameof(Process)); } - AThread Thread = new AThread(GetTranslator(), Memory, EntryPoint); + AThread CpuThread = new AThread(GetTranslator(), Memory, EntryPoint); - KThread KernelThread = new KThread(Thread, ProcessorId, Priority); + KThread Thread = new KThread(CpuThread, ProcessorId, Priority); - int Handle = HandleTable.OpenHandle(KernelThread); + int Handle = HandleTable.OpenHandle(Thread); - KernelThread.Handle = Handle; - - int ThreadId = GetFreeTlsSlot(Thread); + int ThreadId = GetFreeTlsSlot(CpuThread); long Tpidr = MemoryRegions.TlsPagesAddress + ThreadId * TlsSize; - Thread.ThreadState.ProcessId = ProcessId; - Thread.ThreadState.ThreadId = ThreadId; - Thread.ThreadState.CntfrqEl0 = TickFreq; - Thread.ThreadState.Tpidr = Tpidr; + CpuThread.ThreadState.ProcessId = ProcessId; + CpuThread.ThreadState.ThreadId = ThreadId; + CpuThread.ThreadState.CntfrqEl0 = TickFreq; + CpuThread.ThreadState.Tpidr = Tpidr; - Thread.ThreadState.X0 = (ulong)ArgsPtr; - Thread.ThreadState.X1 = (ulong)Handle; - Thread.ThreadState.X31 = (ulong)StackTop; + CpuThread.ThreadState.X0 = (ulong)ArgsPtr; + CpuThread.ThreadState.X1 = (ulong)Handle; + CpuThread.ThreadState.X31 = (ulong)StackTop; - Thread.ThreadState.Break += BreakHandler; - Thread.ThreadState.SvcCall += SvcHandler.SvcCall; - Thread.ThreadState.Undefined += UndefinedHandler; + CpuThread.ThreadState.Break += BreakHandler; + CpuThread.ThreadState.SvcCall += SvcHandler.SvcCall; + CpuThread.ThreadState.Undefined += UndefinedHandler; - Thread.WorkFinished += ThreadFinished; + CpuThread.WorkFinished += ThreadFinished; - ThreadsByTpidr.TryAdd(Thread.ThreadState.Tpidr, KernelThread); + Threads.TryAdd(CpuThread.ThreadState.Tpidr, Thread); return Handle; } @@ -324,7 +324,7 @@ namespace Ryujinx.Core.OsHle public KThread GetThread(long Tpidr) { - if (!ThreadsByTpidr.TryGetValue(Tpidr, out KThread Thread)) + if (!Threads.TryGetValue(Tpidr, out KThread Thread)) { Logging.Error(LogClass.KernelScheduler, $"Thread with TPIDR 0x{Tpidr:x16} not found!"); }