New feature: Override player input with machine learning models (#17407)

* Add dummy game ai subsystem

* First working prototype of a machine learning model that can override player input

* Update README.md

* Update README.md

* Fix loading path on Windows

* Change ai override to player 2

* Added quick menu show game ai option

* Implemented Quick Menu entry for Game AI options

* Redirect debug logs to retroarch log system + properly support player override

* Added support to use framebuffer as input to the AI

* Added pixel format parameter to API

* Fix game name

* code clean-up of game_ai.cpp

* Update README.md - Windows Build

* Update README.md

* Update README.md

* Update README.md

* Update config.params.sh

turn off GAME_AI feature by default

* Fix compile error in menu_displaylist.c

* Add missing #define in menu_cbs_title.c

* Added new game_ai entry in griffin_cpp

* Remove GAME_AI entry in  msg_hash_us.c

* Fix compile error in menu_displaylist.h

* Removed GAME AI references from README.md

* Fixes coding style + add GameAI lib API header

* Convert comment to legacy + remove unused code

* Additional coding style fixes to game_ai.cpp

* Fix identation issues in game_ai.cpp

* Removed some debug code in game_ai.cpp

* Add game_ai_lib in deps

* Replace assert with retro_assert

* Update Makefile.common

* Converting game_ai from cpp to c. First step.

* Convert game_ai from CPP to C. STEP 2: add C function calls

* Convert game_ai from CPP to C. Final Step

* Added shutdown function for game ai lib

* Update game_ai_lib README

* Fix crash when loading/unloading multiple games
This commit is contained in:
Mathieu Poliquin 2025-01-21 20:05:43 +08:00 committed by GitHub
parent 3797d4deb6
commit 66e23fca79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 20810 additions and 1 deletions

View File

@ -2648,6 +2648,11 @@ ifeq ($(HAVE_ODROIDGO2), 1)
gfx/drivers/oga_gfx.o
endif
ifeq ($(HAVE_GAME_AI),1)
DEFINES += -DHAVE_GAME_AI
OBJ += ai/game_ai.o
endif
# Detect the operating system
UNAME := $(shell uname -s)
@ -2673,7 +2678,6 @@ else
$(warning Windows NT version macro (_WIN32_WINNT) is not defined.)
endif
endif
##################################

220
ai/game_ai.c Normal file
View File

@ -0,0 +1,220 @@
#include "game_ai.h"
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include <retro_assert.h>
#include "../deps/game_ai_lib/GameAI.h"
#define GAME_AI_MAX_PLAYERS 2
void * ga = NULL;
volatile void * g_ram_ptr = NULL;
volatile int g_ram_size = 0;
volatile signed short int g_buttons_bits[GAME_AI_MAX_PLAYERS] = {0};
volatile int g_frameCount = 0;
volatile char game_ai_lib_path[1024] = {0};
volatile char g_game_name[1024] = {0};
retro_log_printf_t g_log = NULL;
#ifdef _WIN32
HINSTANCE g_lib_handle = NULL;
#else
void * g_lib_handle = NULL;
#endif
/* GameAI Lib API*/
create_game_ai_t create_game_ai = NULL;
destroy_game_ai_t destroy_game_ai = NULL;
game_ai_lib_init_t game_ai_lib_init = NULL;
game_ai_lib_think_t game_ai_lib_think = NULL;
game_ai_lib_set_show_debug_t game_ai_lib_set_show_debug = NULL;
game_ai_lib_set_debug_log_t game_ai_lib_set_debug_log = NULL;
/* Helper functions */
void game_ai_debug_log(int level, const char *fmt, ...)
{
va_list vp;
va_start(vp, fmt);
if (g_log)
g_log((enum retro_log_level)level, fmt, vp);
va_end(vp);
}
void array_to_bits_16(volatile signed short * result, const bool b[16])
{
for (int bit = 0; bit <= 15; bit++)
{
*result |= b[bit] ? (1 << bit) : 0;
}
}
/* Interface to RA */
extern signed short int game_ai_input(unsigned int port, unsigned int device, unsigned int idx, unsigned int id, signed short int result)
{
if (ga == NULL)
return 0;
if (port < GAME_AI_MAX_PLAYERS)
return g_buttons_bits[port];
return 0;
}
extern void game_ai_init()
{
if (create_game_ai == NULL)
{
#ifdef _WIN32
BOOL fFreeResult, fRunTimeLinkSuccess = FALSE;
g_lib_handle = LoadLibrary(TEXT("game_ai.dll"));
retro_assert(hinstLib);
char full_module_path[MAX_PATH];
DWORD dwLen = GetModuleFileNameA(g_lib_handle, static_cast<char*>(&full_module_path), MAX_PATH);
if (hinstLib != NULL)
{
create_game_ai = (create_game_ai_t) GetProcAddress(hinstLib, "create_game_ai");
retro_assert(create_game_ai);
destroy_game_ai = (destroy_game_ai_t) GetProcAddress(hinstLib, "destroy_game_ai");
retro_assert(destroy_game_ai);
game_ai_lib_init = (game_ai_lib_init_t) GetProcAddress(hinstLib, "game_ai_lib_init");
retro_assert(game_ai_lib_init);
game_ai_lib_think = (game_ai_lib_think_t) GetProcAddress(hinstLib, "game_ai_lib_think");
retro_assert(game_ai_lib_think);
game_ai_lib_set_show_debug = (game_ai_lib_set_show_debug_t) GetProcAddress(hinstLib, "game_ai_lib_set_show_debug");
retro_assert(game_ai_lib_set_show_debug);
game_ai_lib_set_debug_log = (game_ai_lib_set_debug_log_t) GetProcAddress(hinstLib, "game_ai_lib_set_debug_log");
retro_assert(game_ai_lib_set_debug_log);
}
#else
g_lib_handle = dlopen("libgame_ai.so", RTLD_NOW);
retro_assert(g_lib_handle);
if(g_lib_handle != NULL)
{
dlinfo(g_lib_handle, RTLD_DI_ORIGIN, (void *) &game_ai_lib_path);
create_game_ai = (create_game_ai_t)(dlsym(g_lib_handle, "create_game_ai"));
retro_assert(create_game_ai);
destroy_game_ai = (destroy_game_ai_t)(dlsym(g_lib_handle, "destroy_game_ai"));
retro_assert(destroy_game_ai);
game_ai_lib_init = (game_ai_lib_init_t)(dlsym(g_lib_handle, "game_ai_lib_init"));
retro_assert(game_ai_lib_init);
game_ai_lib_think = (game_ai_lib_think_t)(dlsym(g_lib_handle, "game_ai_lib_think"));
retro_assert(game_ai_lib_think);
game_ai_lib_set_show_debug = (game_ai_lib_set_show_debug_t)(dlsym(g_lib_handle, "game_ai_lib_set_show_debug"));
retro_assert(game_ai_lib_set_show_debug);
game_ai_lib_set_debug_log = (game_ai_lib_set_debug_log_t)(dlsym(g_lib_handle, "game_ai_lib_set_debug_log"));
retro_assert(game_ai_lib_set_debug_log);
}
#endif
}
}
extern void game_ai_shutdown()
{
if (g_lib_handle)
{
if (ga)
{
destroy_game_ai(ga);
ga = NULL;
}
#ifdef _WIN32
FreeLibrary(g_lib_handle);
#else
dlclose(g_lib_handle);
#endif
}
}
extern void game_ai_load(const char * name, void * ram_ptr, int ram_size, retro_log_printf_t log)
{
strcpy((char *) &g_game_name[0], name);
g_ram_ptr = ram_ptr;
g_ram_size = ram_size;
g_log = log;
if (ga)
{
destroy_game_ai(ga);
ga = NULL;
}
}
extern void game_ai_think(bool override_p1, bool override_p2, bool show_debug, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format)
{
if (ga)
game_ai_lib_set_show_debug(ga, show_debug);
if (ga == NULL && g_ram_ptr != NULL)
{
ga = create_game_ai((char *) &g_game_name[0]);
retro_assert(ga);
if (ga)
{
char data_path[1024] = {0};
strcpy(&data_path[0], (char *)game_ai_lib_path);
strcat(&data_path[0], "/data/");
strcat(&data_path[0], (char *)g_game_name);
game_ai_lib_init(ga, (void *) g_ram_ptr, g_ram_size);
game_ai_lib_set_debug_log(ga, game_ai_debug_log);
}
}
if (g_frameCount >= (GAMEAI_SKIPFRAMES - 1))
{
if (ga)
{
bool b[GAMEAI_MAX_BUTTONS] = {0};
g_buttons_bits[0]=0;
g_buttons_bits[1]=0;
if (override_p1)
{
game_ai_lib_think(ga, b, 0, frame_data, frame_width, frame_height, frame_pitch, pixel_format);
array_to_bits_16(&g_buttons_bits[0], b);
}
if (override_p2)
{
game_ai_lib_think(ga, b, 1, frame_data, frame_width, frame_height, frame_pitch, pixel_format);
array_to_bits_16(&g_buttons_bits[1], b);
}
}
g_frameCount=0;
}
else
{
g_frameCount++;
}
}

9
ai/game_ai.h Normal file
View File

@ -0,0 +1,9 @@
#pragma once
#include <libretro.h>
signed short int game_ai_input(unsigned int port, unsigned int device, unsigned int idx, unsigned int id, signed short int result);
void game_ai_init();
void game_ai_shutdown();
void game_ai_load(const char * name, void * ram_ptr, int ram_size, retro_log_printf_t log);
void game_ai_think(bool override_p1, bool override_p2, bool show_debug, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format);

View File

@ -2227,6 +2227,10 @@ static struct config_bool_setting *populate_settings_bool(
SETTING_BOOL("gcdwebserver_alert", &settings->bools.gcdwebserver_alert, true, true, false);
#endif
#ifdef HAVE_GAME_AI
SETTING_BOOL("quick_menu_show_game_ai", &settings->bools.quick_menu_show_game_ai, true, 1, false);
#endif
*size = count;
return tmp;

View File

@ -1095,6 +1095,14 @@ typedef struct settings
#if defined(HAVE_COCOATOUCH)
bool gcdwebserver_alert;
#endif
#ifdef HAVE_GAME_AI
bool quick_menu_show_game_ai;
bool game_ai_override_p1;
bool game_ai_override_p2;
bool game_ai_show_debug;
#endif
} bools;
uint8_t flags;

8
deps/game_ai_lib/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
libtorch/
build/
CMakeFiles/
Debug/
libtorch/
win/
*.zip
.vscode/

25
deps/game_ai_lib/CMakeLists.txt vendored Normal file
View File

@ -0,0 +1,25 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
include_directories(${OpenCV_INCLUDE_DIRS})
add_executable(test test.cpp RetroModel.cpp)
target_link_libraries(test "${TORCH_LIBRARIES}" ${OpenCV_LIBS})
add_library(game_ai SHARED GameAILocal.cpp RetroModel.cpp games/NHL94GameAI.cpp games/NHL94GameData.cpp games/DefaultGameAI.cpp utils/data.cpp utils/memory.cpp utils/utils.cpp)
target_link_libraries(game_ai "${TORCH_LIBRARIES}" ${OpenCV_LIBS})
set_property(TARGET test PROPERTY CXX_STANDARD 17)
set_property(TARGET game_ai PROPERTY CXX_STANDARD 17)
if (MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET test
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:test>)
endif (MSVC)

41
deps/game_ai_lib/GameAI.h vendored Normal file
View File

@ -0,0 +1,41 @@
#pragma once
#ifdef __cplusplus
#include <bitset>
#include <string>
#include <filesystem>
#include <vector>
#include <queue>
#endif
typedef void (*debug_log_t)(int level, const char *fmt, ...);
#define GAMEAI_MAX_BUTTONS 16
#define GAMEAI_SKIPFRAMES 4
#ifdef __cplusplus
class GameAI {
public:
virtual void Init(void * ram_ptr, int ram_size) {};
virtual void Think(bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format) {};
void SetShowDebug(const bool show){ this->showDebug = show; };
void SetDebugLog(debug_log_t func){debugLogFunc = func;};
private:
bool showDebug;
debug_log_t debugLogFunc;
};
#endif
typedef void * (*create_game_ai_t)(const char *);
typedef void (*destroy_game_ai_t)(void * obj_ptr);
typedef void (*game_ai_lib_init_t)(void * obj_ptr, void * ram_ptr, int ram_size);
typedef void (*game_ai_lib_think_t)(void * obj_ptr, bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format);
typedef void (*game_ai_lib_set_show_debug_t)(void * obj_ptr, const bool show);
typedef void (*game_ai_lib_set_debug_log_t)(void * obj_ptr, debug_log_t func);

215
deps/game_ai_lib/GameAILocal.cpp vendored Normal file
View File

@ -0,0 +1,215 @@
#include <stdexcept>
#include "GameAI.h"
#include "./games/NHL94GameAI.h"
#include "./games/DefaultGameAI.h"
#if _WIN32
#define DllExport __declspec( dllexport )
#else
#define DllExport
#endif
//=======================================================
// C API
//=======================================================
extern "C" DllExport void game_ai_lib_init(void * obj_ptr, void * ram_ptr, int ram_size)
{
if (obj_ptr)
static_cast<GameAI*>(obj_ptr)->Init(ram_ptr, ram_size);
}
extern "C" DllExport void game_ai_lib_think(void * obj_ptr,bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format)
{
if (obj_ptr)
static_cast<GameAI*>(obj_ptr)->Think(buttons, player, frame_data, frame_width, frame_height, frame_pitch, pixel_format);
}
extern "C" DllExport void game_ai_lib_set_show_debug(void * obj_ptr,const bool show)
{
if (obj_ptr)
static_cast<GameAI*>(obj_ptr)->SetShowDebug(show);
}
extern "C" DllExport void game_ai_lib_set_debug_log(void * obj_ptr,debug_log_t func)
{
if (obj_ptr)
static_cast<GameAI*>(obj_ptr)->SetDebugLog(func);
}
extern "C" DllExport void * create_game_ai(const char * name)
{
std::filesystem::path path = name;
std::string game_name = path.parent_path().filename().string();
GameAILocal * ptr = nullptr;
if(game_name == "NHL941on1-Genesis")
{
ptr = new NHL94GameAI();
}
else
{
ptr = new DefaultGameAI();
}
if (ptr)
{
ptr->full_path = path.string();
ptr->dir_path = path.parent_path().string();
ptr->game_name = game_name;
ptr->DebugPrint("CreateGameAI");
ptr->DebugPrint(name);
ptr->DebugPrint(game_name.c_str());
}
return (void *) ptr;
}
extern "C" DllExport void destroy_game_ai(void * obj_ptr)
{
if (obj_ptr)
{
GameAILocal * gaLocal = nullptr;
gaLocal = static_cast<GameAILocal*>(obj_ptr);
delete gaLocal;
}
}
//=======================================================
// GameAILocal::InitRAM
//=======================================================
void GameAILocal::InitRAM(void * ram_ptr, int ram_size)
{
std::filesystem::path memDataPath = dir_path;
memDataPath += "/data.json";
//retro_data.load()
//std::cout << memDataPath << std::endl;
retro_data.load(memDataPath.string());
Retro::AddressSpace* m_addressSpace = nullptr;
m_addressSpace = &retro_data.addressSpace();
m_addressSpace->reset();
//Retro::configureData(data, m_core);
//reconfigureAddressSpace();
retro_data.addressSpace().setOverlay(Retro::MemoryOverlay{ '=', '>', 2 });
m_addressSpace->addBlock(16711680, ram_size, ram_ptr);
std::cout << "RAM size:" << ram_size << std::endl;
std::cout << "RAM ptr:" << ram_ptr << std::endl;
}
//=======================================================
// GameAILocal::LoadConfig_Player
//=======================================================
void GameAILocal::LoadConfig_Player(const nlohmann::detail::iter_impl<const nlohmann::json> &player)
{
for (auto var = player->cbegin(); var != player->cend(); ++var)
{
if(var.key() == "models")
{
for (auto model = var.value().cbegin(); model != var.value().cend(); ++model)
{
std::filesystem::path modelPath = dir_path;
modelPath += "/";
modelPath += model.value().get<std::string>();
RetroModel * load_model = this->LoadModel(modelPath.string().c_str());
if (models.count(model.key()) == 0)
{
models.insert(std::pair<std::string, RetroModel*>(model.key(), load_model));
}
}
}
}
}
//=======================================================
// GameAILocal::LoadConfig
//=======================================================
void GameAILocal::LoadConfig()
{
DebugPrint("GameAILocal::LoadConfig()");
std::filesystem::path configPath = dir_path;
configPath += "/config.json";
DebugPrint(configPath.string().c_str());
std::ifstream file;
try {
file.open(configPath);
//std::cout << file.rdbuf();
//std::cerr << "Error: " << strerror(errno);
//std::cout << file.get();
}
catch (std::exception & e){
DebugPrint("Error opening config file");
DebugPrint(e.what());
}
//file.clear();
//file.seekg(0, std::ios::beg);
using nlohmann::json;
json manifest;
try {
file >> manifest;
} catch (json::exception& e) {
DebugPrint("Error Loading config");
DebugPrint(e.what());
return;
}
const auto& p1 = const_cast<const json&>(manifest).find("p1");
if (p1 == manifest.cend())
{
DebugPrint("Error Loading config, no p1");
return;
}
LoadConfig_Player(p1);
const auto& p2 = const_cast<const json&>(manifest).find("p2");
if (p2 == manifest.cend())
{
DebugPrint("Error Loading config, no p1");
return;
}
LoadConfig_Player(p2);
}
//=======================================================
// GameAILocal::LoadModel
//=======================================================
RetroModel * GameAILocal::LoadModel(const char * path)
{
RetroModelPytorch * model = new RetroModelPytorch();
model->LoadModel(std::string(path));
return dynamic_cast<RetroModel*>(model);
}
//=======================================================
// GameAILocal::DebugPrint
//=======================================================
void GameAILocal::DebugPrint(const char * msg)
{
std::cout << msg << std::endl;
if (showDebug && debugLogFunc)
{
std::cout << msg << std::endl;
debugLogFunc(0, msg);
}
}

45
deps/game_ai_lib/GameAILocal.h vendored Normal file
View File

@ -0,0 +1,45 @@
#pragma once
#include "GameAI.h"
#include "RetroModel.h"
#include <bitset>
#include <string>
#include <filesystem>
#include <vector>
#include <queue>
#include "utils/data.h"
#include "./utils/json.hpp"
class GameAILocal : public GameAI {
public:
GameAILocal():showDebug(false),
debugLogFunc(nullptr){};
RetroModel * LoadModel(const char * path);
void SetShowDebug(const bool show){ this->showDebug = show; };
void SetDebugLog(debug_log_t func){debugLogFunc = func;};
void DebugPrint(const char * msg);
protected:
void InitRAM(void * ram_ptr, int ram_size);
void LoadConfig();
void LoadConfig_Player(const nlohmann::detail::iter_impl<const nlohmann::json> &player);
bool showDebug;
debug_log_t debugLogFunc;
Retro::GameData retro_data;
std::map<std::string, RetroModel*> models;
public:
std::string full_path;
std::string dir_path;
std::string game_name;
};

64
deps/game_ai_lib/README.md vendored Normal file
View File

@ -0,0 +1,64 @@
# stable-retro lib
Library to be used with emulator frontends (such as RetroArch) to enable ML models to overide player input.
Warning: Still in early prototype version
## Build for Linux
```
sudo apt update
sudo apt install git cmake unzip libqt5opengl5-dev qtbase5-dev zlib1g-dev python3 python3-pip build-essential libopencv-dev
```
```
git clone https://github.com/MatPoliquin/stable-retro-scripts.git
```
Download pytorch C++ lib:
```
cd stable-retro-scripts/ef_lib/
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.3.1%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.3.1+cpu.zip
```
Generate makefiles and compile
```
cmake . -DCMAKE_PREFIX_PATH=./libtorch
make
```
## Build for Windows
Clone stable-retro-scripts repo
Download pytorch C++ lib for Windows:
```
wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.3.1%2Bcpu.zip -o libtorch_win.zip
Expand-Archive libtorch_win.zip
```
Note: 2.3.1 might have missing intel MLK dll issue:
https://github.com/pytorch/pytorch/issues/124009
So you can use nightly build instead and it fixes the issue:
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-win-shared-with-deps-latest.zip -o libtorch_win.zip
Download and Extract OpenCV for Windows:
```
https://sourceforge.net/projects/opencvlibrary/files/4.10.0/
```
The DLLs will be found here:
YourOpenCVFolder\opencv\build\x64\vc16\lib
Generate makefiles and compile
```
cd stable-retro-scripts
mkdir build
cd build
cmake .. -DCMAKE_PREFIX_PATH=Absolute\path\to\libtorch_win -DOpenCV_DIR=Absolute\path\to\opencv\build\x64\vc16\lib
cmake --build . --config Release
```
## Test the lib
```
export LD_LIBRARY_PATH=/path/to/game_ai.so
./retroarch
```

94
deps/game_ai_lib/RetroModel.cpp vendored Normal file
View File

@ -0,0 +1,94 @@
#include "RetroModel.h"
//=======================================================
// RetroModelPytorch::LoadModel
//=======================================================
void RetroModelPytorch::LoadModel(std::string path)
{
try {
this->module = torch::jit::load(path);
std::cerr << "LOADED MODEL:!" << path << std::endl;
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}
}
//=======================================================
// RetroModelPytorch::Forward
//=======================================================
void RetroModelPytorch::Forward(std::vector<float> & output, const std::vector<float> & input)
{
std::vector<torch::jit::IValue> inputs;
at::Tensor tmp = torch::zeros({1, input.size()});
for(int i=0; i < input.size(); i++)
{
tmp[0][i] = input[i];
}
inputs.push_back(tmp);
at::Tensor result = module.forward(inputs).toTuple()->elements()[0].toTensor();
for(int i=0; i < output.size(); i++)
{
output[i] = result[0][i].item<float>();
}
}
//=======================================================
// RetroModelPytorch::Forward
//=======================================================
void RetroModelPytorch::Forward(std::vector<float> & output, RetroModelFrameData & input)
{
std::vector<torch::jit::IValue> inputs;
cv::Mat image(cv::Size(input.width, input.height), CV_8UC2, input.data);
cv::Mat rgb;
cv::Mat gray;
cv::Mat result;
// add new frame on the stack
cv::Mat * newFrame = input.PushNewFrameOnStack();
// Downsample to 84x84 and turn to greyscale
cv::cvtColor(image, gray, cv::COLOR_BGR5652GRAY);
cv::resize(gray, result, cv::Size(84,84), cv::INTER_AREA);
result.copyTo(*newFrame);
/*cv::namedWindow("Display Image", cv::WINDOW_NORMAL);
cv::imshow("Display Image", result);
cv::waitKey(0);*/
at::Tensor tmp = torch::ones({1, 4, 84, 84});
for(auto i : {0,1,2,3})
{
if(input.stack[i]->data)
tmp[0][3-i] = torch::from_blob(input.stack[i]->data, { result.rows, result.cols }, at::kByte);
}
/*test[0][3] = torch::from_blob(input.stack[0]->data, { result.rows, result.cols }, at::kByte);
if(input.stack[1]->data)
test[0][2] = torch::from_blob(input.stack[1]->data, { result.rows, result.cols }, at::kByte);
if(input.stack[2]->data)
test[0][1] = torch::from_blob(input.stack[2]->data, { result.rows, result.cols }, at::kByte);
if(input.stack[3]->data)
test[0][0] = torch::from_blob(input.stack[3]->data, { result.rows, result.cols }, at::kByte);*/
inputs.push_back(tmp);
// Execute the model and turn its output into a tensor.
torch::jit::IValue ret = module.forward(inputs);
at::Tensor actions = ret.toTuple()->elements()[0].toTensor();
for(int i=0; i < output.size(); i++)
{
output[i] = actions[0][i].item<float>();
}
}

69
deps/game_ai_lib/RetroModel.h vendored Normal file
View File

@ -0,0 +1,69 @@
#pragma once
#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <bitset>
#include <string>
#include <filesystem>
#include <vector>
#include <queue>
class RetroModelFrameData
{
public:
RetroModelFrameData(): data(nullptr)
{
stack[0] = new cv::Mat;
stack[1] = new cv::Mat;
stack[2] = new cv::Mat;
stack[3] = new cv::Mat;
}
~RetroModelFrameData()
{
if(stack[0]) delete stack[0];
if(stack[1]) delete stack[1];
if(stack[2]) delete stack[2];
if(stack[3]) delete stack[3];
}
cv::Mat * PushNewFrameOnStack()
{
//push everything down
cv::Mat * tmp = stack[3];
stack[3] = stack[2];
stack[2] = stack[1];
stack[1] = stack[0];
stack[0] = tmp;
return stack[0];
}
void *data;
unsigned int width;
unsigned int height;
unsigned int pitch;
unsigned int format;
cv::Mat * stack[4];
};
class RetroModel {
public:
virtual void Forward(std::vector<float> & output, const std::vector<float> & input)=0;
virtual void Forward(std::vector<float> & output, RetroModelFrameData & input)=0;
};
class RetroModelPytorch : public RetroModel {
public:
virtual void LoadModel(std::string);
virtual void Forward(std::vector<float> & output, const std::vector<float> & input);
virtual void Forward(std::vector<float> & output, RetroModelFrameData & input);
private:
torch::jit::script::Module module;
};

View File

@ -0,0 +1,55 @@
#include "DefaultGameAI.h"
#include <cstdlib>
#include <iostream>
#include <assert.h>
#include <random>
enum DefaultButtons {
INPUT_B = 0,
INPUT_A = 1,
INPUT_MODE = 2,
INPUT_START = 3,
INPUT_UP = 4,
INPUT_DOWN = 5,
INPUT_LEFT = 6,
INPUT_RIGHT = 7,
INPUT_C = 8,
INPUT_Y = 9,
INPUT_X = 10,
INPUT_Z = 11,
INPUT_MAX = 12
};
//=======================================================
// DefaultGameAI::Init
//=======================================================
void DefaultGameAI::Init(void * ram_ptr, int ram_size)
{
LoadConfig();
InitRAM(ram_ptr, ram_size);
}
//=======================================================
// DefaultGameAI::Think
//=======================================================
void DefaultGameAI::Think(bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format)
{
std::vector<float> output(DefaultButtons::INPUT_MAX);
input.data = (void *) frame_data;
input.width = frame_width;
input.height = frame_height;
input.pitch = frame_pitch;
input.format = pixel_format;
models["Model"]->Forward(output, input);
for (int i=0; i < output.size(); i++)
{
buttons[i] = output[i] >= 1.0 ? 1 : 0;
}
buttons[DefaultButtons::INPUT_START] = 0;
buttons[DefaultButtons::INPUT_MODE] = 0;
}

16
deps/game_ai_lib/games/DefaultGameAI.h vendored Normal file
View File

@ -0,0 +1,16 @@
#pragma once
#include "../GameAILocal.h"
#include "memory.h"
#include "../utils/data.h"
class DefaultGameAI : public GameAILocal {
public:
virtual void Init(void * ram_ptr, int ram_size);
virtual void Think(bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format);
private:
RetroModel * model;
RetroModelFrameData input;
};

225
deps/game_ai_lib/games/NHL94GameAI.cpp vendored Normal file
View File

@ -0,0 +1,225 @@
#include "NHL94GameAI.h"
#include <cstdlib>
#include <iostream>
#include <assert.h>
#include <random>
//=======================================================
// NHL94GameAI::Init
//=======================================================
void NHL94GameAI::Init(void * ram_ptr, int ram_size)
{
LoadConfig();
InitRAM(ram_ptr, ram_size);
static_assert(NHL94NeuralNetInput::NN_INPUT_MAX == 16);
isShooting = false;
}
//=======================================================
// NHL94GameAI::SetModelInputs
//=======================================================
void NHL94GameAI::SetModelInputs(std::vector<float> & input, const NHL94Data & data)
{
// players
input[NHL94NeuralNetInput::P1_X] = (float)data.p1_x / (float) NHL94NeuralNetInput::MAX_PLAYER_X;
input[NHL94NeuralNetInput::P1_Y] = (float)data.p1_y / (float) NHL94NeuralNetInput::MAX_PLAYER_Y;
input[NHL94NeuralNetInput::P2_X] = (float)data.p2_x / (float) NHL94NeuralNetInput::MAX_PLAYER_X;
input[NHL94NeuralNetInput::P2_Y] = (float) data.p2_y / (float) NHL94NeuralNetInput::MAX_PLAYER_Y;
input[NHL94NeuralNetInput::G2_X] = (float) data.g2_x / (float) NHL94NeuralNetInput::MAX_PLAYER_X;
input[NHL94NeuralNetInput::G2_Y] = (float) data.g2_y / (float) NHL94NeuralNetInput::MAX_PLAYER_Y;
input[NHL94NeuralNetInput::P1_VEL_X] = (float) data.p1_vel_x / (float) NHL94NeuralNetInput::MAX_VEL_XY;
input[NHL94NeuralNetInput::P1_VEL_Y] = (float) data.p1_vel_y / (float) NHL94NeuralNetInput::MAX_VEL_XY;
input[NHL94NeuralNetInput::P2_VEL_X] = (float) data.p2_vel_x / (float) NHL94NeuralNetInput::MAX_VEL_XY;
input[NHL94NeuralNetInput::P2_VEL_Y] = (float) data.p2_vel_y / (float) NHL94NeuralNetInput::MAX_VEL_XY;
// puck
input[NHL94NeuralNetInput::PUCK_X] = (float) data.puck_x / (float) NHL94NeuralNetInput::MAX_PLAYER_X;
input[NHL94NeuralNetInput::PUCK_Y] = (float) data.puck_y / (float) NHL94NeuralNetInput::MAX_PLAYER_Y;
input[NHL94NeuralNetInput::PUCK_VEL_X] = (float) data.puck_vel_x / (float) NHL94NeuralNetInput::MAX_VEL_XY;
input[NHL94NeuralNetInput::PUCK_VEL_Y] = (float) data.puck_vel_y / (float) NHL94NeuralNetInput::MAX_VEL_XY;
input[NHL94NeuralNetInput::P1_HASPUCK] = data.p1_haspuck ? 0.0 : 1.0;
input[NHL94NeuralNetInput::G1_HASPUCK] = data.g1_haspuck ? 0.0 : 1.0;
}
//=======================================================
// NHL94GameAI::GotoTarget
//=======================================================
void NHL94GameAI::GotoTarget(std::vector<float> & input, int vec_x, int vec_y)
{
if (vec_x > 0)
input[NHL94Buttons::INPUT_LEFT] = 1;
else
input[NHL94Buttons::INPUT_RIGHT] = 1;
if (vec_y > 0)
input[NHL94Buttons::INPUT_DOWN] = 1;
else
input[NHL94Buttons::INPUT_UP] = 1;
}
//=======================================================
// isInsideAttackZone
//=======================================================
static bool isInsideAttackZone(NHL94Data & data)
{
if (data.attack_zone_y > 0 && data.p1_y >= data.attack_zone_y)
{
return true;
}
else if (data.attack_zone_y < 0 && data.p1_y <= data.attack_zone_y)
{
return true;
}
return false;
}
//=======================================================
// isInsideScoreZone
//=======================================================
static bool isInsideScoreZone(NHL94Data & data)
{
if (data.p1_y < data.score_zone_top && data.p1_y > data.score_zone_bottom)
{
return true;
}
return false;
}
//=======================================================
// isInsideDefenseZone
//=======================================================
static bool isInsideDefenseZone(NHL94Data & data)
{
if (data.defense_zone_y > 0 && data.p1_y >= data.defense_zone_y)
{
return true;
}
else if (data.defense_zone_y < 0 && data.p1_y <= data.defense_zone_y)
{
return true;
}
return false;
}
//=======================================================
// NHL94GameAI::Think
//=======================================================
void NHL94GameAI::Think(bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format)
{
NHL94Data data;
data.Init(retro_data);
if(player == 1)
{
data.Flip();
if(data.period % 2 == 0)
{
data.FlipZones();
}
}
else if (player == 0)
{
if(data.period % 2 == 1)
{
data.FlipZones();
}
}
std::vector<float> input(16);
std::vector<float> output(12);
this->SetModelInputs(input, data);
if (data.p1_haspuck)
{
DebugPrint("have puck");
if (isInsideAttackZone(data))
{
DebugPrint(" in attackzone");
models["ScoreGoal"]->Forward(output, input);
output[NHL94Buttons::INPUT_C] = 0;
output[NHL94Buttons::INPUT_B] = 0;
if (isInsideScoreZone(data))
{
if (data.p1_vel_x >= 30 && data.puck_x > -23 && data.puck_x < 0)
{
DebugPrint("Shoot");
output[NHL94Buttons::INPUT_C] = 1;
isShooting = true;
}
else if(data.p1_vel_x <= -30 && data.puck_x < 23 && data.puck_x > 0)
{
DebugPrint("Shoot");
output[NHL94Buttons::INPUT_C] = 1;
isShooting = true;
}
}
}
else
{
this->GotoTarget(output, data.p1_x, -data.attack_zone_y);
}
}
else if (data.g1_haspuck)
{
if (rand() > (RAND_MAX / 2))
output[NHL94Buttons::INPUT_B] = 1;
}
else
{
DebugPrint("Don't have puck");
isShooting = false;
if (isInsideDefenseZone(data) && data.p2_haspuck)
{
DebugPrint(" DefenseModel->Forward");
models["DefenseZone"]->Forward(output, input);
}
else
{
DebugPrint(" GOTO TARGET");
GotoTarget(output, data.p1_x - data.puck_x, data.p1_y - data.puck_y);
}
if (isShooting)
{
//output[NHL94Buttons::INPUT_MODE] = 1;
DebugPrint("Shooting");
output[NHL94Buttons::INPUT_C] = 1;
}
}
assert(output.size() <= 16);
for (int i=0; i < output.size(); i++)
{
buttons[i] = output[i] >= 1.0 ? 1 : 0;
}
buttons[NHL94Buttons::INPUT_START] = 0;
buttons[NHL94Buttons::INPUT_MODE] = 0;
buttons[NHL94Buttons::INPUT_A] = 0;
//buttons[NHL94Buttons::INPUT_B] = 0;
//buttons[NHL94Buttons::INPUT_C] = 0;
buttons[NHL94Buttons::INPUT_X] = 0;
buttons[NHL94Buttons::INPUT_Y] = 0;
buttons[NHL94Buttons::INPUT_Z] = 0;
//Flip directions
if(data.period % 2 != player)
{
std::swap(buttons[NHL94Buttons::INPUT_UP], buttons[NHL94Buttons::INPUT_DOWN]);
std::swap(buttons[NHL94Buttons::INPUT_LEFT], buttons[NHL94Buttons::INPUT_RIGHT]);
}
}

16
deps/game_ai_lib/games/NHL94GameAI.h vendored Normal file
View File

@ -0,0 +1,16 @@
#pragma once
#include "../GameAILocal.h"
#include "NHL94GameData.h"
class NHL94GameAI : public GameAILocal {
public:
virtual void Init(void * ram_ptr, int ram_size);
virtual void Think(bool buttons[GAMEAI_MAX_BUTTONS], int player, const void *frame_data, unsigned int frame_width, unsigned int frame_height, unsigned int frame_pitch, unsigned int pixel_format);
void SetModelInputs(std::vector<float> & input, const NHL94Data & data);
void GotoTarget(std::vector<float> & input, int vec_x, int vec_y);
private:
bool isShooting;
};

108
deps/game_ai_lib/games/NHL94GameData.cpp vendored Normal file
View File

@ -0,0 +1,108 @@
#include "NHL94GameData.h"
//=======================================================
// NHL94Data::Init
//=======================================================
void NHL94Data::Init(const Retro::GameData & data)
{
// players
p1_x = data.lookupValue("p1_x").cast<int>();
p1_y = data.lookupValue("p1_y").cast<int>();
p2_x = data.lookupValue("p2_x").cast<int>();
p2_y = data.lookupValue("p2_y").cast<int>();
p1_vel_x = data.lookupValue("p1_vel_x").cast<int>();
p1_vel_y = data.lookupValue("p1_vel_y").cast<int>();
p2_vel_x = data.lookupValue("p2_vel_x").cast<int>();
p2_vel_y = data.lookupValue("p2_vel_y").cast<int>();
// goalies
g1_x = data.lookupValue("g1_x").cast<int>();
g1_y = data.lookupValue("g1_y").cast<int>();
g2_x = data.lookupValue("g2_x").cast<int>();
g2_y = data.lookupValue("g2_y").cast<int>();
// puck
puck_x = data.lookupValue("puck_x").cast<int>();
puck_y = data.lookupValue("puck_y").cast<int>();
puck_vel_x = data.lookupValue("puck_vel_x").cast<int>();
puck_vel_y = data.lookupValue("puck_vel_y").cast<int>();
p1_fullstar_x = data.lookupValue("fullstar_x").cast<int>();
p1_fullstar_y = data.lookupValue("fullstar_y").cast<int>();
p2_fullstar_x = data.lookupValue("p2_fullstar_x").cast<int>();
p2_fullstar_y = data.lookupValue("p2_fullstar_y").cast<int>();
period = data.lookupValue("period").cast<int>();
// Knowing if the player has the puck is tricky since the fullstar in the game is not aligned with the player every frame
// There is an offset of up to 2 sometimes
if (std::abs(p1_x - p1_fullstar_x) < 3 && std::abs(p1_y - p1_fullstar_y) < 3)
p1_haspuck = true;
else
p1_haspuck = false;
if(std::abs(p2_x - p1_fullstar_x) < 3 && std::abs(p2_y - p1_fullstar_y) < 3)
p2_haspuck = true;
else
p2_haspuck = false;
if(std::abs(g1_x - p1_fullstar_x) < 3 && std::abs(g1_y - p1_fullstar_y) < 3)
g1_haspuck = true;
else
g1_haspuck = false;
if(std::abs(g2_x - p1_fullstar_x) < 3 && std::abs(g2_y - p1_fullstar_y) < 3)
g2_haspuck = true;
else
g2_haspuck = false;
attack_zone_y = NHL94Const::ATACKZONE_POS_Y;
defense_zone_y = NHL94Const::DEFENSEZONE_POS_Y;
score_zone_top = NHL94Const::SCORE_ZONE_TOP;
score_zone_bottom = NHL94Const::SCORE_ZONE_BOTTOM;
}
//=======================================================
// NHL94Data::Flip
//=======================================================
void NHL94Data::Flip()
{
std::swap(p1_x, p2_x);
std::swap(p1_y, p2_y);
std::swap(g1_x, g2_x);
std::swap(g1_y, g2_y);
std::swap(p1_haspuck, p2_haspuck);
std::swap(g1_haspuck, g2_haspuck);
std::swap(p1_vel_x, p2_vel_x);
std::swap(p1_vel_y, p2_vel_y);
}
//=======================================================
// NHL94Data::FlipZones
//=======================================================
void NHL94Data::FlipZones()
{
p1_x = -p1_x;
p1_y = -p1_y;
p2_x = -p2_x;
p2_y = -p2_y;
g1_x = -g1_x;
g1_y = -g1_y;
g2_x = -g2_x;
g2_y = -g2_y;
p1_vel_x = -p1_vel_x;
p1_vel_y = -p1_vel_y;
p2_vel_x = -p2_vel_x;
p2_vel_y = -p2_vel_y;
puck_x = -puck_x;
puck_y = -puck_y;
puck_vel_x = -puck_vel_x;
puck_vel_y = -puck_vel_y;
}

100
deps/game_ai_lib/games/NHL94GameData.h vendored Normal file
View File

@ -0,0 +1,100 @@
#pragma once
#include "memory.h"
#include "../utils/data.h"
enum NHL94Buttons {
INPUT_B = 0,
INPUT_A = 1,
INPUT_MODE = 2,
INPUT_START = 3,
INPUT_UP = 4,
INPUT_DOWN = 5,
INPUT_LEFT = 6,
INPUT_RIGHT = 7,
INPUT_C = 8,
INPUT_Y = 9,
INPUT_X = 10,
INPUT_Z = 11,
INPUT_MAX = 12
};
enum NHL94NeuralNetInput {
P1_X = 0,
P1_Y,
P1_VEL_X,
P1_VEL_Y,
P2_X,
P2_Y,
P2_VEL_X,
P2_VEL_Y,
PUCK_X,
PUCK_Y,
PUCK_VEL_X,
PUCK_VEL_Y,
G2_X,
G2_Y,
P1_HASPUCK,
G1_HASPUCK,
NN_INPUT_MAX,
// Used for normalization
MAX_PLAYER_X = 120,
MAX_PLAYER_Y = 270,
MAX_PUCK_X = 130,
MAX_PUCK_Y = 270,
MAX_VEL_XY = 50
};
enum NHL94Const {
ATACKZONE_POS_Y = 100,
DEFENSEZONE_POS_Y = -80,
SCORE_ZONE_TOP = 230,
SCORE_ZONE_BOTTOM = 210,
};
class NHL94Data {
public:
int p1_x;
int p1_y;
int p2_x;
int p2_y;
int p1_vel_x;
int p1_vel_y;
int p2_vel_x;
int p2_vel_y;
int g1_x;
int g1_y;
int g2_x;
int g2_y;
int puck_x;
int puck_y;
int puck_vel_x;
int puck_vel_y;
int p1_fullstar_x;
int p1_fullstar_y;
int p2_fullstar_x;
int p2_fullstar_y;
bool p1_haspuck;
bool g1_haspuck;
bool p2_haspuck;
bool g2_haspuck;
int attack_zone_y;
int defense_zone_y;
int score_zone_top;
int score_zone_bottom;
int period;
void Init(const Retro::GameData & data);
void Flip();
void FlipZones();
};

202
deps/game_ai_lib/test.cpp vendored Normal file
View File

@ -0,0 +1,202 @@
// test of game ai dynamic lib
#include <iostream>
#include <assert.h>
#include <filesystem>
#include <stdexcept>
#include <opencv2/opencv.hpp>
#include <torch/script.h>
#include "GameAI.h"
#include "RetroModel.h"
#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#if 0
/*
#include "onnxruntime_cxx_api.h"
void Test_ONNX()
{
// Load the model and create InferenceSession
Ort::Env env;
std::string model_path = "path/to/your/onnx/model";
Ort::Session session(env, model_path, Ort::SessionOptions{ nullptr });
// Load and preprocess the input image to inputTensor
...
// Run inference
std::vector outputTensors =
session.Run(Ort::RunOptions{nullptr}, inputNames.data(), &inputTensor,
inputNames.size(), outputNames.data(), outputNames.size());
const float* outputDataPtr = outputTensors[0].GetTensorMutableData();
std::cout << outputDataPtr[0] << std::endl;
}*/
void Test_Resnet()
{
torch::jit::script::Module module;
try {
module = torch::jit::load("/home/mat/github/stable-retro-scripts/traced_resnet_model.pt");
//module = torch::jit::load("/home/mat/github/stable-retro-scripts/model.pt");
std::cerr << "SUCCESS!\n";
module.eval();
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
//inputs.push_back(torch::ones({1, 4, 84, 84}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
//std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return;
}
}
#endif
//=======================================================
// test_opencv
//=======================================================
void test_opencv(std::map<std::string, bool> & tests)
{
cv::Mat image;
cv::Mat grey;
cv::Mat result;
image = cv::imread( "../screenshots/wwf.png", cv::IMREAD_COLOR );
cv::cvtColor(image, grey, cv::COLOR_RGB2GRAY);
cv::resize(grey, result, cv::Size(84,84), cv::INTER_AREA);
if ( !image.data )
{
printf("No image data \n");
return;
}
cv::namedWindow("Display Image", cv::WINDOW_NORMAL);
cv::imshow("Display Image", result);
cv::waitKey(1000);
tests["OPENCV GRAYSCALE DOWNSAMPLE TO 84x84"] = true;
}
//=======================================================
// test_loadlibrary
//=======================================================
void test_loadlibrary(std::map<std::string, bool> & tests)
{
GameAI * ga = nullptr;
create_game_ai_t func = nullptr;
#ifdef _WIN32
HINSTANCE hinstLib;
BOOL fFreeResult, fRunTimeLinkSuccess = FALSE;
hinstLib = LoadLibrary(TEXT("game_ai.dll"));
if (hinstLib != NULL)
{
tests["LOAD LIBRARY"] = true;
func = (create_game_ai_t) GetProcAddress(hinstLib, "create_game_ai");
}
#else
void *myso = dlopen("./libgame_ai.so", RTLD_NOW);
//std::cout << dlerror() << std::endl;
if(myso)
{
tests["LOAD LIBRARY"] = true;
func = reinterpret_cast<create_game_ai_t>(dlsym(myso, "create_game_ai"));
}
#endif
if(func)
{
tests["GET CREATEGAME FUNC"] = true;
ga = (GameAI *) func("./data/NHL941on1-Genesis/NHL941on1.md");
if(ga)
tests["CREATEGAME FUNC"] = true;
}
#ifdef _WIN32
fFreeResult = FreeLibrary(hinstLib);
#endif
}
//=======================================================
// test_pytorch
//=======================================================
void test_pytorch(std::map<std::string, bool> & tests)
{
try {
RetroModelPytorch * model = new RetroModelPytorch();
model->LoadModel(std::string("./data/NHL941on1-Genesis/ScoreGoal.pt"));
std::vector<float> input(16);
std::vector<float> output(12);
model->Forward(output, input);
//TODO validate output
tests["LOAD PYTORCH MODEL"] = true;
}
catch (const c10::Error& e) {
//std::cerr << "error loading the model\n";
throw std::runtime_error ("error loading the model\n");
return;
}
}
int main()
{
std::map<std::string, bool> tests;
tests.insert(std::pair<std::string, bool>("LOAD LIBRARY",false));
tests.insert(std::pair<std::string, bool>("GET CREATEGAME FUNC",false));
tests.insert(std::pair<std::string, bool>("CREATEGAME FUNC",false));
tests.insert(std::pair<std::string, bool>("OPENCV GRAYSCALE DOWNSAMPLE TO 84x84",false));
tests.insert(std::pair<std::string, bool>("LOAD PYTORCH MODEL",false));
std::cout << "========== RUNNING TESTS ==========" << std::endl;
try {
test_loadlibrary(tests);
test_opencv(tests);
test_pytorch(tests);
}
catch (std::exception &e) {
std::cout << "============= EXCEPTION =============" << std::endl;
std::cout << e.what();
}
std::cout << "============== RESULTS =============" << std::endl;
for(auto i: tests)
{
const char * result = i.second ? "PASS" : "FAIL";
std::cout << i.first << "..." << result << std::endl;
}
return 0;
}

405
deps/game_ai_lib/utils/data.cpp vendored Normal file
View File

@ -0,0 +1,405 @@
// Adapted from OpenAI's retro source code:
// https://github.com/openai/retro
#include "data.h"
//#include "script.h"
#include "utils.h"
#ifdef ERROR
#undef ERROR
#endif
#include "json.hpp"
#include <fstream>
using namespace Retro;
using namespace std;
using nlohmann::json;
static string s_dataDirectory;
template<typename T>
T find(json::const_reference j, const string& key) {
const auto& iter = j.find(key);
if (iter == j.end()) {
return T();
}
try {
T t = *iter;
return t;
} catch (json::exception&) {
return T();
}
}
static void setActions(const vector<string>& buttonList, const vector<vector<vector<string>>>& actionsIn, map<int, set<int>>& actions) {
actions.clear();
for (const auto& outer : actionsIn) {
set<int> sublist;
int mask = 0;
for (const auto& middle : outer) {
int buttons = 0;
for (const auto& button : middle) {
const auto& iter = find(buttonList.begin(), buttonList.end(), button);
buttons |= 1 << (iter - buttonList.begin());
}
mask |= buttons;
sublist.insert(buttons);
}
actions.emplace(mask, move(sublist));
}
}
static unsigned filterAction(unsigned action, const map<int, set<int>>& actions) {
unsigned newAction = 0;
for (const auto& actionSet : actions) {
unsigned maskedAction = action & actionSet.first;
if (actionSet.second.find(maskedAction) != actionSet.second.end()) {
newAction |= maskedAction;
}
}
return newAction;
}
Variable::Variable(const DataType& type, size_t address, uint64_t mask)
: type(type)
, address(address)
, mask(mask) {
}
bool Variable::operator==(const Variable& other) const {
return type == other.type && address == other.address && mask == other.mask;
}
bool GameData::load(const string& filename) {
ifstream file(filename);
return load(&file);
}
bool GameData::load(istream* file) {
json manifest;
try {
*file >> manifest;
} catch (json::exception&) {
return false;
}
const auto& info = const_cast<const json&>(manifest).find("info");
if (info == manifest.cend()) {
return false;
}
unordered_map<std::string, Variable> oldVars;
oldVars.swap(m_vars);
for (auto var = info->cbegin(); var != info->cend(); ++var) {
if (var->find("address") == var->cend() || var->find("type") == var->cend()) {
oldVars.swap(m_vars);
return false;
}
string dtype = var->at("type");
if (dtype.size() < 3) {
continue;
}
try {
Variable v(dtype, var->at("address"), var->value("mask", UINT64_MAX));
setVariable(var.key(), v);
} catch (std::out_of_range) {
continue;
}
}
return true;
}
bool GameData::save(const string& filename) const {
ofstream file(filename);
return save(&file);
}
bool GameData::save(ostream* file) const {
json manifest;
json info;
for (const auto& var : m_vars) {
json jvar;
jvar["address"] = var.second.address;
jvar["type"] = var.second.type.type;
if (var.second.mask != UINT64_MAX) {
jvar["mask"] = var.second.mask;
}
info[var.first] = jvar;
}
manifest["info"] = info;
try {
file->width(2);
*file << manifest;
*file << endl;
} catch (json::exception&) {
return false;
}
return true;
}
string GameData::dataPath(const string& hint) {
if (s_dataDirectory.size()) {
return s_dataDirectory;
}
const char* envDir = getenv("RETRO_DATA_PATH");
if (envDir) {
s_dataDirectory = envDir;
} else {
s_dataDirectory = drillUp({ "retro/data", "data" }, ".", hint);
}
return s_dataDirectory;
}
void GameData::reset() {
restart();
m_lastMem.reset();
m_cloneMem.reset();
m_vars.clear();
//m_searches.clear();
m_searchOldMem.clear();
}
void GameData::restart() {
m_customVars.clear();
}
void GameData::updateRam() {
m_lastMem = move(m_cloneMem);
m_cloneMem.clone(m_mem);
}
void GameData::setTypes(const vector<DataType> types) {
m_types = vector<DataType>(types);
}
void GameData::setButtons(const vector<string>& buttons) {
m_buttons = buttons;
}
vector<string> GameData::buttons() const {
return m_buttons;
}
void GameData::setActions(const vector<vector<vector<string>>>& actions) {
::setActions(m_buttons, actions, m_actions);
}
map<int, set<int>> GameData::validActions() const {
return m_actions;
}
unsigned GameData::filterAction(unsigned action) const {
return ::filterAction(action, m_actions);
}
Datum GameData::lookupValue(const string& name) {
auto variant = m_customVars.find(name);
if (variant != m_customVars.end()) {
return Datum(variant->second.get());
}
auto v = m_vars.find(name);
if (v == m_vars.end()) {
throw invalid_argument(name);
}
return m_mem[v->second];
}
Variant GameData::lookupValue(const string& name) const {
auto variant = m_customVars.find(name);
if (variant != m_customVars.end()) {
return *variant->second;
}
auto v = m_vars.find(name);
if (v == m_vars.end()) {
throw invalid_argument(name);
}
return m_mem[v->second];
}
/*Datum GameData::lookupValue(const TypedSearchResult& result) {
return m_mem[Variable{ result.type, result.address }];
}
int64_t GameData::lookupValue(const TypedSearchResult& result) const {
return m_mem[Variable{ result.type, result.address }];
}*/
int64_t GameData::lookupDelta(const string& name) const {
const auto& v = m_vars.find(name);
if (v == m_vars.end()) {
return 0;
}
int64_t newVal = m_cloneMem[v->second];
if (!m_lastMem.ok()) {
return 0;
}
int64_t oldVal = m_lastMem[v->second];
return newVal - oldVal;
}
unordered_map<string, Datum> GameData::lookupAll() {
unordered_map<string, Datum> data;
for (auto var = m_vars.cbegin(); var != m_vars.cend(); ++var) {
try {
data.emplace(var->first, m_mem[var->second]);
} catch (...) {
}
}
for (auto var = m_customVars.cbegin(); var != m_customVars.cend(); ++var) {
data.emplace(var->first, var->second.get());
}
return data;
}
unordered_map<string, int64_t> GameData::lookupAll() const {
unordered_map<string, int64_t> data;
for (auto var = m_vars.cbegin(); var != m_vars.cend(); ++var) {
try {
data.emplace(var->first, m_mem[var->second]);
} catch (...) {
}
}
for (auto var = m_customVars.cbegin(); var != m_customVars.cend(); ++var) {
data.emplace(var->first, *var->second);
}
return data;
}
void GameData::setValue(const std::string& name, int64_t v) {
auto variant = m_customVars.find(name);
if (variant != m_customVars.end()) {
*variant->second = v;
return;
}
auto var = m_vars.find(name);
if (var != m_vars.end()) {
m_mem[var->second] = v;
return;
}
m_customVars.emplace(name, std::make_unique<Variant>(v));
}
void GameData::setValue(const std::string& name, const Variant& v) {
auto variant = m_customVars.find(name);
if (variant != m_customVars.end()) {
*variant->second = v;
return;
}
auto var = m_vars.find(name);
if (var != m_vars.end()) {
m_mem[var->second] = v;
return;
}
m_customVars.emplace(name, std::make_unique<Variant>(v));
}
Variable GameData::getVariable(const string& name) const {
const auto& v = m_vars.find(name);
if (v == m_vars.end()) {
throw invalid_argument(name);
}
return v->second;
}
void GameData::setVariable(const string& name, const Variable& var) {
removeVariable(name);
m_vars.emplace(name, var);
}
void GameData::removeVariable(const string& name) {
auto iter = m_vars.find(name);
if (iter != m_vars.end()) {
m_vars.erase(iter);
}
}
unordered_map<string, Variable> GameData::listVariables() const {
return m_vars;
}
size_t GameData::numVariables() const {
return m_vars.size();
}
/*void GameData::search(const std::string& name, int64_t value) {
if (m_searches.find(name) == m_searches.cend()) {
if (m_types.size()) {
m_searches.emplace(name, Search{ m_types });
} else {
m_searches.emplace(name, Search{});
}
}
Search* search = &m_searches[name];
search->search(m_mem, value);
m_searchOldMem[name].clone(m_mem);
}*/
/*void GameData::deltaSearch(const std::string& name, Operation op, int64_t reference) {
if (m_searches.find(name) == m_searches.cend()) {
if (m_types.size()) {
m_searches.emplace(name, Search{ m_types });
} else {
m_searches.emplace(name, Search{});
}
}
if (m_searchOldMem.find(name) == m_searchOldMem.cend()) {
m_searchOldMem[name].clone(m_mem);
}
Search* search = &m_searches[name];
search->delta(m_mem, m_searchOldMem[name], op, reference);
m_searchOldMem[name].clone(m_mem);
}*
size_t GameData::numSearches() const {
return m_searches.size();
}
vector<string> GameData::listSearches() const {
vector<string> names;
for (const auto& search : m_searches) {
names.emplace_back(search.first);
}
return names;
}
Search* GameData::getSearch(const string& name) {
auto iter = m_searches.find(name);
if (iter != m_searches.end()) {
return &iter->second;
}
return nullptr;
}
void GameData::removeSearch(const string& name) {
auto iter = m_searches.find(name);
if (iter != m_searches.end()) {
m_searches.erase(iter);
}
}*/
static const vector<pair<string, Operation>> s_ops{
make_pair("equal", Operation::EQUAL),
make_pair("negative-equal", Operation::NEGATIVE_EQUAL),
make_pair("not-equal", Operation::NOT_EQUAL),
make_pair("less-than", Operation::LESS_THAN),
make_pair("greater-than", Operation::GREATER_THAN),
make_pair("less-or-equal", Operation::LESS_OR_EQUAL),
make_pair("greater-or-equal", Operation::GREATER_OR_EQUAL),
make_pair("less-or-equal", Operation::LESS_OR_EQUAL),
make_pair("nonzero", Operation::NONZERO),
make_pair("zero", Operation::ZERO),
make_pair("negative", Operation::NEGATIVE),
make_pair("positive", Operation::POSITIVE),
make_pair("sign", Operation::SIGN)
};

94
deps/game_ai_lib/utils/data.h vendored Normal file
View File

@ -0,0 +1,94 @@
// Adapted from OpenAI's retro source code:
// https://github.com/openai/retro
#pragma once
//#include "emulator.h"
#include "memory.h"
//#include "search.h"
#include "utils.h"
#include <map>
#include <memory>
#include <set>
#include <unordered_map>
#include <vector>
#ifdef ABSOLUTE
#undef ABSOLUTE
#endif
namespace Retro {
class GameData {
public:
bool load(const std::string& filename);
bool load(std::istream* stream);
bool save(const std::string& filename) const;
bool save(std::ostream* stream) const;
void reset();
void restart();
static std::string dataPath(const std::string& hint = ".");
AddressSpace& addressSpace() { return m_mem; }
const AddressSpace& addressSpace() const { return m_mem; }
void updateRam();
void setTypes(const std::vector<DataType> types);
void setButtons(const std::vector<std::string>& names);
std::vector<std::string> buttons() const;
void setActions(const std::vector<std::vector<std::vector<std::string>>>& actions);
std::map<int, std::set<int>> validActions() const;
unsigned filterAction(unsigned) const;
Datum lookupValue(const std::string& name);
Variant lookupValue(const std::string& name) const;
//Datum lookupValue(const TypedSearchResult&);
//int64_t lookupValue(const TypedSearchResult&) const;
std::unordered_map<std::string, Datum> lookupAll();
std::unordered_map<std::string, int64_t> lookupAll() const;
void setValue(const std::string& name, int64_t);
void setValue(const std::string& name, const Variant&);
int64_t lookupDelta(const std::string& name) const;
Variable getVariable(const std::string& name) const;
void setVariable(const std::string& name, const Variable&);
void removeVariable(const std::string& name);
std::unordered_map<std::string, Variable> listVariables() const;
size_t numVariables() const;
void search(const std::string& name, int64_t value);
void deltaSearch(const std::string& name, Operation op, int64_t reference);
size_t numSearches() const;
std::vector<std::string> listSearches() const;
//Search* getSearch(const std::string& name);
void removeSearch(const std::string& name);
#ifdef USE_CAPNP
bool loadSearches(const std::string& filename);
bool saveSearches(const std::string& filename) const;
#endif
private:
AddressSpace m_mem;
AddressSpace m_cloneMem;
AddressSpace m_lastMem;
std::vector<DataType> m_types;
std::map<int, std::set<int>> m_actions;
std::vector<std::string> m_buttons;
std::unordered_map<std::string, Variable> m_vars;
//std::unordered_map<std::string, Search> m_searches;
std::unordered_map<std::string, AddressSpace> m_searchOldMem;
std::unordered_map<std::string, std::unique_ptr<Variant>> m_customVars;
};
}

17300
deps/game_ai_lib/utils/json.hpp vendored Normal file

File diff suppressed because it is too large Load Diff

521
deps/game_ai_lib/utils/memory.cpp vendored Normal file
View File

@ -0,0 +1,521 @@
// Adapted from OpenAI's retro source code:
// https://github.com/openai/retro
#include "memory.h"
#include <cstdlib>
#include <unordered_map>
#include <stdexcept>
#include <string.h>
using namespace Retro;
using namespace std;
Endian Retro::reduce(Endian e) {
switch (e) {
case Endian::BIG:
case Endian::LITTLE:
case Endian::UNDEF:
case Endian::MIXED_BL:
case Endian::MIXED_LB:
return e;
case Endian::NATIVE:
return Endian::REAL_NATIVE;
case Endian::MIXED_BN:
return Endian::REAL_MIXED_BN;
case Endian::MIXED_LN:
return Endian::REAL_MIXED_LN;
}
return e;
}
bool Retro::reduceCompare(Endian a, Endian b) {
return reduce(a) == reduce(b);
}
DataType::DataType(const char* type)
: width(type[strlen(type) - 1] - '0')
, endian(
type[0] == '=' ? Endian::NATIVE : type[0] == '>' ? (type[1] == '<' ? Endian::MIXED_BL : type[1] == '=' ? Endian::MIXED_BN : Endian::BIG) : type[0] == '<' ? (type[1] == '>' ? Endian::MIXED_LB : type[1] == '=' ? Endian::MIXED_LN : Endian::LITTLE) : Endian::UNDEF)
, repr(static_cast<Repr>(type[strlen(type) - 2]))
, type{ type[0], type[1], type[2], type[3] }
, maskLo(repr == Repr::LN_BCD || repr == Repr::BCD ? 0xF : 0xFF)
, maskHi(repr == Repr::BCD ? 0xF0 : 0x0)
, cvt(repr == Repr::BCD || repr == Repr::LN_BCD ? 10 : 256) {
uint64_t shiftInc =
repr == Repr::BCD ? 100 : repr == Repr::LN_BCD ? 10 : 256;
int baseLoc;
int baseEnd;
int halfLoc = -1;
int diff;
if (width > 8) {
throw std::out_of_range("Invalid DataType width");
}
switch (reduce(endian)) {
case Endian::LITTLE:
default:
baseLoc = 0;
baseEnd = width;
diff = 1;
break;
case Endian::BIG:
baseLoc = width - 1;
baseEnd = -1;
diff = -1;
break;
case Endian::MIXED_LB:
baseLoc = width / 2 - 1;
baseEnd = -1;
halfLoc = width - 1;
diff = -1;
break;
case Endian::MIXED_BL:
baseLoc = width / 2;
baseEnd = width;
halfLoc = 0;
diff = 1;
break;
}
uint64_t baseShift = 1;
for (int i = baseLoc; i != baseEnd; i += diff, baseShift *= shiftInc) {
shift[i] = baseShift;
}
if (halfLoc >= 0) {
for (int i = halfLoc; i != baseLoc; i += diff, baseShift *= shiftInc) {
shift[i] = baseShift;
}
}
}
DataType::DataType(const string& type)
: DataType(type.c_str()) {
}
Datum DataType::operator()(void* base) const {
return Datum(base, *this);
}
Datum DataType::operator()(void* base, size_t offset, const MemoryOverlay& overlay) const {
return Datum(base, offset, *this, overlay);
}
bool DataType::operator==(const DataType& other) const {
return width == other.width && endian == other.endian && repr == other.repr;
}
bool DataType::operator!=(const DataType& other) const {
return !(*this == other);
}
void DataType::encode(void* buffer, int64_t value) const {
for (size_t i = 0; i < width; ++i) {
uint64_t b = (uint64_t) value / shift[i];
b = b % cvt + b / cvt % cvt * (~maskHi + 1);
static_cast<uint8_t*>(buffer)[i] = b;
}
}
int64_t DataType::decode(const void* buffer) const {
int64_t datum = 0;
for (size_t i = 0; i < width; ++i) {
uint8_t b = static_cast<const uint8_t*>(buffer)[i];
datum += ((b & maskLo) % cvt + ((b & maskHi) >> 4) % cvt * 10) * shift[i];
}
if (repr == Repr::SIGNED) {
datum <<= 8 * (8 - width);
datum >>= 8 * (8 - width);
}
return datum;
}
size_t hash<DataType>::operator()(const DataType& type) const {
return hash<uint32_t>()(*reinterpret_cast<const uint32_t*>(type.type));
}
static constexpr char endianTag(Endian e) {
switch (e) {
case Endian::BIG:
return '>';
case Endian::LITTLE:
return '<';
default:
case Endian::UNDEF:
return '|';
case Endian::NATIVE:
return '=';
}
}
MemoryOverlay::MemoryOverlay(Endian backing, Endian real, size_t width)
: width(width)
, m_backing({ endianTag(backing), 'u', static_cast<char>('0' + width) })
, m_real({ endianTag(real), 'u', static_cast<char>('0' + width) }) {
}
MemoryOverlay::MemoryOverlay(char backing, char real, size_t width)
: width(width)
, m_backing({ backing, 'u', static_cast<char>('0' + width) })
, m_real({ real, 'u', static_cast<char>('0' + width) }) {
}
void* MemoryOverlay::parse(const void* in, size_t offset, void* out, size_t size) const {
size_t offsetEdge = offset & (width - 1);
uintptr_t base = reinterpret_cast<uintptr_t>(in);
base += offset & ~(width - 1);
size += offsetEdge;
uintptr_t outBase = reinterpret_cast<uintptr_t>(out);
for (size_t i = 0; i < size; i += width) {
int64_t val = m_backing.decode(reinterpret_cast<const void*>(base + i));
m_real.encode(reinterpret_cast<void*>(outBase + i), val);
}
return reinterpret_cast<void*>(outBase + offsetEdge);
}
void MemoryOverlay::unparse(void* out, size_t offset, const void* in, size_t size) const {
size_t offsetEdge = offset & (width - 1);
uintptr_t base = reinterpret_cast<uintptr_t>(out);
base += offset & ~(width - 1);
size += offsetEdge;
uintptr_t inBase = reinterpret_cast<uintptr_t>(in);
for (size_t i = 0; i < size; i += width) {
int64_t val = m_real.decode(reinterpret_cast<void*>(inBase + i));
m_backing.encode(reinterpret_cast<void*>(base + i), val);
}
}
Variant::Variant(int64_t i)
: m_type(Type::INT)
, m_vi(i) {
}
Variant::Variant(double d)
: m_type(Type::FLOAT)
, m_vf(d) {
}
Variant::Variant(bool b)
: m_type(Type::BOOL)
, m_vb(b) {
}
Variant::operator int64_t() const {
return cast<int64_t>();
}
Variant::operator int() const {
return cast<int>();
}
Variant::operator float() const {
return cast<float>();
}
Variant::operator double() const {
return cast<double>();
}
Variant::operator bool() const {
return cast<bool>();
}
void Variant::clear() {
m_type = Type::VOID;
}
Variant& Variant::operator=(int64_t v) {
m_type = Type::INT;
m_vi = v;
return *this;
}
Variant& Variant::operator=(double v) {
m_type = Type::FLOAT;
m_vf = v;
return *this;
}
Variant& Variant::operator=(bool v) {
m_type = Type::BOOL;
m_vb = v;
return *this;
}
Datum::Datum(void* base, const DataType& type)
: m_base(base)
, m_type(type) {
}
Datum::Datum(void* base, size_t offset, const DataType& type, const MemoryOverlay& overlay)
: m_base(base)
, m_offset(offset)
, m_type(type)
, m_overlay(overlay) {
}
Datum::Datum(void* base, const Variable& var, const MemoryOverlay& overlay)
: m_base(base)
, m_offset(var.address)
, m_type(var.type)
, m_mask(var.mask)
, m_overlay(overlay) {
}
Datum::Datum(Variant* variant)
: m_type("=i8")
, m_variant(variant) {
}
Datum& Datum::operator=(int64_t value) {
if (m_base) {
if (m_overlay.width > 1 || m_offset) {
uint8_t fakeBase[16]{};
m_type.encode(m_overlay.parse(m_base, m_offset, reinterpret_cast<void*>(fakeBase), m_type.width), value);
m_overlay.unparse(m_base, m_offset, reinterpret_cast<void*>(fakeBase), m_type.width);
} else {
m_type.encode(m_base, value);
}
} else if (m_variant) {
*m_variant = value;
}
return *this;
}
Datum::operator int64_t() const {
if (!m_base) {
if (m_variant) {
return *m_variant;
}
return 0;
}
int64_t value;
if (m_overlay.width > 1 || m_offset) {
uint8_t fakeBase[16]{};
value = m_type.decode(m_overlay.parse(m_base, m_offset, reinterpret_cast<void*>(fakeBase), m_type.width));
} else {
value = m_type.decode(m_base);
}
return value & m_mask;
}
Datum::operator Variant() const {
if (m_variant) {
return *m_variant;
}
return static_cast<int64_t>(*this);
}
DynamicMemoryView::DynamicMemoryView(void* buffer, size_t bytes, const DataType& dtype, const MemoryOverlay& overlay)
: dtype(dtype)
, overlay(overlay) {
m_mem.open(buffer, bytes);
}
Datum DynamicMemoryView::operator[](size_t offset) {
return dtype(m_mem.offset(0), offset, overlay);
}
int64_t DynamicMemoryView::operator[](size_t offset) const {
if (overlay.width > 1) {
uint8_t fakeBase[16]{};
return dtype.decode(overlay.parse(m_mem.offset(0), offset, reinterpret_cast<void*>(fakeBase), dtype.width));
}
return dtype.decode(m_mem.offset(offset));
}
const DataType AddressSpace::s_type{ "|u1" };
void AddressSpace::addBlock(size_t offset, size_t size, void* data) {
if (data) {
m_blocks[offset].open(data, size);
} else {
m_blocks[offset].open(size);
}
}
void AddressSpace::addBlock(size_t offset, size_t size, const void* data) {
if (data) {
m_blocks[offset].clone(data, size);
} else {
m_blocks[offset].open(size);
}
}
void AddressSpace::addBlock(size_t offset, const MemoryView<>& base) {
m_blocks[offset].clone(base);
}
void AddressSpace::updateBlock(size_t offset, void* data) {
m_blocks[offset].open(data, m_blocks[offset].size());
}
void AddressSpace::updateBlock(size_t offset, const void* data) {
m_blocks[offset].clone(data, m_blocks[offset].size());
}
void AddressSpace::updateBlock(size_t offset, const MemoryView<>& base) {
m_blocks[offset].clone(base);
}
bool AddressSpace::hasBlock(size_t offset) const {
for (const auto& block : m_blocks) {
if (offset < block.first) {
continue;
}
if (offset < block.first + block.second.size()) {
return true;
}
}
return false;
}
const MemoryView<>& AddressSpace::block(size_t offset) const {
for (const auto& block : m_blocks) {
if (offset < block.first) {
continue;
}
if (offset < block.first + block.second.size()) {
return block.second;
}
}
throw std::out_of_range("No known mapping 1");
}
MemoryView<>& AddressSpace::block(size_t offset) {
for (auto& block : m_blocks) {
if (offset < block.first) {
continue;
}
if (offset < block.first + block.second.size()) {
return block.second;
}
}
throw std::out_of_range("No known mapping 2");
}
bool AddressSpace::ok() const {
return m_blocks.size() > 0;
}
void AddressSpace::reset() {
m_blocks.clear();
}
void AddressSpace::clone(const AddressSpace& as) {
m_blocks.clear();
m_overlay = make_unique<MemoryOverlay>(*as.m_overlay);
for (auto& kv : as.m_blocks) {
m_blocks[kv.first].clone(kv.second);
}
}
void AddressSpace::clone() {
for (auto& kv : m_blocks) {
kv.second.clone();
}
}
void AddressSpace::setOverlay(const MemoryOverlay& overlay) {
m_overlay = make_unique<MemoryOverlay>(overlay);
}
Datum AddressSpace::operator[](size_t offset) {
for (auto& kv : m_blocks) {
if (offset < kv.first) {
throw std::out_of_range("No known mapping 3");
}
if (offset - kv.first >= kv.second.size()) {
continue;
}
return Datum(kv.second.offset(0), offset - kv.first, s_type, *m_overlay);
}
throw std::out_of_range("No known mapping 4");
}
Datum AddressSpace::operator[](const Variable& var) {
for (auto& kv : m_blocks) {
if (var.address < kv.first) {
throw std::out_of_range("No known mapping 5");
}
if (var.address - kv.first >= kv.second.size()) {
continue;
}
return Datum(kv.second.offset(0), Variable{ var.type, var.address - kv.first, var.mask }, *m_overlay);
}
throw std::out_of_range("No known mapping 6");
}
uint8_t AddressSpace::operator[](size_t offset) const {
for (const auto& kv : m_blocks) {
if (offset < kv.first) {
throw std::out_of_range("No known mapping 7");
}
if (offset - kv.first >= kv.second.size()) {
continue;
}
uint8_t fakeBase[16]{};
return s_type.decode(m_overlay->parse(kv.second.offset(0), offset - kv.first, reinterpret_cast<void*>(fakeBase), s_type.width));
}
throw std::out_of_range("No known mapping 8");
}
int64_t AddressSpace::operator[](const Variable& var) const {
for (const auto& kv : m_blocks) {
if (var.address < kv.first) {
throw std::out_of_range("No known mapping 9");
}
if (var.address - kv.first >= kv.second.size()) {
continue;
}
int64_t value;
if (m_overlay->width > 1) {
uint8_t fakeBase[16];
value = var.type.decode(m_overlay->parse(kv.second.offset(0), var.address - kv.first, reinterpret_cast<void*>(fakeBase), var.type.width));
} else {
value = var.type.decode(kv.second.offset(var.address - kv.first));
}
value &= var.mask;
return value;
}
throw std::out_of_range("No known mapping 10");
}
AddressSpace& AddressSpace::operator=(AddressSpace&& as) {
m_blocks.clear();
m_overlay = move(as.m_overlay);
for (auto& kv : as.m_blocks) {
m_blocks[kv.first] = move(as.m_blocks[kv.first]);
}
as.m_blocks.clear();
return *this;
}
int64_t Retro::toBcd(int64_t value) {
int64_t out = 0;
int shift = 0;
while (value) {
out |= (value % 10) << (shift * 4);
++shift;
value /= 10;
}
return out;
}
int64_t Retro::toLNBcd(int64_t value) {
int64_t out = 0;
int shift = 0;
while (value) {
out |= (value % 10) << (shift * 8);
++shift;
value /= 10;
}
return out;
}
bool Retro::isBcd(uint64_t value) {
uint64_t halfdigits = (value >> 1) & 0x7777777777777777;
return !((halfdigits + 0x3333333333333333) & 0x8888888888888888);
}

474
deps/game_ai_lib/utils/memory.h vendored Normal file
View File

@ -0,0 +1,474 @@
// Adapted from OpenAI's retro source code:
// https://github.com/openai/retro
#pragma once
//#include "gtest/gtest.h"
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <string.h>
#include <fcntl.h>
#ifndef _WIN32
#include <sys/mman.h>
#include <unistd.h>
#else
#include <windows.h>
#include "unistd.h"
#endif
#ifdef VOID
#undef VOID
#endif
namespace Retro {
template<typename T = uint8_t>
class MemoryView {
public:
MemoryView() {}
MemoryView(const MemoryView<T>&) = delete;
~MemoryView();
bool open(const std::string& file, size_t bytes = 0);
void open(void* buffer, size_t bytes);
void open(size_t bytes);
void open(std::initializer_list<T>);
void close();
bool ok() const;
void clone(const void* buffer, size_t bytes);
void clone(const MemoryView<T>&);
void clone();
T& operator[](size_t);
const T& operator[](size_t) const;
MemoryView<T>& operator=(MemoryView<T>&&);
void* offset(size_t);
const void* offset(size_t) const;
size_t size() const;
private:
T* m_buffer = nullptr;
int m_backingFd = -1;
bool m_managed = false;
size_t m_size = 0;
#ifdef _WIN32
HANDLE m_mapView;
#endif
};
template<typename T>
MemoryView<T>::~MemoryView() {
close();
}
template<typename T>
bool MemoryView<T>::open(const std::string& file, size_t bytes) {
if (ok()) {
close();
}
int flags = O_RDWR;
if (bytes) {
flags |= O_CREAT;
}
m_backingFd = ::open(file.c_str(), flags, 0600);
if (m_backingFd < 0) {
return false;
}
if (bytes) {
ftruncate(m_backingFd, bytes);
m_size = bytes;
} else {
m_size = lseek(m_backingFd, 0, SEEK_END);
}
m_managed = true;
#ifdef _WIN32
m_mapView = CreateFileMapping(reinterpret_cast<HANDLE>(_get_osfhandle(m_backingFd)), 0, PAGE_READWRITE, 0, m_size, 0);
m_buffer = reinterpret_cast<T*>(static_cast<uint8_t*>(MapViewOfFile(m_mapView, FILE_MAP_WRITE, 0, 0, m_size)));
#else
m_buffer = reinterpret_cast<T*>(static_cast<uint8_t*>(mmap(nullptr, m_size, PROT_READ | PROT_WRITE, MAP_SHARED, m_backingFd, 0)));
#endif
if (m_buffer == reinterpret_cast<T*>(-1)) {
m_buffer = nullptr;
m_managed = false;
::close(m_backingFd);
return false;
}
return true;
}
template<typename T>
void MemoryView<T>::open(void* buffer, size_t bytes) {
if (ok()) {
close();
}
m_backingFd = -1;
m_size = bytes;
m_managed = false;
m_buffer = static_cast<T*>(buffer);
}
template<typename T>
void MemoryView<T>::open(size_t bytes) {
if (ok()) {
close();
}
m_backingFd = -1;
m_size = bytes;
m_managed = true;
#ifdef _WIN32
m_buffer = static_cast<T*>(VirtualAlloc(nullptr, bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE));
#else
m_buffer = static_cast<T*>(mmap(nullptr, bytes, PROT_READ | PROT_WRITE, MAP_ANON | MAP_SHARED, -1, 0));
#endif
}
template<typename T>
void MemoryView<T>::open(std::initializer_list<T> list) {
open(list.size());
std::copy(list.begin(), list.end(), m_buffer);
}
template<typename T>
void MemoryView<T>::close() {
if (!ok()) {
return;
}
if (m_managed) {
if (m_buffer) {
#ifdef _WIN32
if (m_backingFd >= 0) {
UnmapViewOfFile(m_buffer);
CloseHandle(m_mapView);
} else {
VirtualFree(m_buffer, 0, MEM_RELEASE);
}
#else
munmap(m_buffer, m_size);
#endif
}
if (m_backingFd >= 0) {
::close(m_backingFd);
m_backingFd = -1;
}
}
m_buffer = nullptr;
m_size = 0;
m_managed = false;
}
template<typename T>
bool MemoryView<T>::ok() const {
return m_buffer && m_size;
}
template<typename T>
void MemoryView<T>::clone() {
if (!ok() || m_managed) {
return;
}
clone(static_cast<void*>(m_buffer), m_size);
}
template<typename T>
void MemoryView<T>::clone(const void* buffer, size_t bytes) {
if (m_managed && bytes == m_size) {
memmove(m_buffer, buffer, bytes);
return;
}
if (static_cast<void*>(m_buffer) != buffer || !m_managed) {
close();
}
#ifdef _WIN32
T* newBuffer = static_cast<T*>(VirtualAlloc(nullptr, bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE));
#else
T* newBuffer = static_cast<T*>(mmap(nullptr, bytes, PROT_READ | PROT_WRITE, MAP_ANON | MAP_SHARED, -1, 0));
#endif
memcpy(newBuffer, buffer, bytes);
m_buffer = newBuffer;
m_size = bytes;
m_managed = true;
}
template<typename T>
void MemoryView<T>::clone(const MemoryView<T>& other) {
clone(static_cast<const void*>(other.m_buffer), other.m_size);
}
template<typename T>
T& MemoryView<T>::operator[](size_t index) {
return m_buffer[index];
}
template<typename T>
const T& MemoryView<T>::operator[](size_t index) const {
return m_buffer[index];
}
template<typename T>
MemoryView<T>& MemoryView<T>::operator=(MemoryView<T>&& other) {
close();
m_buffer = other.m_buffer;
m_backingFd = other.m_backingFd;
m_managed = other.m_managed;
m_size = other.m_size;
other.m_managed = false;
return *this;
}
template<typename T>
void* MemoryView<T>::offset(size_t index) {
return reinterpret_cast<void*>(&m_buffer[index]);
}
template<typename T>
const void* MemoryView<T>::offset(size_t index) const {
return reinterpret_cast<const void*>(&m_buffer[index]);
}
template<typename T>
size_t MemoryView<T>::size() const {
return m_size;
}
enum class Endian : char {
BIG = 0b01,
LITTLE = 0b10,
NATIVE = 0b11,
MIXED_BL = 0b1001,
MIXED_LB = 0b0110,
MIXED_BN = 0b1101,
MIXED_LN = 0b1110,
#if defined(__LITTLE_ENDIAN__) || __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
REAL_NATIVE = LITTLE,
REAL_MIXED_BN = MIXED_BL,
REAL_MIXED_LN = LITTLE,
#else
REAL_NATIVE = BIG,
REAL_MIXED_BN = BIG,
REAL_MIXED_LN = MIXED_LN,
#endif
UNDEF = 0
};
Endian reduce(Endian);
bool reduceCompare(Endian, Endian);
enum class Repr : char {
SIGNED = 'i',
UNSIGNED = 'u',
BCD = 'd',
LN_BCD = 'n'
};
class Datum;
class MemoryOverlay;
class DataType {
public:
DataType(const char*);
DataType(const std::string&);
DataType(const DataType&) = default;
Datum operator()(void*) const;
Datum operator()(void*, size_t offset, const MemoryOverlay&) const;
bool operator==(const DataType&) const;
bool operator!=(const DataType&) const;
void encode(void* buffer, int64_t value) const;
int64_t decode(const void* buffer) const;
const size_t width;
const Endian endian;
const Repr repr;
const char type[5];
private:
#if 0
FRIEND_TEST(DataTypeShift, 1);
FRIEND_TEST(DataTypeShift, 2);
FRIEND_TEST(DataTypeShift, 3);
FRIEND_TEST(DataTypeShift, 4);
FRIEND_TEST(DataTypeShift, 5);
FRIEND_TEST(DataTypeShift, 6);
FRIEND_TEST(DataTypeShift, 7);
FRIEND_TEST(DataTypeShift, 8);
#endif
const uint8_t maskLo;
const uint8_t maskHi;
const unsigned cvt;
int64_t shift[8]{};
};
struct Variable {
Variable(const DataType&, size_t address, uint64_t mask = UINT64_MAX);
Variable(const Variable&) = default;
bool operator==(const Variable&) const;
const DataType type;
const size_t address;
const uint64_t mask = UINT64_MAX;
};
class MemoryOverlay {
public:
MemoryOverlay(Endian backing = Endian::NATIVE, Endian real = Endian::NATIVE, size_t width = 1);
MemoryOverlay(char backing, char real, size_t width = 1);
void* parse(const void* in, size_t offset, void* out, size_t size) const;
void unparse(void* out, size_t offset, const void* in, size_t size) const;
const size_t width;
private:
DataType m_backing;
DataType m_real;
};
class Variant {
public:
enum class Type {
BOOL,
INT,
FLOAT,
VOID
};
Variant() {}
Variant(int64_t);
Variant(double);
Variant(bool);
template<typename T>
T cast() const {
switch (m_type) {
case Type::BOOL:
return m_vb;
case Type::INT:
return m_vi;
case Type::FLOAT:
return m_vf;
case Type::VOID:
default:
return T();
}
}
operator int() const;
operator int64_t() const;
operator float() const;
operator double() const;
operator bool() const;
void clear();
Variant& operator=(int64_t);
Variant& operator=(double);
Variant& operator=(bool);
Type type() { return m_type; }
private:
Type m_type = Type::VOID;
union {
bool m_vb;
int64_t m_vi;
double m_vf;
};
};
class Datum {
public:
Datum() {}
Datum(void*, const DataType&);
Datum(void* base, const Variable&, const MemoryOverlay& overlay = {});
Datum(void* base, size_t offset, const DataType&, const MemoryOverlay& overlay = {});
Datum(Variant*);
Datum& operator=(int64_t);
operator int64_t() const;
operator Variant() const;
bool operator==(int64_t);
private:
void* const m_base = nullptr;
const size_t m_offset = 0;
const DataType m_type{ "|u1" };
const uint64_t m_mask = UINT64_MAX;
const MemoryOverlay m_overlay{};
Variant* m_variant = nullptr;
};
class DynamicMemoryView {
public:
DynamicMemoryView(void* buffer, size_t bytes, const DataType&, const MemoryOverlay& = {});
Datum operator[](size_t);
int64_t operator[](size_t) const;
const DataType dtype;
const MemoryOverlay overlay;
private:
MemoryView<> m_mem;
};
class AddressSpace {
public:
void addBlock(size_t offset, size_t size, void* data = nullptr);
void addBlock(size_t offset, size_t size, const void* data);
void addBlock(size_t offset, const MemoryView<>& base);
void updateBlock(size_t offset, void* data);
void updateBlock(size_t offset, const void* data);
void updateBlock(size_t offset, const MemoryView<>& base);
bool hasBlock(size_t offset) const;
const MemoryView<>& block(size_t offset) const;
MemoryView<>& block(size_t offset);
const std::map<size_t, MemoryView<>>& blocks() const { return m_blocks; }
std::map<size_t, MemoryView<>>& blocks() { return m_blocks; }
bool ok() const;
void reset();
void clone(const AddressSpace&);
void clone();
void setOverlay(const MemoryOverlay& overlay);
const MemoryOverlay& overlay() const { return *m_overlay; };
Datum operator[](size_t);
Datum operator[](const Variable&);
uint8_t operator[](size_t) const;
int64_t operator[](const Variable&) const;
AddressSpace& operator=(AddressSpace&&);
private:
static const DataType s_type;
;
std::map<size_t, MemoryView<>> m_blocks;
std::unique_ptr<MemoryOverlay> m_overlay = std::make_unique<MemoryOverlay>();
};
int64_t toBcd(int64_t);
int64_t toLNBcd(int64_t);
bool isBcd(uint64_t);
}
namespace std {
template<>
struct hash<Retro::DataType> {
size_t operator()(const Retro::DataType&) const;
};
}

59
deps/game_ai_lib/utils/unistd.h vendored Normal file
View File

@ -0,0 +1,59 @@
#ifndef _UNISTD_H
#define _UNISTD_H 1
/* This file intended to serve as a drop-in replacement for
* unistd.h on Windows.
* Please add functionality as neeeded.
* Original file from: http://stackoverflow.com/a/826027
*/
#include <stdlib.h>
#include <io.h>
//#include <getopt.h> /* getopt at: https://gist.github.com/bikerm16/1b75e2dd20d839dcea58 */
#include <process.h> /* for getpid() and the exec..() family */
#include <direct.h> /* for _getcwd() and _chdir() */
#define srandom srand
#define random rand
/* Values for the second argument to access.
These may be OR'd together. */
#define R_OK 4 /* Test for read permission. */
#define W_OK 2 /* Test for write permission. */
#define X_OK R_OK /* execute permission - unsupported in Windows,
use R_OK instead. */
#define F_OK 0 /* Test for existence. */
#define access _access
#define dup2 _dup2
#define execve _execve
#define ftruncate _chsize
#define unlink _unlink
#define fileno _fileno
#define getcwd _getcwd
#define chdir _chdir
#define isatty _isatty
#define lseek _lseek
/* read, write, and close are NOT being #defined here,
* because while there are file handle specific versions for Windows,
* they probably don't work for sockets.
* You need to look at your app and consider whether
* to call e.g. closesocket().
*/
#define ssize_t int
#define STDIN_FILENO 0
#define STDOUT_FILENO 1
#define STDERR_FILENO 2
/* should be in some equivalent to <sys/types.h> */
//typedef __int8 int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned __int8 uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
#endif /* unistd.h */

91
deps/game_ai_lib/utils/utils.cpp vendored Normal file
View File

@ -0,0 +1,91 @@
// Adapted from OpenAI's retro source code:
// https://github.com/openai/retro
#include "utils.h"
#include <climits>
#include <cstdlib>
#include <sys/stat.h>
using namespace std;
namespace Retro {
int64_t calculate(Operation op, int64_t reference, int64_t value) {
switch (op) {
case Operation::NOOP:
return value;
case Operation::EQUAL:
return value == reference;
case Operation::NEGATIVE_EQUAL:
return value == -reference;
case Operation::NOT_EQUAL:
return value != reference;
case Operation::LESS_THAN:
return value < reference;
case Operation::GREATER_THAN:
return value > reference;
case Operation::LESS_OR_EQUAL:
return value <= reference;
case Operation::GREATER_OR_EQUAL:
return value >= reference;
case Operation::NONZERO:
return value != 0;
case Operation::ZERO:
return value == 0;
case Operation::POSITIVE:
return value > 0;
case Operation::NEGATIVE:
return value < 0;
case Operation::SIGN:
return value < 0 ? -1 : value > 0 ? 1 : 0;
}
return 0;
}
string drillUp(const vector<string>& targets, const string& fail, const string& hint)
{
#if 0
char rpath[PATH_MAX];
string path(".");
#ifndef _WIN32
if (!hint.empty() && realpath(hint.c_str(), rpath)) {
path = rpath;
}
#else
if (!hint.empty()) {
path = hint;
}
#endif
while (!path.empty() && path != "/") {
for (const auto& target : targets) {
struct stat statbuf;
string testPath = path + "/" + target;
if (stat(testPath.c_str(), &statbuf) == 0 && S_ISDIR(statbuf.st_mode)) {
return testPath;
}
}
#ifndef _WIN32
string new_path = path.substr(0, path.find_last_of('/'));
#else
string new_path = path.substr(0, path.find_last_of("/\\"));
#endif
if (new_path == path) {
break;
}
path = new_path;
#ifndef _WIN32
if (!path.empty() && realpath(path.c_str(), rpath)) {
path = rpath;
}
#endif
}
if (!fail.empty()) {
return fail + "/" + targets[0];
}
return {};
#else
return {};
#endif
}
}

29
deps/game_ai_lib/utils/utils.h vendored Normal file
View File

@ -0,0 +1,29 @@
// Adapted from OpenAI's retro source code:
// https://github.com/openai/retro
#pragma once
#include <string>
#include <vector>
namespace Retro {
enum class Operation {
NOOP,
EQUAL,
NEGATIVE_EQUAL,
NOT_EQUAL,
LESS_THAN,
GREATER_THAN,
LESS_OR_EQUAL,
GREATER_OR_EQUAL,
NONZERO,
ZERO,
POSITIVE,
NEGATIVE,
SIGN,
};
int64_t calculate(Operation op, int64_t reference, int64_t value);
std::string drillUp(const std::vector<std::string>& targets, const std::string& fail = {}, const std::string& hint = ".");
}

View File

@ -1682,3 +1682,10 @@ CLOUD SYNC
#include "../network/cloud_sync_driver.c"
#include "../network/cloud_sync/webdav.c"
#endif
/*============================================================
GAME AI
============================================================ */
#if defined(HAVE_GAME_AI)
#include "../ai/game_ai.c"
#endif

View File

@ -120,3 +120,4 @@ FONTS
#include "../deps/discord-rpc/src/connection_unix.cpp"
#endif
#endif

View File

@ -61,6 +61,8 @@
#include "../tasks/tasks_internal.h"
#include "../verbosity.h"
#include "../ai/game_ai.h"
#define HOLD_BTN_DELAY_SEC 2
/* Depends on ASCII character values */
@ -6882,6 +6884,17 @@ int16_t input_driver_state_wrapper(unsigned port, unsigned device,
result);
#endif
#ifdef HAVE_GAME_AI
if(settings->bools.game_ai_override_p1 && port == 0)
{
result |= game_ai_input(port, device, idx, id, result);
}
if(settings->bools.game_ai_override_p2 && port == 1)
{
result |= game_ai_input(port, device, idx, id, result);
}
#endif
return result;
}

View File

@ -6624,3 +6624,25 @@ MSG_HASH(
MENU_ENUM_LABEL_GAMEMODE_ENABLE,
"game_mode_enable"
)
#ifdef HAVE_GAME_AI
MSG_HASH(
MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS,
"core_game_ai_options"
)
MSG_HASH(
MENU_ENUM_LABEL_QUICK_MENU_SHOW_GAME_AI,
"quick_menu_show_game_ai"
)
MSG_HASH(
MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P1,
"game_ai_override_p1"
)
MSG_HASH(
MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P2,
"game_ai_override_p2"
)
MSG_HASH(
MENU_ENUM_LABEL_GAME_AI_SHOW_DEBUG,
"game_ai_show_debug"
)
#endif

View File

@ -16640,3 +16640,54 @@ MSG_HASH(
MSG_AI_SERVICE_STOPPED,
"stopped."
)
#ifdef HAVE_GAME_AI
MSG_HASH(
MENU_ENUM_LABEL_VALUE_GAME_AI_MENU_OPTION,
"AI player override"
)
MSG_HASH(
MENU_ENUM_SUBLABEL_GAME_AI_MENU_OPTION,
"AI player override sublabel"
)
MSG_HASH(
MENU_ENUM_LABEL_VALUE_CORE_GAME_AI_OPTIONS,
"Game AI"
)
MSG_HASH(
MENU_ENUM_LABEL_VALUE_GAME_AI_OVERRIDE_P1,
"Override p1"
)
MSG_HASH(
MENU_ENUM_SUBLABEL_GAME_AI_OVERRIDE_P1,
"Override player 01"
)
MSG_HASH(
MENU_ENUM_LABEL_VALUE_GAME_AI_OVERRIDE_P2,
"Override p2"
)
MSG_HASH(
MENU_ENUM_SUBLABEL_GAME_AI_OVERRIDE_P2,
"Override player 02"
)
MSG_HASH(
MENU_ENUM_LABEL_VALUE_GAME_AI_SHOW_DEBUG,
"Show Debug"
)
MSG_HASH(
MENU_ENUM_SUBLABEL_GAME_AI_SHOW_DEBUG,
"Show Debug"
)
MSG_HASH(
MENU_ENUM_LABEL_VALUE_QUICK_MENU_SHOW_GAME_AI,
"Show 'Game AI'"
)
MSG_HASH(
MENU_ENUM_SUBLABEL_QUICK_MENU_SHOW_GAME_AI,
"Show the 'Game AI' option."
)
#endif

View File

@ -294,6 +294,10 @@ GENERIC_DEFERRED_PUSH(deferred_push_core_information_steam_list, DISPLAYLIST_
GENERIC_DEFERRED_PUSH(deferred_push_file_browser_select_sideload_core, DISPLAYLIST_FILE_BROWSER_SELECT_SIDELOAD_CORE)
#ifdef HAVE_GAME_AI
GENERIC_DEFERRED_PUSH(deferred_push_core_game_ai_options, DISPLAYLIST_OPTIONS_GAME_AI)
#endif
static int deferred_push_cursor_manager_list_deferred(
menu_displaylist_info_t *info)
{
@ -951,6 +955,10 @@ static int menu_cbs_init_bind_deferred_push_compare_label(
{MENU_ENUM_LABEL_DEFERRED_LAKKA_LIST, deferred_push_lakka_list},
#endif
{MENU_ENUM_LABEL_DEFERRED_ADD_TO_PLAYLIST_LIST, deferred_push_add_to_playlist_list},
#ifdef HAVE_GAME_AI
{MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS, deferred_push_core_game_ai_options},
#endif
};
if (!string_is_equal(label, "null"))
@ -1415,6 +1423,11 @@ static int menu_cbs_init_bind_deferred_push_compare_label(
case MENU_ENUM_LABEL_DEFERRED_ADD_TO_PLAYLIST_LIST:
BIND_ACTION_DEFERRED_PUSH(cbs, deferred_push_add_to_playlist_list);
break;
#ifdef HAVE_GAME_AI
case MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS:
BIND_ACTION_DEFERRED_PUSH(cbs, deferred_push_core_game_ai_options);
break;
#endif
default:
return -1;
}

View File

@ -2016,6 +2016,9 @@ static int menu_cbs_init_bind_get_string_representation_compare_label(
case MENU_ENUM_LABEL_SYSTEM_INFORMATION:
case MENU_ENUM_LABEL_ACHIEVEMENT_LIST:
case MENU_ENUM_LABEL_ACHIEVEMENT_LIST_HARDCORE:
#ifdef HAVE_GAME_AI
case MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS:
#endif
BIND_ACTION_GET_VALUE(cbs,
menu_action_setting_disp_set_label_menu_more);
break;

View File

@ -1399,6 +1399,15 @@ DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_core_create_backup,
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_core_restore_backup_list, MENU_ENUM_SUBLABEL_CORE_RESTORE_BACKUP_LIST)
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_core_delete_backup_list, MENU_ENUM_SUBLABEL_CORE_DELETE_BACKUP_LIST)
#ifdef HAVE_GAME_AI
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_game_ai_menu_option, MENU_ENUM_SUBLABEL_GAME_AI_MENU_OPTION)
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_quick_menu_show_game_ai, MENU_ENUM_SUBLABEL_QUICK_MENU_SHOW_GAME_AI)
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_core_game_ai_options, MENU_ENUM_SUBLABEL_CORE_GAME_AI_OPTIONS)
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_game_ai_override_p1, MENU_ENUM_SUBLABEL_GAME_AI_OVERRIDE_P1)
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_game_ai_override_p2, MENU_ENUM_SUBLABEL_GAME_AI_OVERRIDE_P2)
DEFAULT_SUBLABEL_MACRO(action_bind_sublabel_game_ai_show_debug, MENU_ENUM_SUBLABEL_GAME_AI_SHOW_DEBUG)
#endif
static int action_bind_sublabel_systeminfo_controller_entry(
file_list_t *list,
unsigned type, unsigned i,
@ -5663,6 +5672,23 @@ int menu_cbs_init_bind_sublabel(menu_file_list_cbs_t *cbs,
case MENU_ENUM_LABEL_CORE_DELETE_BACKUP_ENTRY:
BIND_ACTION_SUBLABEL(cbs, action_bind_sublabel_core_backup_entry);
break;
#ifdef HAVE_GAME_AI
case MENU_ENUM_LABEL_QUICK_MENU_SHOW_GAME_AI:
BIND_ACTION_SUBLABEL(cbs, action_bind_sublabel_quick_menu_show_game_ai);
break;
case MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS:
BIND_ACTION_SUBLABEL(cbs, action_bind_sublabel_core_game_ai_options);
break;
case MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P1:
BIND_ACTION_SUBLABEL(cbs, action_bind_sublabel_game_ai_override_p1);
break;
case MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P2:
BIND_ACTION_SUBLABEL(cbs, action_bind_sublabel_game_ai_override_p2);
break;
case MENU_ENUM_LABEL_GAME_AI_SHOW_DEBUG:
BIND_ACTION_SUBLABEL(cbs, action_bind_sublabel_game_ai_show_debug);
break;
#endif
default:
return -1;
}

View File

@ -807,6 +807,10 @@ DEFAULT_FILL_TITLE_SEARCH_FILTER_MACRO(action_get_title_cheat_file_load,
DEFAULT_FILL_TITLE_SEARCH_FILTER_MACRO(action_get_title_cheat_file_load_append, MENU_ENUM_LABEL_VALUE_CHEAT_FILE_APPEND)
DEFAULT_FILL_TITLE_SEARCH_FILTER_MACRO(action_get_title_overlay, MENU_ENUM_LABEL_VALUE_OVERLAY_PRESET)
#ifdef HAVE_GAME_AI
DEFAULT_TITLE_SEARCH_FILTER_MACRO(action_get_core_game_ai_options_list, MENU_ENUM_LABEL_VALUE_CORE_GAME_AI_OPTIONS)
#endif
static int action_get_title_generic(char *s, size_t len,
const char *path, const char *text)
{
@ -1312,6 +1316,10 @@ static int menu_cbs_init_bind_title_compare_label(menu_file_list_cbs_t *cbs,
action_get_title_deferred_core_list},
{MENU_ENUM_LABEL_DEFERRED_CORE_LIST_SET,
action_get_title_deferred_core_list},
#if defined(HAVE_GAME_AI)
{MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS,
action_get_core_game_ai_options_list},
#endif
};
if (cbs->setting)
@ -1751,6 +1759,12 @@ static int menu_cbs_init_bind_title_compare_label(menu_file_list_cbs_t *cbs,
case MENU_ENUM_LABEL_MANUAL_CONTENT_SCAN_DIR:
BIND_ACTION_GET_TITLE(cbs, action_get_title_manual_content_scan_dir);
break;
#ifdef HAVE_GAME_AI
case MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS:
BIND_ACTION_GET_TITLE(cbs, action_get_core_game_ai_options_list);
break;
#endif
default:
return -1;
}

View File

@ -4017,8 +4017,25 @@ static int menu_displaylist_parse_load_content_settings(
MENU_SETTING_ACTION, 0, 0, NULL))
count++;
}
#if HAVE_GAME_AI
if (MENU_DISPLAYLIST_PARSE_SETTINGS_ENUM(list,
MENU_ENUM_LABEL_GAME_AI_MENU_OPTION,
PARSE_ONLY_BOOL, false) == 0)
count++;
if (settings->bools.quick_menu_show_game_ai)
{
if (menu_entries_append(list,
msg_hash_to_str(MENU_ENUM_LABEL_VALUE_CORE_GAME_AI_OPTIONS),
msg_hash_to_str(MENU_ENUM_LABEL_CORE_GAME_AI_OPTIONS),
MENU_ENUM_LABEL_CORE_CHEAT_OPTIONS,
MENU_SETTING_ACTION, 0, 0, NULL))
count++;
}
#endif
}
return count;
}
@ -8450,6 +8467,28 @@ unsigned menu_displaylist_build_list(
}
#endif
break;
#ifdef HAVE_GAME_AI
case DISPLAYLIST_OPTIONS_GAME_AI:
{
if (MENU_DISPLAYLIST_PARSE_SETTINGS_ENUM(list,
MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P1,
PARSE_ONLY_BOOL, false) == 0)
count++;
if (MENU_DISPLAYLIST_PARSE_SETTINGS_ENUM(list,
MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P2,
PARSE_ONLY_BOOL, false) == 0)
count++;
if (MENU_DISPLAYLIST_PARSE_SETTINGS_ENUM(list,
MENU_ENUM_LABEL_GAME_AI_SHOW_DEBUG,
PARSE_ONLY_BOOL, false) == 0)
count++;
}
break;
#endif
case DISPLAYLIST_DROPDOWN_LIST_RESOLUTION:
menu_entries_clear(list);
{
@ -11188,6 +11227,10 @@ unsigned menu_displaylist_build_list(
{MENU_ENUM_LABEL_QUICK_MENU_SHOW_SAVE_CONTENT_DIR_OVERRIDES, PARSE_ONLY_BOOL},
{MENU_ENUM_LABEL_QUICK_MENU_SHOW_SAVE_GAME_OVERRIDES, PARSE_ONLY_BOOL},
{MENU_ENUM_LABEL_QUICK_MENU_SHOW_CHEATS, PARSE_ONLY_BOOL},
#ifdef HAVE_GAME_AI
{MENU_ENUM_LABEL_QUICK_MENU_SHOW_GAME_AI, PARSE_ONLY_BOOL},
#endif
};
for (i = 0; i < ARRAY_SIZE(build_list); i++)
@ -11226,6 +11269,7 @@ unsigned menu_displaylist_build_list(
#ifdef HAVE_NETWORKING
{MENU_ENUM_LABEL_QUICK_MENU_SHOW_DOWNLOAD_THUMBNAILS, PARSE_ONLY_BOOL},
#endif
{MENU_ENUM_LABEL_QUICK_MENU_SHOW_INFORMATION, PARSE_ONLY_BOOL},
};
@ -14557,6 +14601,9 @@ bool menu_displaylist_ctl(enum menu_displaylist_ctl_state type,
case DISPLAYLIST_SUBSYSTEM_SETTINGS_LIST:
#ifdef HAVE_MIST
case DISPLAYLIST_STEAM_SETTINGS_LIST:
#endif
#ifdef HAVE_GAME_AI
case DISPLAYLIST_OPTIONS_GAME_AI:
#endif
case DISPLAYLIST_OPTIONS_OVERRIDES:
menu_entries_clear(info->list);
@ -16200,6 +16247,7 @@ bool menu_displaylist_ctl(enum menu_displaylist_ctl_state type,
info->flags |= MD_FLAG_NEED_REFRESH
| MD_FLAG_NEED_PUSH;
break;
case DISPLAYLIST_NONE:
break;
}

View File

@ -295,6 +295,9 @@ enum menu_displaylist_ctl_state
#if defined(HAVE_LAKKA)
DISPLAYLIST_CPU_PERFPOWER_LIST,
DISPLAYLIST_CPU_POLICY_LIST,
#endif
#ifdef HAVE_GAME_AI
DISPLAYLIST_OPTIONS_GAME_AI,
#endif
DISPLAYLIST_PENDING_CLEAR,
DISPLAYLIST_SHADER_PRESET_PREPEND,

View File

@ -23566,6 +23566,69 @@ static bool setting_append_list(
(*list)[list_info->index - 1].ui_type = ST_UI_TYPE_UINT_COMBOBOX;
#endif
#ifdef HAVE_GAME_AI
CONFIG_BOOL(
list, list_info,
&settings->bools.quick_menu_show_game_ai,
MENU_ENUM_LABEL_QUICK_MENU_SHOW_GAME_AI,
MENU_ENUM_LABEL_VALUE_QUICK_MENU_SHOW_GAME_AI,
1,
MENU_ENUM_LABEL_VALUE_OFF,
MENU_ENUM_LABEL_VALUE_ON,
&group_info,
&subgroup_info,
parent_group,
general_write_handler,
general_read_handler,
SD_FLAG_NONE);
CONFIG_BOOL(
list, list_info,
&settings->bools.game_ai_override_p1,
MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P1,
MENU_ENUM_LABEL_VALUE_GAME_AI_OVERRIDE_P1,
1,
MENU_ENUM_LABEL_VALUE_OFF,
MENU_ENUM_LABEL_VALUE_ON,
&group_info,
&subgroup_info,
parent_group,
general_write_handler,
general_read_handler,
SD_FLAG_CMD_APPLY_AUTO);
CONFIG_BOOL(
list, list_info,
&settings->bools.game_ai_override_p2,
MENU_ENUM_LABEL_GAME_AI_OVERRIDE_P2,
MENU_ENUM_LABEL_VALUE_GAME_AI_OVERRIDE_P2,
1,
MENU_ENUM_LABEL_VALUE_OFF,
MENU_ENUM_LABEL_VALUE_ON,
&group_info,
&subgroup_info,
parent_group,
general_write_handler,
general_read_handler,
SD_FLAG_CMD_APPLY_AUTO);
CONFIG_BOOL(
list, list_info,
&settings->bools.game_ai_show_debug,
MENU_ENUM_LABEL_GAME_AI_SHOW_DEBUG,
MENU_ENUM_LABEL_VALUE_GAME_AI_SHOW_DEBUG,
1,
MENU_ENUM_LABEL_VALUE_OFF,
MENU_ENUM_LABEL_VALUE_ON,
&group_info,
&subgroup_info,
parent_group,
general_write_handler,
general_read_handler,
SD_FLAG_CMD_APPLY_AUTO);
#endif
END_SUB_GROUP(list, list_info, parent_group);
END_GROUP(list, list_info, parent_group);
break;

View File

@ -4275,6 +4275,16 @@ enum msg_hash_enums
MSG_3DS_BOTTOM_MENU_SAVE_STATE,
MSG_3DS_BOTTOM_MENU_LOAD_STATE,
#ifdef HAVE_GAME_AI
MENU_LABEL(QUICK_MENU_SHOW_GAME_AI),
MENU_LABEL(CORE_GAME_AI_OPTIONS),
MENU_LABEL(GAME_AI_MENU_OPTION),
MENU_LABEL(GAME_AI_OVERRIDE_P1),
MENU_LABEL(GAME_AI_OVERRIDE_P2),
MENU_LABEL(GAME_AI_SHOW_DEBUG),
#endif
MSG_LAST,
/* Ensure sizeof(enum) == sizeof(int) */

View File

@ -207,3 +207,4 @@ HAVE_MEMFD_CREATE=auto # libc supports memfd_create
C89_CRTSWITCHRES=no
HAVE_MICROPHONE=yes # Microphone support
HAVE_TEST_DRIVERS=yes # Test input driver
HAVE_GAME_AI=no

View File

@ -209,6 +209,10 @@
#include "accessibility.h"
#ifdef HAVE_GAME_AI
#include "ai/game_ai.h"
#endif
#if defined(HAVE_SDL) || defined(HAVE_SDL2) || defined(HAVE_SDL_DINGUX)
#include "SDL.h"
#endif
@ -7905,6 +7909,12 @@ bool retroarch_main_init(int argc, char *argv[])
preempt_init(runloop_st);
#endif
#ifdef HAVE_GAME_AI
game_ai_init();
#endif
return true;
error:
@ -8445,6 +8455,10 @@ bool retroarch_main_quit(void)
retroarch_menu_running_finished(true);
#endif
#ifdef HAVE_GAME_AI
game_ai_shutdown();
#endif
return true;
}

View File

@ -244,6 +244,10 @@
#include "JITSupport.h"
#endif
#if HAVE_GAME_AI
#include "ai/game_ai.h"
#endif
#define SHADER_FILE_WATCH_DELAY_MSEC 500
#define QUIT_DELAY_USEC 3 * 1000000 /* 3 seconds */
@ -7574,6 +7578,11 @@ bool core_load_game(retro_ctx_load_content_info_t *load_info)
* should be reset once core is deinitialised */
input_state_get_ptr()->flags |= INP_FLAG_REMAPPING_CACHE_ACTIVE;
runloop_st->current_core.flags |= RETRO_CORE_FLAG_GAME_LOADED;
#ifdef HAVE_GAME_AI
/* load models */
game_ai_load(load_info->info->path, runloop_st->current_core.retro_get_memory_data(RETRO_MEMORY_SYSTEM_RAM), runloop_st->current_core.retro_get_memory_size(RETRO_MEMORY_SYSTEM_RAM), libretro_log_cb);
#endif
return true;
}
@ -7714,6 +7723,16 @@ void core_run(void)
current_core->retro_run();
#ifdef HAVE_GAME_AI
settings_t *settings = config_get_ptr();
video_driver_state_t *video_st= video_state_get_ptr();
game_ai_think(settings->bools.game_ai_override_p1, settings->bools.game_ai_override_p2, settings->bools.game_ai_show_debug,
video_st->frame_cache_data, video_st->frame_cache_width, video_st->frame_cache_height, video_st->frame_cache_pitch, video_st->pix_fmt);
#endif
if ( late_polling
&& (!(current_core->flags & RETRO_CORE_FLAG_INPUT_POLLED)))
input_driver_poll();