From 8a7a226bfa832a3c23cd4018b832ed470aa25b6c Mon Sep 17 00:00:00 2001 From: vocksel Date: Sat, 11 Jun 2022 15:40:07 -0700 Subject: [PATCH] Clear all consumers when a module changes (#9) * Add the beginning of consumer cache invalidation * Clear the consumer cache on source change * Update .gitignore * Add test for ensuring no interference Also removes the ConsumerTest folder for generating on the fly * Unfocus * Remove debug prints * Create a module for getCallerPath --- .gitignore | 1 + scripts/test.sh | 2 + src/getCallerPath.lua | 32 +++++++++++++++ src/init.lua | 72 ++++++++++++++++++++++++++------ src/init.spec.lua | 95 +++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 185 insertions(+), 17 deletions(-) create mode 100644 scripts/test.sh create mode 100644 src/getCallerPath.lua diff --git a/.gitignore b/.gitignore index 49cc1e0..6ae64d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Rojo /*.rbxl* /*.rbxm* +sourcemap.json # Selene /roblox.toml diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000..22c4808 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,2 @@ +rojo build dev.project.json -o studio-tests.rbxl +run-in-roblox --place studio-tests.rbxl --script tests/init.server.lua \ No newline at end of file diff --git a/src/getCallerPath.lua b/src/getCallerPath.lua new file mode 100644 index 0000000..bb2b867 --- /dev/null +++ b/src/getCallerPath.lua @@ -0,0 +1,32 @@ +local root = script.Parent + +local LOADSTRING_PATH_PATTERN = '%[string "(.*)"%]' + +local function getCallerPath() + local level = 1 + + while true do + local path = debug.info(level, "s") + + if path then + -- Skip over any path that is a descendant of this package + if not path:find(root.Name, nil, true) then + -- Sometimes the path is represented as `[string "path.to.module"]` + -- so we match for the instance path and, if found, return it + local pathFromLoadstring = path:match(LOADSTRING_PATH_PATTERN) + + if pathFromLoadstring then + return pathFromLoadstring + else + return path + end + end + else + return nil + end + + level += 1 + end +end + +return getCallerPath diff --git a/src/init.lua b/src/init.lua index a335eed..de6b822 100644 --- a/src/init.lua +++ b/src/init.lua @@ -1,6 +1,7 @@ local Janitor = require(script.Parent.Janitor) local GoodSignal = require(script.Parent.GoodSignal) local bind = require(script.bind) +local getCallerPath = require(script.getCallerPath) local getEnv = require(script.getEnv) --[=[ @@ -17,7 +18,12 @@ local getEnv = require(script.getEnv) local ModuleLoader = {} ModuleLoader.__index = ModuleLoader -export type Class = typeof(ModuleLoader.new()) +export type CachedModule = { + module: ModuleScript, + isLoaded: boolean, + result: any, + consumers: { ModuleScript }, +} --[=[ Constructs a new ModuleLoader instance. @@ -27,6 +33,7 @@ function ModuleLoader.new() self._cache = {} self._loadstring = loadstring + self._debugInfo = debug.info self._janitor = Janitor.new() --[=[ @@ -56,19 +63,17 @@ function ModuleLoader.new() end function ModuleLoader:_loadCachedModule(module: ModuleScript) - local returnValues = self._cache[module] - local success = returnValues[1] - local result = returnValues[2] + local cachedModule: CachedModule = self._cache[module:GetFullName()] assert( - success, + cachedModule.isLoaded, "Requested module experienced an error while loading MODULE: " .. module:GetFullName() .. " - RESULT: " - .. tostring(result) + .. tostring(cachedModule.result) ) - return result + return cachedModule.result end --[=[ @@ -88,6 +93,19 @@ function ModuleLoader:_getSource(module: ModuleScript): any? return if success then result else nil end +function ModuleLoader:_clearConsumerFromCache(moduleFullName: string) + local cachedModule: CachedModule = self._cache[moduleFullName] + + if cachedModule then + for _, consumer in ipairs(cachedModule.consumers) do + self._cache[consumer] = nil + self:_clearConsumerFromCache(consumer) + end + + self._cache[moduleFullName] = nil + end +end + --[=[ Tracks the changes to a required module's ancestry and `Source`. @@ -104,6 +122,7 @@ function ModuleLoader:_trackChanges(module: ModuleScript) self._janitor:Add(module.Changed:Connect(function(prop: string) if prop == "Source" then + self:_clearConsumerFromCache(module:GetFullName()) self.loadedModuleChanged:Fire(module) end end)) @@ -124,8 +143,15 @@ end loader:cache(moduleInstance, module) ``` ]=] -function ModuleLoader:cache(module: ModuleScript, source: any) - self._cache[module] = { true, source } +function ModuleLoader:cache(module: ModuleScript, result: any) + local cachedModule: CachedModule = { + module = module, + result = result, + isLoaded = true, + consumers = {}, + } + + self._cache[module:GetFullName()] = cachedModule end --[=[ @@ -141,7 +167,14 @@ end ``` ]=] function ModuleLoader:require(module: ModuleScript) - if self._cache[module] then + local cachedModule = self._cache[module:GetFullName()] + local callerPath = getCallerPath() + + if cachedModule then + if self._cache[callerPath] then + table.insert(cachedModule.consumers, callerPath) + end + return self:_loadCachedModule(module) end @@ -152,18 +185,29 @@ function ModuleLoader:require(module: ModuleScript) error(("Could not parse %s: %s"):format(module:GetFullName(), parseError)) end + local newCachedModule: CachedModule = { + module = module, + result = nil, + isLoaded = false, + consumers = { + if self._cache[callerPath] then callerPath else nil, + }, + } + self._cache[module:GetFullName()] = newCachedModule + local env = getEnv(module) env.require = bind(self, self.require) setfenv(moduleFn, env) local success, result = xpcall(moduleFn, debug.traceback) - if not success then + if success then + newCachedModule.isLoaded = true + newCachedModule.result = result + else error(("Error requiring %s: %s"):format(module.Name, result)) end - self._cache[module] = { success, result } - self:_trackChanges(module) return self:_loadCachedModule(module) @@ -195,4 +239,6 @@ function ModuleLoader:clear() self._janitor:Cleanup() end +export type Class = typeof(ModuleLoader.new()) + return ModuleLoader diff --git a/src/init.spec.lua b/src/init.spec.lua index 65cca0b..b90af5a 100644 --- a/src/init.spec.lua +++ b/src/init.spec.lua @@ -1,4 +1,6 @@ return function() + local ReplicatedStorage = game:GetService("ReplicatedStorage") + local Mock = require(script.Parent.Parent.Mock) local ModuleLoader = require(script.Parent) @@ -110,13 +112,13 @@ return function() end) describe("cache", function() - it("should add a module and its source to the cache", function() + it("should add a module and its result to the cache", function() loader:cache(mockModuleInstance, mockModule) - local cachedModule = loader._cache[mockModuleInstance] + local cachedModule = loader._cache[mockModuleInstance:GetFullName()] expect(cachedModule).to.be.ok() - expect(cachedModule[2]).to.equal(mockModule) + expect(cachedModule.result).to.equal(mockModule) end) end) @@ -128,7 +130,7 @@ return function() it("should add the module to the cache", function() loader:require(mockModuleInstance) - expect(loader._cache[mockModuleInstance]).to.be.ok() + expect(loader._cache[mockModuleInstance:GetFullName()]).to.be.ok() end) end) @@ -143,4 +145,89 @@ return function() expect(countDict(loader._cache)).to.equal(0) end) end) + + -- For these tests to work, TestEZ must be run from a plugin context so that + -- loadstring works, along with assigning to the `Source` property of + -- modules + describe("consumers", function() + local modules = Instance.new("Folder") :: Folder & { + ModuleA: ModuleScript, + ModuleB: ModuleScript, + ModuleC: ModuleScript, + } + + beforeEach(function() + local moduleA = Instance.new("ModuleScript") + moduleA.Name = "ModuleA" + moduleA.Source = [[ + require(script.Parent.ModuleB) + + return "ModuleA" + ]] + moduleA.Parent = modules + + local moduleB = Instance.new("ModuleScript") + moduleB.Name = "ModuleB" + moduleB.Source = [[ + return "ModuleB" + ]] + moduleB.Parent = modules + + local moduleC = Instance.new("ModuleScript") + moduleC.Name = "ModuleC" + moduleC.Source = [[ + return "ModuleC" + ]] + moduleC.Parent = modules + + modules.Parent = game + + loader._loadstring = loadstring + end) + + afterEach(function() + modules:ClearAllChildren() + end) + + it("should keep track of the consumers for a module", function() + loader:require(modules.ModuleA) + + expect(loader._cache[modules.ModuleA:GetFullName()]).to.be.ok() + + local cachedModuleB = loader._cache[modules.ModuleB:GetFullName()] + + expect(cachedModuleB).to.be.ok() + expect(#cachedModuleB.consumers).to.equal(1) + expect(cachedModuleB.consumers[1]).to.equal(modules.ModuleA:GetFullName()) + end) + + it("should remove all consumers of a changed module from the cache", function() + loader:require(modules.ModuleA) + + expect(next(loader._cache)).to.be.ok() + + task.defer(function() + modules.ModuleB.Source = 'return "ModuleB Reloaded"' + end) + loader.loadedModuleChanged:Wait() + + expect(next(loader._cache)).never.to.be.ok() + end) + + it("should not interfere with other cached modules", function() + loader:require(modules.ModuleA) + loader:require(modules.ModuleC) + + expect(next(loader._cache)).to.be.ok() + + task.defer(function() + modules.ModuleB.Source = 'return "ModuleB Reloaded"' + end) + loader.loadedModuleChanged:Wait() + + expect(loader._cache[modules.ModuleA:GetFullName()]).never.to.be.ok() + expect(loader._cache[modules.ModuleB:GetFullName()]).never.to.be.ok() + expect(loader._cache[modules.ModuleC:GetFullName()]).to.be.ok() + end) + end) end