diff --git a/Source/Core/Core/Boot/Boot.cpp b/Source/Core/Core/Boot/Boot.cpp index be4ef2636c..3b004eca78 100644 --- a/Source/Core/Core/Boot/Boot.cpp +++ b/Source/Core/Core/Boot/Boot.cpp @@ -5,7 +5,10 @@ #include "Core/Boot/Boot.h" #include +#include +#include #include +#include #include #include #include @@ -25,6 +28,7 @@ #include "Core/Boot/DolReader.h" #include "Core/Boot/ElfReader.h" +#include "Core/CommonTitles.h" #include "Core/ConfigManager.h" #include "Core/FifoPlayer/FifoPlayer.h" #include "Core/HLE/HLE.h" @@ -418,3 +422,38 @@ BootExecutableReader::BootExecutableReader(const std::vector& bytes) : m_byt } BootExecutableReader::~BootExecutableReader() = default; + +void StateFlags::UpdateChecksum() +{ + constexpr size_t length_in_bytes = sizeof(StateFlags) - 4; + constexpr size_t num_elements = length_in_bytes / sizeof(u32); + std::array flag_data; + std::memcpy(flag_data.data(), &flags, length_in_bytes); + checksum = std::accumulate(flag_data.cbegin(), flag_data.cend(), 0U); +} + +void UpdateStateFlags(std::function update_function) +{ + const std::string file_path = + Common::GetTitleDataPath(Titles::SYSTEM_MENU, Common::FROM_SESSION_ROOT) + WII_STATE; + + File::IOFile file; + StateFlags state; + if (File::Exists(file_path)) + { + file.Open(file_path, "r+b"); + file.ReadBytes(&state, sizeof(state)); + } + else + { + File::CreateFullPath(file_path); + file.Open(file_path, "a+b"); + memset(&state, 0, sizeof(state)); + } + + update_function(&state); + state.UpdateChecksum(); + + file.Seek(0, SEEK_SET); + file.WriteBytes(&state, sizeof(state)); +} diff --git a/Source/Core/Core/Boot/Boot.h b/Source/Core/Core/Boot/Boot.h index f5428616af..96dd2002e4 100644 --- a/Source/Core/Core/Boot/Boot.h +++ b/Source/Core/Core/Boot/Boot.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include #include @@ -128,3 +129,18 @@ public: protected: std::vector m_bytes; }; + +struct StateFlags +{ + void UpdateChecksum(); + u32 checksum; + u8 flags; + u8 type; + u8 discstate; + u8 returnto; + u32 unknown[6]; +}; + +// Reads the state file from the NAND, then calls the passed update function to update the struct, +// and finally writes the updated state file to the NAND. +void UpdateStateFlags(std::function update_function); diff --git a/Source/Core/Core/Boot/Boot_WiiWAD.cpp b/Source/Core/Core/Boot/Boot_WiiWAD.cpp index 145ae48a1f..44d2b64262 100644 --- a/Source/Core/Core/Boot/Boot_WiiWAD.cpp +++ b/Source/Core/Core/Boot/Boot_WiiWAD.cpp @@ -2,15 +2,12 @@ // Licensed under GPLv2+ // Refer to the license.txt file included. -#include #include #include -#include #include #include "Common/CommonPaths.h" #include "Common/CommonTypes.h" -#include "Common/File.h" #include "Common/FileUtil.h" #include "Common/MsgHandler.h" #include "Common/NandPaths.h" @@ -24,55 +21,11 @@ #include "DiscIO/NANDContentLoader.h" -struct StateFlags -{ - u32 checksum; - u8 flags; - u8 type; - u8 discstate; - u8 returnto; - u32 unknown[6]; -}; - -static u32 StateChecksum(const StateFlags& flags) -{ - constexpr size_t length_in_bytes = sizeof(StateFlags) - 4; - constexpr size_t num_elements = length_in_bytes / sizeof(u32); - std::array flag_data; - - std::memcpy(flag_data.data(), &flags.flags, length_in_bytes); - - return std::accumulate(flag_data.cbegin(), flag_data.cend(), 0U); -} - bool CBoot::Boot_WiiWAD(const std::string& _pFilename) { - std::string state_filename( - Common::GetTitleDataPath(Titles::SYSTEM_MENU, Common::FROM_SESSION_ROOT) + WII_STATE); - - if (File::Exists(state_filename)) - { - File::IOFile state_file(state_filename, "r+b"); - StateFlags state; - state_file.ReadBytes(&state, sizeof(StateFlags)); - - state.type = 0x03; // TYPE_RETURN - state.checksum = StateChecksum(state); - - state_file.Seek(0, SEEK_SET); - state_file.WriteBytes(&state, sizeof(StateFlags)); - } - else - { - File::CreateFullPath(state_filename); - File::IOFile state_file(state_filename, "a+b"); - StateFlags state; - memset(&state, 0, sizeof(StateFlags)); - state.type = 0x03; // TYPE_RETURN - state.discstate = 0x01; // DISCSTATE_WII - state.checksum = StateChecksum(state); - state_file.WriteBytes(&state, sizeof(StateFlags)); - } + UpdateStateFlags([](StateFlags* state) { + state->type = 0x03; // TYPE_RETURN + }); const DiscIO::NANDContentLoader& ContentLoader = DiscIO::NANDContentManager::Access().GetNANDLoader(_pFilename);