1
0
mirror of https://gitlab.com/OpenMW/openmw.git synced 2025-01-26 09:35:28 +00:00

Merge branch 'lua_casting_error' into 'master'

Fix crash on sol::object type mismatch in invalid Lua script

See merge request OpenMW/openmw!2975
This commit is contained in:
psi29a 2023-04-25 22:19:45 +00:00
commit 0cf9fe0e2e
15 changed files with 102 additions and 52 deletions

View File

@ -230,7 +230,7 @@ namespace MWLua
if (spellOrId.is<ESM::Spell>()) if (spellOrId.is<ESM::Spell>())
return spellOrId.as<const ESM::Spell*>()->mId; return spellOrId.as<const ESM::Spell*>()->mId;
else else
return ESM::RefId::deserializeText(spellOrId.as<std::string_view>()); return ESM::RefId::deserializeText(LuaUtil::cast<std::string_view>(spellOrId));
}; };
// types.Actor.spells(o):add(id) // types.Actor.spells(o):add(id)

View File

@ -57,7 +57,7 @@ namespace MWLua
cell = cellOrName.as<const GCell&>().mStore; cell = cellOrName.as<const GCell&>().mStore;
else else
{ {
std::string_view name = cellOrName.as<std::string_view>(); std::string_view name = LuaUtil::cast<std::string_view>(cellOrName);
if (name.empty()) if (name.empty())
cell = nullptr; // default exterior worldspace cell = nullptr; // default exterior worldspace
else else
@ -253,7 +253,7 @@ namespace MWLua
throw std::runtime_error("Attaching scripts to Static is not allowed: " + std::string(path)); throw std::runtime_error("Attaching scripts to Static is not allowed: " + std::string(path));
if (initData != sol::nil) if (initData != sol::nil)
context.mLuaManager->addCustomLocalScript(object.ptr(), *scriptId, context.mLuaManager->addCustomLocalScript(object.ptr(), *scriptId,
LuaUtil::serialize(initData.as<sol::table>(), context.mSerializer)); LuaUtil::serialize(LuaUtil::cast<sol::table>(initData), context.mSerializer));
else else
context.mLuaManager->addCustomLocalScript( context.mLuaManager->addCustomLocalScript(
object.ptr(), *scriptId, cfg[*scriptId].mInitializationData); object.ptr(), *scriptId, cfg[*scriptId].mInitializationData);

View File

@ -123,7 +123,7 @@ namespace MWLua
{ {
auto& stats = ptr.getClass().getCreatureStats(ptr); auto& stats = ptr.getClass().getCreatureStats(ptr);
if (prop == "current") if (prop == "current")
stats.setLevel(value.as<int>()); stats.setLevel(LuaUtil::cast<int>(value));
} }
}; };
@ -167,7 +167,7 @@ namespace MWLua
{ {
auto& stats = ptr.getClass().getCreatureStats(ptr); auto& stats = ptr.getClass().getCreatureStats(ptr);
auto stat = stats.getDynamic(index); auto stat = stats.getDynamic(index);
float floatValue = value.as<float>(); float floatValue = LuaUtil::cast<float>(value);
if (prop == "base") if (prop == "base")
stat.setBase(floatValue); stat.setBase(floatValue);
else if (prop == "current") else if (prop == "current")
@ -201,9 +201,9 @@ namespace MWLua
float getModified(const Context& context) const float getModified(const Context& context) const
{ {
auto base = get(context, "base", &MWMechanics::AttributeValue::getBase).as<float>(); auto base = LuaUtil::cast<float>(get(context, "base", &MWMechanics::AttributeValue::getBase));
auto damage = get(context, "damage", &MWMechanics::AttributeValue::getDamage).as<float>(); auto damage = LuaUtil::cast<float>(get(context, "damage", &MWMechanics::AttributeValue::getDamage));
auto modifier = get(context, "modifier", &MWMechanics::AttributeValue::getModifier).as<float>(); auto modifier = LuaUtil::cast<float>(get(context, "modifier", &MWMechanics::AttributeValue::getModifier));
return std::max(0.f, base - damage + modifier); // Should match AttributeValue::getModified return std::max(0.f, base - damage + modifier); // Should match AttributeValue::getModified
} }
@ -226,7 +226,7 @@ namespace MWLua
{ {
auto& stats = ptr.getClass().getCreatureStats(ptr); auto& stats = ptr.getClass().getCreatureStats(ptr);
auto stat = stats.getAttribute(index); auto stat = stats.getAttribute(index);
float floatValue = value.as<float>(); float floatValue = LuaUtil::cast<float>(value);
if (prop == "base") if (prop == "base")
stat.setBase(floatValue); stat.setBase(floatValue);
else if (prop == "damage") else if (prop == "damage")
@ -278,9 +278,9 @@ namespace MWLua
float getModified(const Context& context) const float getModified(const Context& context) const
{ {
auto base = get(context, "base", &MWMechanics::SkillValue::getBase).as<float>(); auto base = LuaUtil::cast<float>(get(context, "base", &MWMechanics::SkillValue::getBase));
auto damage = get(context, "damage", &MWMechanics::SkillValue::getDamage).as<float>(); auto damage = LuaUtil::cast<float>(get(context, "damage", &MWMechanics::SkillValue::getDamage));
auto modifier = get(context, "modifier", &MWMechanics::SkillValue::getModifier).as<float>(); auto modifier = LuaUtil::cast<float>(get(context, "modifier", &MWMechanics::SkillValue::getModifier));
return std::max(0.f, base - damage + modifier); // Should match SkillValue::getModified return std::max(0.f, base - damage + modifier); // Should match SkillValue::getModified
} }
@ -311,7 +311,7 @@ namespace MWLua
{ {
auto& stats = ptr.getClass().getNpcStats(ptr); auto& stats = ptr.getClass().getNpcStats(ptr);
auto stat = stats.getSkill(index); auto stat = stats.getSkill(index);
float floatValue = value.as<float>(); float floatValue = LuaUtil::cast<float>(value);
if (prop == "base") if (prop == "base")
stat.setBase(floatValue); stat.setBase(floatValue);
else if (prop == "damage") else if (prop == "damage")

View File

@ -272,11 +272,11 @@ namespace MWLua
SetEquipmentAction::Equipment eqp; SetEquipmentAction::Equipment eqp;
for (auto& [key, value] : equipment) for (auto& [key, value] : equipment)
{ {
int slot = key.as<int>(); int slot = LuaUtil::cast<int>(key);
if (value.is<Object>()) if (value.is<Object>())
eqp[slot] = value.as<Object>().id(); eqp[slot] = LuaUtil::cast<Object>(value).id();
else else
eqp[slot] = value.as<std::string>(); eqp[slot] = LuaUtil::cast<std::string>(value);
} }
context.mLuaManager->addAction( context.mLuaManager->addAction(
std::make_unique<SetEquipmentAction>(context.mLua, obj.id(), std::move(eqp))); std::make_unique<SetEquipmentAction>(context.mLua, obj.id(), std::move(eqp)));

View File

@ -100,6 +100,23 @@ return {
EXPECT_EQ(LuaUtil::toString(sol::make_object(mLua.sol(), "something")), "\"something\""); EXPECT_EQ(LuaUtil::toString(sol::make_object(mLua.sol(), "something")), "\"something\"");
} }
TEST_F(LuaStateTest, Cast)
{
EXPECT_EQ(LuaUtil::cast<int>(sol::make_object(mLua.sol(), 3.14)), 3);
EXPECT_ERROR(
LuaUtil::cast<int>(sol::make_object(mLua.sol(), "3.14")), "Value \"\"3.14\"\" can not be casted to int");
EXPECT_ERROR(LuaUtil::cast<std::string_view>(sol::make_object(mLua.sol(), sol::nil)),
"Value \"nil\" can not be casted to string");
EXPECT_ERROR(LuaUtil::cast<std::string>(sol::make_object(mLua.sol(), sol::nil)),
"Value \"nil\" can not be casted to string");
EXPECT_ERROR(LuaUtil::cast<sol::table>(sol::make_object(mLua.sol(), sol::nil)),
"Value \"nil\" can not be casted to sol::table");
EXPECT_ERROR(LuaUtil::cast<sol::function>(sol::make_object(mLua.sol(), "3.14")),
"Value \"\"3.14\"\" can not be casted to sol::function");
EXPECT_ERROR(LuaUtil::cast<sol::protected_function>(sol::make_object(mLua.sol(), "3.14")),
"Value \"\"3.14\"\" can not be casted to sol::function");
}
TEST_F(LuaStateTest, ErrorHandling) TEST_F(LuaStateTest, ErrorHandling)
{ {
EXPECT_ERROR(mLua.runInNewSandbox("invalid.lua"), "[string \"invalid.lua\"]:1:"); EXPECT_ERROR(mLua.runInNewSandbox("invalid.lua"), "[string \"invalid.lua\"]:1:");

View File

@ -2,6 +2,7 @@
#include <components/debug/debuglog.hpp> #include <components/debug/debuglog.hpp>
#include <components/l10n/manager.hpp> #include <components/l10n/manager.hpp>
#include <components/lua/luastate.hpp>
namespace namespace
{ {
@ -17,20 +18,20 @@ namespace
{ {
// Argument values // Argument values
if (value.is<std::string>()) if (value.is<std::string>())
args.push_back(icu::Formattable(value.as<std::string>().c_str())); args.push_back(icu::Formattable(LuaUtil::cast<std::string>(value).c_str()));
// Note: While we pass all numbers as doubles, they still seem to be handled appropriately. // Note: While we pass all numbers as doubles, they still seem to be handled appropriately.
// Numbers can be forced to be integers using the argType number and argStyle integer // Numbers can be forced to be integers using the argType number and argStyle integer
// E.g. {var, number, integer} // E.g. {var, number, integer}
else if (value.is<double>()) else if (value.is<double>())
args.push_back(icu::Formattable(value.as<double>())); args.push_back(icu::Formattable(LuaUtil::cast<double>(value)));
else else
{ {
Log(Debug::Error) << "Unrecognized argument type for key \"" << key.as<std::string>() Log(Debug::Error) << "Unrecognized argument type for key \"" << LuaUtil::cast<std::string>(key)
<< "\" when formatting message \"" << messageId << "\""; << "\" when formatting message \"" << messageId << "\"";
} }
// Argument names // Argument names
const auto str = key.as<std::string>(); const auto str = LuaUtil::cast<std::string>(key);
argNames.push_back(icu::UnicodeString::fromUTF8(icu::StringPiece(str.data(), str.size()))); argNames.push_back(icu::UnicodeString::fromUTF8(icu::StringPiece(str.data(), str.size())));
} }
} }

View File

@ -420,4 +420,25 @@ namespace LuaUtil
return call(sol::state_view(obj.lua_state())["tostring"], obj); return call(sol::state_view(obj.lua_state())["tostring"], obj);
} }
std::string internal::formatCastingError(const sol::object& obj, const std::type_info& t)
{
const char* typeName = t.name();
if (t == typeid(int))
typeName = "int";
else if (t == typeid(unsigned))
typeName = "uint32";
else if (t == typeid(size_t))
typeName = "size_t";
else if (t == typeid(float))
typeName = "float";
else if (t == typeid(double))
typeName = "double";
else if (t == typeid(sol::table))
typeName = "sol::table";
else if (t == typeid(sol::function) || t == typeid(sol::protected_function))
typeName = "sol::function";
else if (t == typeid(std::string) || t == typeid(std::string_view))
typeName = "string";
return std::string("Value \"") + toString(obj) + std::string("\" can not be casted to ") + typeName;
}
} }

View File

@ -1,12 +1,12 @@
#ifndef COMPONENTS_LUA_LUASTATE_H #ifndef COMPONENTS_LUA_LUASTATE_H
#define COMPONENTS_LUA_LUASTATE_H #define COMPONENTS_LUA_LUASTATE_H
#include <filesystem>
#include <map> #include <map>
#include <typeinfo>
#include <sol/sol.hpp> #include <sol/sol.hpp>
#include <filesystem>
#include "configuration.hpp" #include "configuration.hpp"
namespace VFS namespace VFS
@ -247,15 +247,25 @@ namespace LuaUtil
// String representation of a Lua object. Should be used for debugging/logging purposes only. // String representation of a Lua object. Should be used for debugging/logging purposes only.
std::string toString(const sol::object&); std::string toString(const sol::object&);
namespace internal
{
std::string formatCastingError(const sol::object& obj, const std::type_info&);
}
template <class T>
decltype(auto) cast(const sol::object& obj)
{
if (!obj.is<T>())
throw std::runtime_error(internal::formatCastingError(obj, typeid(T)));
return obj.as<T>();
}
template <class T> template <class T>
T getValueOrDefault(const sol::object& obj, const T& defaultValue) T getValueOrDefault(const sol::object& obj, const T& defaultValue)
{ {
if (obj == sol::nil) if (obj == sol::nil)
return defaultValue; return defaultValue;
if (obj.is<T>()) return cast<T>(obj);
return obj.as<T>();
else
throw std::logic_error(std::string("Value \"") + toString(obj) + std::string("\" has unexpected type"));
} }
// Makes a table read only (when accessed from Lua) by wrapping it with an empty userdata. // Makes a table read only (when accessed from Lua) by wrapping it with an empty userdata.

View File

@ -94,33 +94,34 @@ namespace LuaUtil
if (scriptOutput == sol::nil) if (scriptOutput == sol::nil)
return true; return true;
sol::object engineHandlers = sol::nil, eventHandlers = sol::nil; sol::object engineHandlers = sol::nil, eventHandlers = sol::nil;
for (const auto& [key, value] : sol::table(scriptOutput)) for (const auto& [key, value] : cast<sol::table>(scriptOutput))
{ {
std::string_view sectionName = key.as<std::string_view>(); std::string_view sectionName = cast<std::string_view>(key);
if (sectionName == ENGINE_HANDLERS) if (sectionName == ENGINE_HANDLERS)
engineHandlers = value; engineHandlers = value;
else if (sectionName == EVENT_HANDLERS) else if (sectionName == EVENT_HANDLERS)
eventHandlers = value; eventHandlers = value;
else if (sectionName == INTERFACE_NAME) else if (sectionName == INTERFACE_NAME)
script.mInterfaceName = value.as<std::string>(); script.mInterfaceName = cast<std::string>(value);
else if (sectionName == INTERFACE) else if (sectionName == INTERFACE)
script.mInterface = value.as<sol::table>(); script.mInterface = cast<sol::table>(value);
else else
Log(Debug::Error) << "Not supported section '" << sectionName << "' in " << debugName; Log(Debug::Error) << "Not supported section '" << sectionName << "' in " << debugName;
} }
if (engineHandlers != sol::nil) if (engineHandlers != sol::nil)
{ {
for (const auto& [key, fn] : sol::table(engineHandlers)) for (const auto& [key, handler] : cast<sol::table>(engineHandlers))
{ {
std::string_view handlerName = key.as<std::string_view>(); std::string_view handlerName = cast<std::string_view>(key);
sol::function fn = cast<sol::function>(handler);
if (handlerName == HANDLER_INIT) if (handlerName == HANDLER_INIT)
onInit = sol::function(fn); onInit = fn;
else if (handlerName == HANDLER_LOAD) else if (handlerName == HANDLER_LOAD)
onLoad = sol::function(fn); onLoad = fn;
else if (handlerName == HANDLER_SAVE) else if (handlerName == HANDLER_SAVE)
script.mOnSave = sol::function(fn); script.mOnSave = fn;
else if (handlerName == HANDLER_INTERFACE_OVERRIDE) else if (handlerName == HANDLER_INTERFACE_OVERRIDE)
script.mOnOverride = sol::function(fn); script.mOnOverride = fn;
else else
{ {
auto it = mEngineHandlers.find(handlerName); auto it = mEngineHandlers.find(handlerName);
@ -133,13 +134,13 @@ namespace LuaUtil
} }
if (eventHandlers != sol::nil) if (eventHandlers != sol::nil)
{ {
for (const auto& [key, fn] : sol::table(eventHandlers)) for (const auto& [key, fn] : cast<sol::table>(eventHandlers))
{ {
std::string_view eventName = key.as<std::string_view>(); std::string_view eventName = cast<std::string_view>(key);
auto it = mEventHandlers.find(eventName); auto it = mEventHandlers.find(eventName);
if (it == mEventHandlers.end()) if (it == mEventHandlers.end())
it = mEventHandlers.emplace(std::string(eventName), EventHandlerList()).first; it = mEventHandlers.emplace(std::string(eventName), EventHandlerList()).first;
insertHandler(it->second, scriptId, fn); insertHandler(it->second, scriptId, cast<sol::function>(fn));
} }
} }
@ -318,7 +319,7 @@ namespace LuaUtil
try try
{ {
sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data); sol::object res = LuaUtil::call({ this, h.mScriptId }, h.mFn, data);
if (res != sol::nil && !res.as<bool>()) if (res.is<bool>() && !res.as<bool>())
break; // Skip other handlers if 'false' was returned. break; // Skip other handlers if 'false' was returned.
} }
catch (std::exception& e) catch (std::exception& e)

View File

@ -106,7 +106,7 @@ namespace LuaUtil
bool BasicSerializer::serialize(BinaryData& out, const sol::userdata& data) const bool BasicSerializer::serialize(BinaryData& out, const sol::userdata& data) const
{ {
appendRefNum(out, data.as<ESM::RefNum>()); appendRefNum(out, cast<ESM::RefNum>(data));
return true; return true;
} }

View File

@ -85,7 +85,7 @@ namespace LuaUtil
if (values) if (values)
{ {
for (const auto& [k, v] : *values) for (const auto& [k, v] : *values)
mValues[k.as<std::string>()] = Value(v); mValues[cast<std::string>(k)] = Value(v);
} }
if (mStorage->mListener) if (mStorage->mListener)
mStorage->mListener->sectionReplaced(mSectionName, values); mStorage->mListener->sectionReplaced(mSectionName, values);
@ -166,9 +166,9 @@ namespace LuaUtil
sol::table data = deserialize(mLua, serializedData); sol::table data = deserialize(mLua, serializedData);
for (const auto& [sectionName, sectionTable] : data) for (const auto& [sectionName, sectionTable] : data)
{ {
const std::shared_ptr<Section>& section = getSection(sectionName.as<std::string_view>()); const std::shared_ptr<Section>& section = getSection(cast<std::string_view>(sectionName));
for (const auto& [key, value] : sol::table(sectionTable)) for (const auto& [key, value] : cast<sol::table>(sectionTable))
section->set(key.as<std::string_view>(), value); section->set(cast<std::string_view>(key), value);
} }
} }
catch (std::exception& e) catch (std::exception& e)

View File

@ -236,17 +236,17 @@ namespace LuaUtil
{ {
util["bitOr"] = [](unsigned a, sol::variadic_args va) { util["bitOr"] = [](unsigned a, sol::variadic_args va) {
for (const auto& v : va) for (const auto& v : va)
a |= v.as<unsigned>(); a |= cast<unsigned>(v);
return a; return a;
}; };
util["bitAnd"] = [](unsigned a, sol::variadic_args va) { util["bitAnd"] = [](unsigned a, sol::variadic_args va) {
for (const auto& v : va) for (const auto& v : va)
a &= v.as<unsigned>(); a &= cast<unsigned>(v);
return a; return a;
}; };
util["bitXor"] = [](unsigned a, sol::variadic_args va) { util["bitXor"] = [](unsigned a, sol::variadic_args va) {
for (const auto& v : va) for (const auto& v : va)
a ^= v.as<unsigned>(); a ^= cast<unsigned>(v);
return a; return a;
}; };
util["bitNot"] = [](unsigned a) { return ~a; }; util["bitNot"] = [](unsigned a) { return ~a; };

View File

@ -78,7 +78,7 @@ namespace LuaUi
{ {
sol::object result = callMethod("indexOf", name); sol::object result = callMethod("indexOf", name);
if (result.is<size_t>()) if (result.is<size_t>())
return fromLua(result.as<size_t>()); return fromLua(LuaUtil::cast<size_t>(result));
else else
return std::nullopt; return std::nullopt;
} }
@ -86,7 +86,7 @@ namespace LuaUi
{ {
sol::object result = callMethod("indexOf", table); sol::object result = callMethod("indexOf", table);
if (result.is<size_t>()) if (result.is<size_t>())
return fromLua(result.as<size_t>()); return fromLua(LuaUtil::cast<size_t>(result));
else else
return std::nullopt; return std::nullopt;
} }

View File

@ -63,7 +63,7 @@ namespace LuaUi
destroyWidget(w); destroyWidget(w);
return result; return result;
} }
ContentView content(contentObj.as<sol::table>()); ContentView content(LuaUtil::cast<sol::table>(contentObj));
result.resize(content.size()); result.resize(content.size());
size_t minSize = std::min(children.size(), content.size()); size_t minSize = std::min(children.size(), content.size());
for (size_t i = 0; i < minSize; i++) for (size_t i = 0; i < minSize; i++)

View File

@ -36,7 +36,7 @@ namespace LuaUi
private: private:
Element(sol::table layout); Element(sol::table layout);
sol::table layout() { return mLayout.as<sol::table>(); } sol::table layout() { return LuaUtil::cast<sol::table>(mLayout); }
static std::map<Element*, std::shared_ptr<Element>> sAllElements; static std::map<Element*, std::shared_ptr<Element>> sAllElements;
void updateAttachment(); void updateAttachment();
}; };