From 302b47f5e63443166d8edc50a5df76252b2dcfaf Mon Sep 17 00:00:00 2001
From: JosJuice <josjuice@gmail.com>
Date: Tue, 13 Jul 2021 19:09:35 +0200
Subject: [PATCH] JitArm64: Add temp reg parameter to Arm64RegCache::Flush

We currently have a bug when calling Arm64GPRCache::Flush with
FlushMode::MaintainState, zero free host registers, and at least
one guest register containing an immediate. We end up grabbing
a temporary register from the register cache in order to be
able to write the immediate to memory, but grabbing a temporary
register when there are zero free registers causes the least
recently used register to be flushed in a way which does not
maintain the state of the register cache.

To get around this, require callers to pass in a temporary
register in the GPR MaintainState case. In other cases,
passing in a temporary register is not required but can help
avoid spilling a register (if the caller already had a
temporary register at hand anyway, which in particular will
be the case in my upcoming memcheck pull request).
---
 Source/Core/Core/PowerPC/JitArm64/Jit.cpp     | 36 +++++----
 .../Core/PowerPC/JitArm64/JitArm64_Branch.cpp | 52 ++++++------
 .../PowerPC/JitArm64/JitArm64_RegCache.cpp    | 80 +++++++++++++------
 .../Core/PowerPC/JitArm64/JitArm64_RegCache.h | 53 +++++++-----
 .../JitArm64/JitArm64_SystemRegisters.cpp     | 15 ++--
 5 files changed, 143 insertions(+), 93 deletions(-)

diff --git a/Source/Core/Core/PowerPC/JitArm64/Jit.cpp b/Source/Core/Core/PowerPC/JitArm64/Jit.cpp
index 7b18330e0f..3191be64eb 100644
--- a/Source/Core/Core/PowerPC/JitArm64/Jit.cpp
+++ b/Source/Core/Core/PowerPC/JitArm64/Jit.cpp
@@ -145,8 +145,8 @@ void JitArm64::Shutdown()
 void JitArm64::FallBackToInterpreter(UGeckoInstruction inst)
 {
   FlushCarry();
-  gpr.Flush(FlushMode::All, js.op);
-  fpr.Flush(FlushMode::All, js.op);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   if (js.op->opinfo->flags & FL_ENDBLOCK)
   {
@@ -204,8 +204,8 @@ void JitArm64::FallBackToInterpreter(UGeckoInstruction inst)
     SwitchToFarCode();
     SetJumpTarget(handleException);
 
-    gpr.Flush(FlushMode::MaintainState);
-    fpr.Flush(FlushMode::MaintainState);
+    gpr.Flush(FlushMode::MaintainState, WA);
+    fpr.Flush(FlushMode::MaintainState, ARM64Reg::INVALID_REG);
 
     WriteExceptionExit(js.compilerPC, false, true);
 
@@ -218,8 +218,8 @@ void JitArm64::FallBackToInterpreter(UGeckoInstruction inst)
 void JitArm64::HLEFunction(u32 hook_index)
 {
   FlushCarry();
-  gpr.Flush(FlushMode::All);
-  fpr.Flush(FlushMode::All);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   MOVP2R(ARM64Reg::X8, &HLE::Execute);
   MOVI2R(ARM64Reg::W0, js.compilerPC);
@@ -741,8 +741,8 @@ void JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
       TST(ARM64Reg::W30, LogicalImm(cause_mask, 32));
       B(CC_EQ, done_here);
 
-      gpr.Flush(FlushMode::MaintainState);
-      fpr.Flush(FlushMode::MaintainState);
+      gpr.Flush(FlushMode::MaintainState, ARM64Reg::W30);
+      fpr.Flush(FlushMode::MaintainState, ARM64Reg::INVALID_REG);
       WriteExceptionExit(js.compilerPC, true, true);
       SwitchToNearCode();
       SetJumpTarget(no_ext_exception);
@@ -759,6 +759,7 @@ void JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
     {
       ARM64Reg WA = gpr.GetReg();
       ARM64Reg XA = EncodeRegTo64(WA);
+
       LDR(IndexType::Unsigned, WA, PPC_REG, PPCSTATE_OFF(Exceptions));
       FixupBranch no_ext_exception = TBZ(WA, IntLog2(EXCEPTION_EXTERNAL_INT));
       FixupBranch exception = B();
@@ -775,14 +776,15 @@ void JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
                                  ProcessorInterface::INT_CAUSE_PE_FINISH;
       TST(WA, LogicalImm(cause_mask, 32));
       B(CC_EQ, done_here);
-      gpr.Unlock(WA);
 
-      gpr.Flush(FlushMode::MaintainState);
-      fpr.Flush(FlushMode::MaintainState);
+      gpr.Flush(FlushMode::MaintainState, WA);
+      fpr.Flush(FlushMode::MaintainState, ARM64Reg::INVALID_REG);
       WriteExceptionExit(js.compilerPC, true, true);
       SwitchToNearCode();
       SetJumpTarget(no_ext_exception);
       SetJumpTarget(exit);
+
+      gpr.Unlock(WA);
     }
 
     if (HandleFunctionHooking(op.address))
@@ -801,8 +803,8 @@ void JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
         SwitchToFarCode();
         SetJumpTarget(far_addr);
 
-        gpr.Flush(FlushMode::MaintainState);
-        fpr.Flush(FlushMode::MaintainState);
+        gpr.Flush(FlushMode::MaintainState, WA);
+        fpr.Flush(FlushMode::MaintainState, ARM64Reg::INVALID_REG);
 
         LDR(IndexType::Unsigned, WA, PPC_REG, PPCSTATE_OFF(Exceptions));
         ORR(WA, WA, LogicalImm(EXCEPTION_FPU_UNAVAILABLE, 32));
@@ -821,8 +823,8 @@ void JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
 
       if (SConfig::GetInstance().bJITRegisterCacheOff)
       {
-        gpr.Flush(FlushMode::All);
-        fpr.Flush(FlushMode::All);
+        gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+        fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
         FlushCarry();
       }
 
@@ -855,8 +857,8 @@ void JitArm64::DoJit(u32 em_address, JitBlock* b, u32 nextPC)
 
   if (code_block.m_broken)
   {
-    gpr.Flush(FlushMode::All);
-    fpr.Flush(FlushMode::All);
+    gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+    fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
     WriteExit(nextPC);
   }
 
diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_Branch.cpp b/Source/Core/Core/PowerPC/JitArm64/JitArm64_Branch.cpp
index 3f9c3f32e4..c0682b616f 100644
--- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_Branch.cpp
+++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_Branch.cpp
@@ -18,8 +18,8 @@ void JitArm64::sc(UGeckoInstruction inst)
   INSTRUCTION_START
   JITDISABLE(bJITBranchOff);
 
-  gpr.Flush(FlushMode::All);
-  fpr.Flush(FlushMode::All);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   ARM64Reg WA = gpr.GetReg();
 
@@ -37,8 +37,8 @@ void JitArm64::rfi(UGeckoInstruction inst)
   INSTRUCTION_START
   JITDISABLE(bJITBranchOff);
 
-  gpr.Flush(FlushMode::All);
-  fpr.Flush(FlushMode::All);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   // See Interpreter rfi for details
   const u32 mask = 0x87C0FFFF;
@@ -95,8 +95,8 @@ void JitArm64::bx(UGeckoInstruction inst)
     return;
   }
 
-  gpr.Flush(FlushMode::All);
-  fpr.Flush(FlushMode::All);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   if (js.op->branchIsIdleLoop)
   {
@@ -151,20 +151,17 @@ void JitArm64::bcx(UGeckoInstruction inst)
     MOVI2R(WA, js.compilerPC + 4);
     STR(IndexType::Unsigned, WA, PPC_REG, PPCSTATE_OFF_SPR(SPR_LR));
   }
-  gpr.Unlock(WA);
 
-  gpr.Flush(FlushMode::MaintainState);
-  fpr.Flush(FlushMode::MaintainState);
+  gpr.Flush(FlushMode::MaintainState, WA);
+  fpr.Flush(FlushMode::MaintainState, ARM64Reg::INVALID_REG);
 
   if (js.op->branchIsIdleLoop)
   {
     // make idle loops go faster
-    ARM64Reg WA2 = gpr.GetReg();
-    ARM64Reg XA2 = EncodeRegTo64(WA2);
+    ARM64Reg XA = EncodeRegTo64(WA);
 
-    MOVP2R(XA2, &CoreTiming::Idle);
-    BLR(XA2);
-    gpr.Unlock(WA2);
+    MOVP2R(XA, &CoreTiming::Idle);
+    BLR(XA);
 
     WriteExceptionExit(js.op->branchTo);
   }
@@ -182,10 +179,12 @@ void JitArm64::bcx(UGeckoInstruction inst)
 
   if (!analyzer.HasOption(PPCAnalyst::PPCAnalyzer::OPTION_CONDITIONAL_CONTINUE))
   {
-    gpr.Flush(FlushMode::All);
-    fpr.Flush(FlushMode::All);
+    gpr.Flush(FlushMode::All, WA);
+    fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
     WriteExit(js.compilerPC + 4);
   }
+
+  gpr.Unlock(WA);
 }
 
 void JitArm64::bcctrx(UGeckoInstruction inst)
@@ -205,8 +204,8 @@ void JitArm64::bcctrx(UGeckoInstruction inst)
   // BO_2 == 1z1zz -> b always
 
   // NPC = CTR & 0xfffffffc;
-  gpr.Flush(FlushMode::All);
-  fpr.Flush(FlushMode::All);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   if (inst.LK_3)
   {
@@ -235,7 +234,7 @@ void JitArm64::bclrx(UGeckoInstruction inst)
       (inst.BO & BO_DONT_DECREMENT_FLAG) == 0 || (inst.BO & BO_DONT_CHECK_CONDITION) == 0;
 
   ARM64Reg WA = gpr.GetReg();
-  ARM64Reg WB = inst.LK ? gpr.GetReg() : ARM64Reg::INVALID_REG;
+  ARM64Reg WB = conditional || inst.LK ? gpr.GetReg() : ARM64Reg::INVALID_REG;
 
   FixupBranch pCTRDontBranch;
   if ((inst.BO & BO_DONT_DECREMENT_FLAG) == 0)  // Decrement and test CTR
@@ -271,11 +270,10 @@ void JitArm64::bclrx(UGeckoInstruction inst)
   {
     MOVI2R(WB, js.compilerPC + 4);
     STR(IndexType::Unsigned, WB, PPC_REG, PPCSTATE_OFF_SPR(SPR_LR));
-    gpr.Unlock(WB);
   }
 
-  gpr.Flush(conditional ? FlushMode::MaintainState : FlushMode::All);
-  fpr.Flush(conditional ? FlushMode::MaintainState : FlushMode::All);
+  gpr.Flush(conditional ? FlushMode::MaintainState : FlushMode::All, WB);
+  fpr.Flush(conditional ? FlushMode::MaintainState : FlushMode::All, ARM64Reg::INVALID_REG);
 
   if (js.op->branchIsIdleLoop)
   {
@@ -292,8 +290,6 @@ void JitArm64::bclrx(UGeckoInstruction inst)
     WriteBLRExit(WA);
   }
 
-  gpr.Unlock(WA);
-
   if (conditional)
     SwitchToNearCode();
 
@@ -304,8 +300,12 @@ void JitArm64::bclrx(UGeckoInstruction inst)
 
   if (!analyzer.HasOption(PPCAnalyst::PPCAnalyzer::OPTION_CONDITIONAL_CONTINUE))
   {
-    gpr.Flush(FlushMode::All);
-    fpr.Flush(FlushMode::All);
+    gpr.Flush(FlushMode::All, WA);
+    fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
     WriteExit(js.compilerPC + 4);
   }
+
+  gpr.Unlock(WA);
+  if (WB != ARM64Reg::INVALID_REG)
+    gpr.Unlock(WB);
 }
diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.cpp b/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.cpp
index 7bd95f74ea..fdd4790807 100644
--- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.cpp
+++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.cpp
@@ -123,7 +123,7 @@ void Arm64RegCache::FlushMostStaleRegister()
     }
   }
 
-  FlushRegister(most_stale_preg, false);
+  FlushRegister(most_stale_preg, false, ARM64Reg::INVALID_REG);
 }
 
 void Arm64RegCache::DiscardRegister(size_t preg)
@@ -197,7 +197,7 @@ Arm64GPRCache::GuestRegInfo Arm64GPRCache::GetGuestByIndex(size_t index)
   return GetGuestGPR(0);
 }
 
-void Arm64GPRCache::FlushRegister(size_t index, bool maintain_state)
+void Arm64GPRCache::FlushRegister(size_t index, bool maintain_state, ARM64Reg tmp_reg)
 {
   GuestRegInfo guest_reg = GetGuestByIndex(index);
   OpArg& reg = guest_reg.reg;
@@ -224,12 +224,26 @@ void Arm64GPRCache::FlushRegister(size_t index, bool maintain_state)
     }
     else
     {
-      ARM64Reg host_reg = bitsize != 64 ? GetReg() : EncodeRegTo64(GetReg());
+      bool allocated_tmp_reg = false;
+      if (tmp_reg != ARM64Reg::INVALID_REG)
+      {
+        ASSERT(IsGPR(tmp_reg));
+      }
+      else
+      {
+        ASSERT_MSG(DYNA_REC, !maintain_state,
+                   "Flushing immediate while maintaining state requires temporary register");
+        tmp_reg = GetReg();
+        allocated_tmp_reg = true;
+      }
 
-      m_emit->MOVI2R(host_reg, reg.GetImm());
-      m_emit->STR(IndexType::Unsigned, host_reg, PPC_REG, u32(guest_reg.ppc_offset));
+      const ARM64Reg encoded_tmp_reg = bitsize != 64 ? tmp_reg : EncodeRegTo64(tmp_reg);
 
-      UnlockRegister(EncodeRegTo32(host_reg));
+      m_emit->MOVI2R(encoded_tmp_reg, reg.GetImm());
+      m_emit->STR(IndexType::Unsigned, encoded_tmp_reg, PPC_REG, u32(guest_reg.ppc_offset));
+
+      if (allocated_tmp_reg)
+        UnlockRegister(tmp_reg);
     }
 
     if (!maintain_state)
@@ -237,7 +251,7 @@ void Arm64GPRCache::FlushRegister(size_t index, bool maintain_state)
   }
 }
 
-void Arm64GPRCache::FlushRegisters(BitSet32 regs, bool maintain_state)
+void Arm64GPRCache::FlushRegisters(BitSet32 regs, bool maintain_state, ARM64Reg tmp_reg)
 {
   for (size_t i = 0; i < GUEST_GPR_COUNT; ++i)
   {
@@ -270,26 +284,26 @@ void Arm64GPRCache::FlushRegisters(BitSet32 regs, bool maintain_state)
         }
       }
 
-      FlushRegister(GUEST_GPR_OFFSET + i, maintain_state);
+      FlushRegister(GUEST_GPR_OFFSET + i, maintain_state, tmp_reg);
     }
   }
 }
 
-void Arm64GPRCache::FlushCRRegisters(BitSet32 regs, bool maintain_state)
+void Arm64GPRCache::FlushCRRegisters(BitSet32 regs, bool maintain_state, ARM64Reg tmp_reg)
 {
   for (size_t i = 0; i < GUEST_CR_COUNT; ++i)
   {
     if (regs[i])
     {
-      FlushRegister(GUEST_CR_OFFSET + i, maintain_state);
+      FlushRegister(GUEST_CR_OFFSET + i, maintain_state, tmp_reg);
     }
   }
 }
 
-void Arm64GPRCache::Flush(FlushMode mode, PPCAnalyst::CodeOp* op)
+void Arm64GPRCache::Flush(FlushMode mode, ARM64Reg tmp_reg)
 {
-  FlushRegisters(BitSet32(~0U), mode == FlushMode::MaintainState);
-  FlushCRRegisters(BitSet32(~0U), mode == FlushMode::MaintainState);
+  FlushRegisters(BitSet32(~0U), mode == FlushMode::MaintainState, tmp_reg);
+  FlushCRRegisters(BitSet32(~0U), mode == FlushMode::MaintainState, tmp_reg);
 }
 
 ARM64Reg Arm64GPRCache::R(const GuestRegInfo& guest_reg)
@@ -417,14 +431,14 @@ BitSet32 Arm64GPRCache::GetCallerSavedUsed() const
   return registers;
 }
 
-void Arm64GPRCache::FlushByHost(ARM64Reg host_reg)
+void Arm64GPRCache::FlushByHost(ARM64Reg host_reg, ARM64Reg tmp_reg)
 {
   for (size_t i = 0; i < m_guest_registers.size(); ++i)
   {
     const OpArg& reg = m_guest_registers[i];
     if (reg.GetType() == RegType::Register && DecodeReg(reg.GetReg()) == DecodeReg(host_reg))
     {
-      FlushRegister(i, false);
+      FlushRegister(i, false, tmp_reg);
       return;
     }
   }
@@ -437,7 +451,7 @@ Arm64FPRCache::Arm64FPRCache() : Arm64RegCache(GUEST_FPR_COUNT)
 {
 }
 
-void Arm64FPRCache::Flush(FlushMode mode, PPCAnalyst::CodeOp* op)
+void Arm64FPRCache::Flush(FlushMode mode, ARM64Reg tmp_reg)
 {
   for (size_t i = 0; i < m_guest_registers.size(); ++i)
   {
@@ -446,7 +460,7 @@ void Arm64FPRCache::Flush(FlushMode mode, PPCAnalyst::CodeOp* op)
     if (reg_type != RegType::NotLoaded && reg_type != RegType::Discarded &&
         reg_type != RegType::Immediate)
     {
-      FlushRegister(i, mode == FlushMode::MaintainState);
+      FlushRegister(i, mode == FlushMode::MaintainState, tmp_reg);
     }
   }
 }
@@ -695,7 +709,7 @@ void Arm64FPRCache::GetAllocationOrder()
     m_host_registers.push_back(HostReg(reg));
 }
 
-void Arm64FPRCache::FlushByHost(ARM64Reg host_reg)
+void Arm64FPRCache::FlushByHost(ARM64Reg host_reg, ARM64Reg tmp_reg)
 {
   for (size_t i = 0; i < m_guest_registers.size(); ++i)
   {
@@ -705,7 +719,7 @@ void Arm64FPRCache::FlushByHost(ARM64Reg host_reg)
     if (reg_type != RegType::NotLoaded && reg_type != RegType::Discarded &&
         reg_type != RegType::Immediate && reg.GetReg() == host_reg)
     {
-      FlushRegister(i, false);
+      FlushRegister(i, false, tmp_reg);
       return;
     }
   }
@@ -728,15 +742,31 @@ bool Arm64FPRCache::IsCalleeSaved(ARM64Reg reg) const
   return std::find(callee_regs.begin(), callee_regs.end(), reg) != callee_regs.end();
 }
 
-void Arm64FPRCache::FlushRegister(size_t preg, bool maintain_state)
+void Arm64FPRCache::FlushRegister(size_t preg, bool maintain_state, ARM64Reg tmp_reg)
 {
   OpArg& reg = m_guest_registers[preg];
   const ARM64Reg host_reg = reg.GetReg();
   const bool dirty = reg.IsDirty();
   RegType type = reg.GetType();
 
-  // If FlushRegister calls GetReg with all registers locked, we can get infinite recursion
-  const ARM64Reg tmp_reg = GetUnlockedRegisterCount() > 0 ? GetReg() : ARM64Reg::INVALID_REG;
+  bool allocated_tmp_reg = false;
+  if (tmp_reg != ARM64Reg::INVALID_REG)
+  {
+    ASSERT(IsVector(tmp_reg));
+  }
+  else if (GetUnlockedRegisterCount() > 0)
+  {
+    // Calling GetReg here with 0 registers free could cause problems for two reasons:
+    //
+    // 1. When GetReg needs to flush, it calls this function, which can lead to infinite recursion
+    // 2. When GetReg needs to flush, it does not respect maintain_state == true
+    //
+    // So if we have 0 registers free, just don't allocate a temporary register.
+    // The emitted code will still work but might be a little less efficient.
+
+    tmp_reg = GetReg();
+    allocated_tmp_reg = true;
+  }
 
   // If we're in single mode, just convert it back to a double.
   if (type == RegType::Single)
@@ -801,14 +831,14 @@ void Arm64FPRCache::FlushRegister(size_t preg, bool maintain_state)
     }
   }
 
-  if (tmp_reg != ARM64Reg::INVALID_REG)
+  if (allocated_tmp_reg)
     UnlockRegister(tmp_reg);
 }
 
-void Arm64FPRCache::FlushRegisters(BitSet32 regs, bool maintain_state)
+void Arm64FPRCache::FlushRegisters(BitSet32 regs, bool maintain_state, ARM64Reg tmp_reg)
 {
   for (int j : regs)
-    FlushRegister(j, maintain_state);
+    FlushRegister(j, maintain_state, tmp_reg);
 }
 
 BitSet32 Arm64FPRCache::GetCallerSavedUsed() const
diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.h b/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.h
index a62094aea6..a8a110b1e1 100644
--- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.h
+++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_RegCache.h
@@ -156,8 +156,10 @@ public:
   virtual void Start(PPCAnalyst::BlockRegStats& stats) {}
   void DiscardRegisters(BitSet32 regs);
   void ResetRegisters(BitSet32 regs);
-  // Flushes the register cache in different ways depending on the mode
-  virtual void Flush(FlushMode mode, PPCAnalyst::CodeOp* op) = 0;
+  // Flushes the register cache in different ways depending on the mode.
+  // A temporary register must be supplied when flushing GPRs with FlushMode::MaintainState,
+  // but in other cases it can be set to ARM64Reg::INVALID_REG when convenient for the caller.
+  virtual void Flush(FlushMode mode, Arm64Gen::ARM64Reg tmp_reg) = 0;
 
   virtual BitSet32 GetCallerSavedUsed() const = 0;
 
@@ -208,10 +210,11 @@ protected:
   void UnlockRegister(Arm64Gen::ARM64Reg host_reg);
 
   // Flushes a guest register by host provided
-  virtual void FlushByHost(Arm64Gen::ARM64Reg host_reg) = 0;
+  virtual void FlushByHost(Arm64Gen::ARM64Reg host_reg,
+                           Arm64Gen::ARM64Reg tmp_reg = Arm64Gen::ARM64Reg::INVALID_REG) = 0;
 
   void DiscardRegister(size_t preg);
-  virtual void FlushRegister(size_t preg, bool maintain_state) = 0;
+  virtual void FlushRegister(size_t preg, bool maintain_state, Arm64Gen::ARM64Reg tmp_reg) = 0;
 
   void IncrementAllUsed()
   {
@@ -246,8 +249,10 @@ public:
 
   void Start(PPCAnalyst::BlockRegStats& stats) override;
 
-  // Flushes the register cache in different ways depending on the mode
-  void Flush(FlushMode mode, PPCAnalyst::CodeOp* op = nullptr) override;
+  // Flushes the register cache in different ways depending on the mode.
+  // A temporary register must be supplied when flushing GPRs with FlushMode::MaintainState,
+  // but in other cases it can be set to ARM64Reg::INVALID_REG when convenient for the caller.
+  void Flush(FlushMode mode, Arm64Gen::ARM64Reg tmp_reg) override;
 
   // Returns a guest GPR inside of a host register
   // Will dump an immediate to the host register as well
@@ -266,17 +271,24 @@ public:
   void BindCRToRegister(size_t preg, bool do_load) { BindToRegister(GetGuestCR(preg), do_load); }
   BitSet32 GetCallerSavedUsed() const override;
 
-  void StoreRegisters(BitSet32 regs) { FlushRegisters(regs, false); }
-  void StoreCRRegisters(BitSet32 regs) { FlushCRRegisters(regs, false); }
+  void StoreRegisters(BitSet32 regs, Arm64Gen::ARM64Reg tmp_reg = Arm64Gen::ARM64Reg::INVALID_REG)
+  {
+    FlushRegisters(regs, false, tmp_reg);
+  }
+  void StoreCRRegisters(BitSet32 regs, Arm64Gen::ARM64Reg tmp_reg = Arm64Gen::ARM64Reg::INVALID_REG)
+  {
+    FlushCRRegisters(regs, false, tmp_reg);
+  }
 
 protected:
   // Get the order of the host registers
   void GetAllocationOrder() override;
 
   // Flushes a guest register by host provided
-  void FlushByHost(Arm64Gen::ARM64Reg host_reg) override;
+  void FlushByHost(Arm64Gen::ARM64Reg host_reg,
+                   Arm64Gen::ARM64Reg tmp_reg = Arm64Gen::ARM64Reg::INVALID_REG) override;
 
-  void FlushRegister(size_t index, bool maintain_state) override;
+  void FlushRegister(size_t index, bool maintain_state, Arm64Gen::ARM64Reg tmp_reg) override;
 
 private:
   bool IsCalleeSaved(Arm64Gen::ARM64Reg reg) const;
@@ -297,8 +309,8 @@ private:
   void SetImmediate(const GuestRegInfo& guest_reg, u32 imm);
   void BindToRegister(const GuestRegInfo& guest_reg, bool do_load);
 
-  void FlushRegisters(BitSet32 regs, bool maintain_state);
-  void FlushCRRegisters(BitSet32 regs, bool maintain_state);
+  void FlushRegisters(BitSet32 regs, bool maintain_state, Arm64Gen::ARM64Reg tmp_reg);
+  void FlushCRRegisters(BitSet32 regs, bool maintain_state, Arm64Gen::ARM64Reg tmp_reg);
 };
 
 class Arm64FPRCache : public Arm64RegCache
@@ -306,8 +318,9 @@ class Arm64FPRCache : public Arm64RegCache
 public:
   Arm64FPRCache();
 
-  // Flushes the register cache in different ways depending on the mode
-  void Flush(FlushMode mode, PPCAnalyst::CodeOp* op = nullptr) override;
+  // Flushes the register cache in different ways depending on the mode.
+  // The temporary register can be set to ARM64Reg::INVALID_REG when convenient for the caller.
+  void Flush(FlushMode mode, Arm64Gen::ARM64Reg tmp_reg) override;
 
   // Returns a guest register inside of a host register
   // Will dump an immediate to the host register as well
@@ -321,19 +334,23 @@ public:
 
   void FixSinglePrecision(size_t preg);
 
-  void StoreRegisters(BitSet32 regs) { FlushRegisters(regs, false); }
+  void StoreRegisters(BitSet32 regs, Arm64Gen::ARM64Reg tmp_reg = Arm64Gen::ARM64Reg::INVALID_REG)
+  {
+    FlushRegisters(regs, false, tmp_reg);
+  }
 
 protected:
   // Get the order of the host registers
   void GetAllocationOrder() override;
 
   // Flushes a guest register by host provided
-  void FlushByHost(Arm64Gen::ARM64Reg host_reg) override;
+  void FlushByHost(Arm64Gen::ARM64Reg host_reg,
+                   Arm64Gen::ARM64Reg tmp_reg = Arm64Gen::ARM64Reg::INVALID_REG) override;
 
-  void FlushRegister(size_t preg, bool maintain_state) override;
+  void FlushRegister(size_t preg, bool maintain_state, Arm64Gen::ARM64Reg tmp_reg) override;
 
 private:
   bool IsCalleeSaved(Arm64Gen::ARM64Reg reg) const;
 
-  void FlushRegisters(BitSet32 regs, bool maintain_state);
+  void FlushRegisters(BitSet32 regs, bool maintain_state, Arm64Gen::ARM64Reg tmp_reg);
 };
diff --git a/Source/Core/Core/PowerPC/JitArm64/JitArm64_SystemRegisters.cpp b/Source/Core/Core/PowerPC/JitArm64/JitArm64_SystemRegisters.cpp
index 3b3dff85d1..18a9404502 100644
--- a/Source/Core/Core/PowerPC/JitArm64/JitArm64_SystemRegisters.cpp
+++ b/Source/Core/Core/PowerPC/JitArm64/JitArm64_SystemRegisters.cpp
@@ -56,8 +56,8 @@ void JitArm64::mtmsr(UGeckoInstruction inst)
   gpr.BindToRegister(inst.RS, true);
   STR(IndexType::Unsigned, gpr.R(inst.RS), PPC_REG, PPCSTATE_OFF(msr));
 
-  gpr.Flush(FlushMode::All);
-  fpr.Flush(FlushMode::All);
+  gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
+  fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
 
   // Our jit cache also stores some MSR bits, as they have changed, we either
   // have to validate them in the BLR/RET check, or just flush the stack here.
@@ -213,13 +213,12 @@ void JitArm64::twx(UGeckoInstruction inst)
   SwitchToFarCode();
   SetJumpTarget(far_addr);
 
-  gpr.Flush(FlushMode::MaintainState);
-  fpr.Flush(FlushMode::MaintainState);
+  gpr.Flush(FlushMode::MaintainState, WA);
+  fpr.Flush(FlushMode::MaintainState, ARM64Reg::INVALID_REG);
 
   LDR(IndexType::Unsigned, WA, PPC_REG, PPCSTATE_OFF(Exceptions));
   ORR(WA, WA, LogicalImm(EXCEPTION_PROGRAM, 32));
   STR(IndexType::Unsigned, WA, PPC_REG, PPCSTATE_OFF(Exceptions));
-  gpr.Unlock(WA);
 
   WriteExceptionExit(js.compilerPC, false, true);
 
@@ -229,10 +228,12 @@ void JitArm64::twx(UGeckoInstruction inst)
 
   if (!analyzer.HasOption(PPCAnalyst::PPCAnalyzer::OPTION_CONDITIONAL_CONTINUE))
   {
-    gpr.Flush(FlushMode::All);
-    fpr.Flush(FlushMode::All);
+    gpr.Flush(FlushMode::All, WA);
+    fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
     WriteExit(js.compilerPC + 4);
   }
+
+  gpr.Unlock(WA);
 }
 
 void JitArm64::mfspr(UGeckoInstruction inst)