From f16bfd3a7e2ed2dff2869404ac8a1f48aeea766d Mon Sep 17 00:00:00 2001 From: chreden <4263940+chreden@users.noreply.github.com> Date: Mon, 22 Jul 2024 02:01:08 +0100 Subject: [PATCH] First version of setting plugin directory (#1259) Set the working directory for `dofile` to the plugin directory. Also set the path for `require` to be plugin directory. Fully recreate lua state on plugin reload. Closes #1258 --- trview.app/Lua/ILua.h | 1 + trview.app/Lua/Lua.cpp | 46 ++++++++++++++++++++++++++++++----- trview.app/Lua/Lua.h | 4 +++ trview.app/Mocks/Lua/ILua.h | 1 + trview.app/Plugins/Plugin.cpp | 14 ++++++++++- trview.app/Plugins/Plugin.h | 2 ++ trview.common/Files.cpp | 13 ++++++++++ trview.common/Files.h | 2 ++ trview.common/IFiles.h | 2 ++ trview.common/Mocks/IFiles.h | 2 ++ 10 files changed, 80 insertions(+), 7 deletions(-) diff --git a/trview.app/Lua/ILua.h b/trview.app/Lua/ILua.h index a377d1806..11646e13f 100644 --- a/trview.app/Lua/ILua.h +++ b/trview.app/Lua/ILua.h @@ -10,6 +10,7 @@ namespace trview virtual void do_file(const std::string& file) = 0; virtual void execute(const std::string& command) = 0; virtual void initialise(IApplication* application) = 0; + virtual void set_directory(const std::string& directory) = 0; Event on_print; }; diff --git a/trview.app/Lua/Lua.cpp b/trview.app/Lua/Lua.cpp index 3c6505c4d..b79d45e73 100644 --- a/trview.app/Lua/Lua.cpp +++ b/trview.app/Lua/Lua.cpp @@ -44,6 +44,14 @@ namespace trview return 0; } + int dofile(lua_State* L) + { + luaL_checktype(L, lua_upvalueindex(1), LUA_TUSERDATA); + ILua* self = *static_cast(lua_touserdata(L, lua_upvalueindex(1))); + self->do_file(lua_tostring(L, 1)); + return 0; + } + constexpr luaL_Reg loadedlibs[] = { {LUA_GNAME, luaopen_base}, {LUA_LOADLIBNAME, luaopen_package}, @@ -63,12 +71,7 @@ namespace trview Lua::Lua(const IRoute::Source& route_source, const IRandomizerRoute::Source& randomizer_route_source, const IWaypoint::Source& waypoint_source, const std::shared_ptr& dialogs, const std::shared_ptr& files) : _route_source(route_source), _randomizer_route_source(randomizer_route_source), _waypoint_source(waypoint_source), _dialogs(dialogs), _files(files) { - L = luaL_newstate(); - for (const auto& lib : loadedlibs) - { - luaL_requiref(L, lib.name, lib.func, 1); - lua_pop(L, 1); - } + create_state(); } Lua::~Lua() @@ -78,6 +81,9 @@ namespace trview void Lua::do_file(const std::string& file) { + const auto current_working_directory = _files->working_directory(); + _files->set_working_directory(_directory); + if (luaL_dofile(L, file.c_str()) != LUA_OK) { if (lua_type(L, -1) == LUA_TSTRING) @@ -89,6 +95,8 @@ namespace trview on_print("An error occurred"); } } + + _files->set_working_directory(current_working_directory); } void Lua::execute(const std::string& command) @@ -108,14 +116,40 @@ namespace trview void Lua::initialise(IApplication* application) { + create_state(); ILua** userdata = static_cast(lua_newuserdata(L, sizeof(this))); *userdata = this; lua_pushcclosure(L, print, 1); lua_setglobal(L, "print"); + userdata = static_cast(lua_newuserdata(L, sizeof(this))); + *userdata = this; + lua_pushcclosure(L, dofile, 1); + lua_setglobal(L, "dofile"); lua::trview_register(L, application, _route_source, _randomizer_route_source, _waypoint_source, _dialogs, _files); lua::imgui_register(L); } + void Lua::set_directory(const std::string& directory) + { + _directory = directory; + } + + void Lua::create_state() + { + if (L) + { + lua_close(L); + L = nullptr; + } + + L = luaL_newstate(); + for (const auto& lib : loadedlibs) + { + luaL_requiref(L, lib.name, lib.func, 1); + lua_pop(L, 1); + } + } + namespace lua { int push_string(lua_State* L, const std::string& text) diff --git a/trview.app/Lua/Lua.h b/trview.app/Lua/Lua.h index a278969ab..5ca681556 100644 --- a/trview.app/Lua/Lua.h +++ b/trview.app/Lua/Lua.h @@ -27,13 +27,17 @@ namespace trview void do_file(const std::string& file) override; void execute(const std::string& command) override; void initialise(IApplication* application) override; + void set_directory(const std::string& directory) override; private: + void create_state(); + lua_State* L{ nullptr }; IRoute::Source _route_source; IRandomizerRoute::Source _randomizer_route_source; IWaypoint::Source _waypoint_source; std::shared_ptr _dialogs; std::shared_ptr _files; + std::string _directory; }; namespace lua diff --git a/trview.app/Mocks/Lua/ILua.h b/trview.app/Mocks/Lua/ILua.h index 4a4fce593..bec943280 100644 --- a/trview.app/Mocks/Lua/ILua.h +++ b/trview.app/Mocks/Lua/ILua.h @@ -13,6 +13,7 @@ namespace trview MOCK_METHOD(void, do_file, (const std::string&), (override)); MOCK_METHOD(void, execute, (const std::string&), (override)); MOCK_METHOD(void, initialise, (IApplication*), (override)); + MOCK_METHOD(void, set_directory, (const std::string&), (override)); }; } } diff --git a/trview.app/Plugins/Plugin.cpp b/trview.app/Plugins/Plugin.cpp index aa5d4778a..96d04b41d 100644 --- a/trview.app/Plugins/Plugin.cpp +++ b/trview.app/Plugins/Plugin.cpp @@ -1,4 +1,6 @@ #include "Plugin.h" +#include +#include namespace trview { @@ -17,6 +19,7 @@ namespace trview const std::string& path) : _lua(std::move(lua)), _path(path), _files(files) { + _lua->set_directory(_path); register_print(); load(); } @@ -48,7 +51,9 @@ namespace trview void Plugin::initialise(IApplication* application) { + _application = application; _lua->initialise(application); + set_package_path(); load_script(); } @@ -85,7 +90,7 @@ namespace trview void Plugin::reload() { load(); - load_script(); + initialise(_application); } void Plugin::load() @@ -129,4 +134,11 @@ namespace trview { _lua->execute("if render_ui ~= nil then render_ui() end"); } + + void Plugin::set_package_path() + { + std::string escaped = _path; + std::ranges::replace(escaped, '\\', '/'); + _lua->execute(std::format("package.path = \"{}/?.lua\"", escaped)); + } } diff --git a/trview.app/Plugins/Plugin.h b/trview.app/Plugins/Plugin.h index 1564e19b2..33e9ee0f8 100644 --- a/trview.app/Plugins/Plugin.h +++ b/trview.app/Plugins/Plugin.h @@ -32,6 +32,7 @@ namespace trview void load(); void load_script(); void register_print(); + void set_package_path(); std::shared_ptr _files; std::unique_ptr _lua; @@ -42,5 +43,6 @@ namespace trview std::string _script; std::string _messages; TokenStore _token_store; + IApplication* _application; }; } diff --git a/trview.common/Files.cpp b/trview.common/Files.cpp index ae94219b9..88258890c 100644 --- a/trview.common/Files.cpp +++ b/trview.common/Files.cpp @@ -144,6 +144,19 @@ namespace trview return data; } + std::string Files::working_directory() const + { + DWORD length = GetCurrentDirectory(0, nullptr); + std::vector buffer (static_cast(length) + 1, static_cast(0)); + GetCurrentDirectory(static_cast(buffer.size()), &buffer[0]); + return to_utf8(&buffer[0]); + } + + void Files::set_working_directory(const std::string& directory) + { + SetCurrentDirectory(to_utf16(directory).c_str()); + } + std::vector Files::get_files(const std::wstring& folder, const std::vector& patterns) const { std::vector data; diff --git a/trview.common/Files.h b/trview.common/Files.h index 6078f8b44..c370dfd38 100644 --- a/trview.common/Files.h +++ b/trview.common/Files.h @@ -18,6 +18,8 @@ namespace trview virtual void save_file(const std::string& filename, const std::string& text) const override; virtual std::vector get_files(const std::string& folder, const std::string& pattern) const override; std::vector get_directories(const std::string& folder) const override; + std::string working_directory() const override; + void set_working_directory(const std::string& directory) override; private: std::vector get_files(const std::wstring& folder, const std::vector& patterns) const; }; diff --git a/trview.common/IFiles.h b/trview.common/IFiles.h index 328568959..e3347c391 100644 --- a/trview.common/IFiles.h +++ b/trview.common/IFiles.h @@ -33,5 +33,7 @@ namespace trview virtual void save_file(const std::string& filename, const std::string& text) const = 0; virtual std::vector get_files(const std::string& folder, const std::string& pattern) const = 0; virtual std::vector get_directories(const std::string& folder) const = 0; + virtual std::string working_directory() const = 0; + virtual void set_working_directory(const std::string& directory) = 0; }; } diff --git a/trview.common/Mocks/IFiles.h b/trview.common/Mocks/IFiles.h index 99eb6a1ea..b9b0d3990 100644 --- a/trview.common/Mocks/IFiles.h +++ b/trview.common/Mocks/IFiles.h @@ -20,6 +20,8 @@ namespace trview MOCK_METHOD(void, save_file, (const std::string&, const std::string&), (const, override)); MOCK_METHOD(std::vector, get_files, (const std::string&, const std::string&), (const, override)); MOCK_METHOD(std::vector, get_directories, (const std::string&), (const, override)); + MOCK_METHOD(std::string, working_directory, (), (const, override)); + MOCK_METHOD(void, set_working_directory, (const std::string&), (override)); }; } }