#ifndef COMPONENTS_LUA_LUASTATE_H #define COMPONENTS_LUA_LUASTATE_H #include #include #include #include "configuration.hpp" namespace VFS { class Manager; } namespace LuaUtil { std::string getLuaVersion(); class ScriptsContainer; struct ScriptId { ScriptsContainer* mContainer = nullptr; int mIndex; // index in LuaUtil::ScriptsConfiguration }; struct LuaStateSettings { uint64_t mInstructionLimit = 0; // 0 is unlimited uint64_t mMemoryLimit = 0; // 0 is unlimited uint64_t mSmallAllocMaxSize = 1024 * 1024; // big default value efficiently disables memory tracking bool mLogMemoryUsage = false; }; // Holds Lua state. // Provides additional features: // - Load scripts from the virtual filesystem; // - Caching of loaded scripts; // - Disable unsafe Lua functions; // - Run every instance of every script in a separate sandbox; // - Forbid any interactions between sandboxes except than via provided API; // - Access to common read-only resources from different sandboxes; // - Replace standard `require` with a safe version that allows to search // Lua libraries (only source, no dll's) in the virtual filesystem; // - Make `print` to add the script name to every message and // write to the Log rather than directly to stdout; class LuaState { public: explicit LuaState(const VFS::Manager* vfs, const ScriptsConfiguration* conf, const LuaStateSettings& settings = LuaStateSettings{}); LuaState(const LuaState&) = delete; LuaState(LuaState&&) = delete; ~LuaState(); // Returns underlying sol::state. sol::state& sol() { return mLua; } // Can be used by a C++ function that is called from Lua to get the Lua traceback. // Makes no sense if called not from Lua code. // Note: It is a slow function, should be used for debug purposes only. std::string debugTraceback() { return mLua["debug"]["traceback"]().get(); } // A shortcut to create a new Lua table. sol::table newTable() { return sol::table(mLua, sol::create); } template sol::table tableFromPairs(std::initializer_list> list) { sol::table res(mLua, sol::create); for (const auto& [k, v] : list) res[k] = v; return res; } // Registers a package that will be available from every sandbox via `require(name)`. // The package can be either a sol::table with an API or a sol::function. If it is a function, // it will be evaluated (once per sandbox) the first time when requested. If the package // is a table, then `makeReadOnly` is applied to it automatically (but not to other tables it contains). void addCommonPackage(std::string packageName, sol::object package); // Creates a new sandbox, runs a script, and returns the result // (the result is expected to be an interface of the script). // Args: // path: path to the script in the virtual filesystem; // namePrefix: sandbox name will be "[]". Sandbox name // will be added to every `print` output. // packages: additional packages that should be available from the sandbox via `require`. Each package // should be either a sol::table or a sol::function. If it is a function, it will be evaluated // (once per sandbox) with the argument 'hiddenData' the first time when requested. sol::protected_function_result runInNewSandbox(const std::string& path, const std::string& namePrefix = "", const std::map& packages = {}, const sol::object& hiddenData = sol::nil); void dropScriptCache() { mCompiledScripts.clear(); } const ScriptsConfiguration& getConfiguration() const { return *mConf; } // Load internal Lua library. All libraries are loaded in one sandbox and shouldn't be exposed to scripts // directly. void addInternalLibSearchPath(const std::filesystem::path& path) { mLibSearchPaths.push_back(path); } sol::function loadInternalLib(std::string_view libName); sol::function loadFromVFS(const std::string& path); sol::environment newInternalLibEnvironment(); uint64_t getTotalMemoryUsage() const { return mTotalMemoryUsage; } uint64_t getSmallAllocMemoryUsage() const { return mSmallAllocMemoryUsage; } uint64_t getMemoryUsageByScriptIndex(unsigned id) const { return id < mMemoryUsage.size() ? mMemoryUsage[id] : 0; } const LuaStateSettings& getSettings() const { return mSettings; } private: static sol::protected_function_result throwIfError(sol::protected_function_result&&); template friend sol::protected_function_result call(const sol::protected_function& fn, Args&&... args); template friend sol::protected_function_result call( ScriptId scriptId, const sol::protected_function& fn, Args&&... args); sol::function loadScriptAndCache(const std::string& path); static void countHook(lua_State* L, lua_Debug* ar); static void* trackingAllocator(void* ud, void* ptr, size_t osize, size_t nsize); struct AllocOwner { std::shared_ptr mContainer; int mScriptIndex; }; const LuaStateSettings mSettings; // Needed to track resource usage per script, must be initialized before mLua. ScriptId mActiveScriptId; uint64_t mCurrentCallInstructionCounter = 0; std::map mBigAllocOwners; uint64_t mTotalMemoryUsage = 0; uint64_t mSmallAllocMemoryUsage = 0; std::vector mMemoryUsage; sol::state mLua; const ScriptsConfiguration* mConf; sol::table mSandboxEnv; std::map mCompiledScripts; std::map mCommonPackages; const VFS::Manager* mVFS; std::vector mLibSearchPaths; }; // LuaUtil::call should be used for every call of every Lua function. // 1) It is a workaround for a bug in `sol`. See https://github.com/ThePhD/sol2/issues/1078 // 2) When called with ScriptId it tracks resource usage (scriptId refers to the script that is responsible for this // call). template sol::protected_function_result call(const sol::protected_function& fn, Args&&... args) { try { auto res = LuaState::throwIfError(fn(std::forward(args)...)); return res; } catch (std::exception&) { throw; } catch (...) { throw std::runtime_error("Unknown error"); } } // Lua must be initialized through LuaUtil::LuaState, otherwise this function will segfault. template sol::protected_function_result call(ScriptId scriptId, const sol::protected_function& fn, Args&&... args) { LuaState* luaState; (void)lua_getallocf(fn.lua_state(), reinterpret_cast(&luaState)); assert(luaState->mActiveScriptId.mContainer == nullptr && "recursive call of LuaUtil::call(scriptId, ...) ?"); luaState->mActiveScriptId = scriptId; luaState->mCurrentCallInstructionCounter = 0; try { auto res = LuaState::throwIfError(fn(std::forward(args)...)); luaState->mActiveScriptId = {}; return res; } catch (std::exception&) { luaState->mActiveScriptId = {}; throw; } catch (...) { luaState->mActiveScriptId = {}; throw std::runtime_error("Unknown error"); } } // getFieldOrNil(table, "a", "b", "c") returns table["a"]["b"]["c"] or nil if some of the fields doesn't exist. template sol::object getFieldOrNil(const sol::object& table, std::string_view first, const Str&... str) { if (!table.is()) return sol::nil; if constexpr (sizeof...(str) == 0) return table.as()[first]; else return getFieldOrNil(table.as()[first], str...); } // String representation of a Lua object. Should be used for debugging/logging purposes only. std::string toString(const sol::object&); template T getValueOrDefault(const sol::object& obj, const T& defaultValue) { if (obj == sol::nil) return defaultValue; if (obj.is()) return obj.as(); else throw std::logic_error(std::string("Value \"") + toString(obj) + std::string("\" has unexpected type")); } // Makes a table read only (when accessed from Lua) by wrapping it with an empty userdata. // Needed to forbid any changes in common resources that can be accessed from different sandboxes. // `strictIndex = true` replaces default `__index` with a strict version that throws an error if key is not found. sol::table makeReadOnly(const sol::table&, bool strictIndex = false); inline sol::table makeStrictReadOnly(const sol::table& tbl) { return makeReadOnly(tbl, true); } sol::table getMutableFromReadOnly(const sol::userdata&); } #endif // COMPONENTS_LUA_LUASTATE_H