Skip to content

Commit

Permalink
Clear all consumers when a module changes (#9)
Browse files Browse the repository at this point in the history
* Add the beginning of consumer cache invalidation

* Clear the consumer cache on source change

* Update .gitignore

* Add test for ensuring no interference

Also removes the ConsumerTest folder for generating on the fly

* Unfocus

* Remove debug prints

* Create a module for getCallerPath
  • Loading branch information
vocksel authored Jun 11, 2022
1 parent a472460 commit 8a7a226
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Rojo
/*.rbxl*
/*.rbxm*
sourcemap.json

# Selene
/roblox.toml
Expand Down
2 changes: 2 additions & 0 deletions scripts/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
rojo build dev.project.json -o studio-tests.rbxl
run-in-roblox --place studio-tests.rbxl --script tests/init.server.lua
32 changes: 32 additions & 0 deletions src/getCallerPath.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
local root = script.Parent

local LOADSTRING_PATH_PATTERN = '%[string "(.*)"%]'

local function getCallerPath()
local level = 1

while true do
local path = debug.info(level, "s")

if path then
-- Skip over any path that is a descendant of this package
if not path:find(root.Name, nil, true) then
-- Sometimes the path is represented as `[string "path.to.module"]`
-- so we match for the instance path and, if found, return it
local pathFromLoadstring = path:match(LOADSTRING_PATH_PATTERN)

if pathFromLoadstring then
return pathFromLoadstring
else
return path
end
end
else
return nil
end

level += 1
end
end

return getCallerPath
72 changes: 59 additions & 13 deletions src/init.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local Janitor = require(script.Parent.Janitor)
local GoodSignal = require(script.Parent.GoodSignal)
local bind = require(script.bind)
local getCallerPath = require(script.getCallerPath)
local getEnv = require(script.getEnv)

--[=[
Expand All @@ -17,7 +18,12 @@ local getEnv = require(script.getEnv)
local ModuleLoader = {}
ModuleLoader.__index = ModuleLoader

export type Class = typeof(ModuleLoader.new())
export type CachedModule = {
module: ModuleScript,
isLoaded: boolean,
result: any,
consumers: { ModuleScript },
}

--[=[
Constructs a new ModuleLoader instance.
Expand All @@ -27,6 +33,7 @@ function ModuleLoader.new()

self._cache = {}
self._loadstring = loadstring
self._debugInfo = debug.info
self._janitor = Janitor.new()

--[=[
Expand Down Expand Up @@ -56,19 +63,17 @@ function ModuleLoader.new()
end

function ModuleLoader:_loadCachedModule(module: ModuleScript)
local returnValues = self._cache[module]
local success = returnValues[1]
local result = returnValues[2]
local cachedModule: CachedModule = self._cache[module:GetFullName()]

assert(
success,
cachedModule.isLoaded,
"Requested module experienced an error while loading MODULE: "
.. module:GetFullName()
.. " - RESULT: "
.. tostring(result)
.. tostring(cachedModule.result)
)

return result
return cachedModule.result
end

--[=[
Expand All @@ -88,6 +93,19 @@ function ModuleLoader:_getSource(module: ModuleScript): any?
return if success then result else nil
end

function ModuleLoader:_clearConsumerFromCache(moduleFullName: string)
local cachedModule: CachedModule = self._cache[moduleFullName]

if cachedModule then
for _, consumer in ipairs(cachedModule.consumers) do
self._cache[consumer] = nil
self:_clearConsumerFromCache(consumer)
end

self._cache[moduleFullName] = nil
end
end

--[=[
Tracks the changes to a required module's ancestry and `Source`.
Expand All @@ -104,6 +122,7 @@ function ModuleLoader:_trackChanges(module: ModuleScript)

self._janitor:Add(module.Changed:Connect(function(prop: string)
if prop == "Source" then
self:_clearConsumerFromCache(module:GetFullName())
self.loadedModuleChanged:Fire(module)
end
end))
Expand All @@ -124,8 +143,15 @@ end
loader:cache(moduleInstance, module)
```
]=]
function ModuleLoader:cache(module: ModuleScript, source: any)
self._cache[module] = { true, source }
function ModuleLoader:cache(module: ModuleScript, result: any)
local cachedModule: CachedModule = {
module = module,
result = result,
isLoaded = true,
consumers = {},
}

self._cache[module:GetFullName()] = cachedModule
end

--[=[
Expand All @@ -141,7 +167,14 @@ end
```
]=]
function ModuleLoader:require(module: ModuleScript)
if self._cache[module] then
local cachedModule = self._cache[module:GetFullName()]
local callerPath = getCallerPath()

if cachedModule then
if self._cache[callerPath] then
table.insert(cachedModule.consumers, callerPath)
end

return self:_loadCachedModule(module)
end

Expand All @@ -152,18 +185,29 @@ function ModuleLoader:require(module: ModuleScript)
error(("Could not parse %s: %s"):format(module:GetFullName(), parseError))
end

local newCachedModule: CachedModule = {
module = module,
result = nil,
isLoaded = false,
consumers = {
if self._cache[callerPath] then callerPath else nil,
},
}
self._cache[module:GetFullName()] = newCachedModule

local env = getEnv(module)
env.require = bind(self, self.require)
setfenv(moduleFn, env)

local success, result = xpcall(moduleFn, debug.traceback)

if not success then
if success then
newCachedModule.isLoaded = true
newCachedModule.result = result
else
error(("Error requiring %s: %s"):format(module.Name, result))
end

self._cache[module] = { success, result }

self:_trackChanges(module)

return self:_loadCachedModule(module)
Expand Down Expand Up @@ -195,4 +239,6 @@ function ModuleLoader:clear()
self._janitor:Cleanup()
end

export type Class = typeof(ModuleLoader.new())

return ModuleLoader
95 changes: 91 additions & 4 deletions src/init.spec.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
return function()
local ReplicatedStorage = game:GetService("ReplicatedStorage")

local Mock = require(script.Parent.Parent.Mock)
local ModuleLoader = require(script.Parent)

Expand Down Expand Up @@ -110,13 +112,13 @@ return function()
end)

describe("cache", function()
it("should add a module and its source to the cache", function()
it("should add a module and its result to the cache", function()
loader:cache(mockModuleInstance, mockModule)

local cachedModule = loader._cache[mockModuleInstance]
local cachedModule = loader._cache[mockModuleInstance:GetFullName()]

expect(cachedModule).to.be.ok()
expect(cachedModule[2]).to.equal(mockModule)
expect(cachedModule.result).to.equal(mockModule)
end)
end)

Expand All @@ -128,7 +130,7 @@ return function()

it("should add the module to the cache", function()
loader:require(mockModuleInstance)
expect(loader._cache[mockModuleInstance]).to.be.ok()
expect(loader._cache[mockModuleInstance:GetFullName()]).to.be.ok()
end)
end)

Expand All @@ -143,4 +145,89 @@ return function()
expect(countDict(loader._cache)).to.equal(0)
end)
end)

-- For these tests to work, TestEZ must be run from a plugin context so that
-- loadstring works, along with assigning to the `Source` property of
-- modules
describe("consumers", function()
local modules = Instance.new("Folder") :: Folder & {
ModuleA: ModuleScript,
ModuleB: ModuleScript,
ModuleC: ModuleScript,
}

beforeEach(function()
local moduleA = Instance.new("ModuleScript")
moduleA.Name = "ModuleA"
moduleA.Source = [[
require(script.Parent.ModuleB)
return "ModuleA"
]]
moduleA.Parent = modules

local moduleB = Instance.new("ModuleScript")
moduleB.Name = "ModuleB"
moduleB.Source = [[
return "ModuleB"
]]
moduleB.Parent = modules

local moduleC = Instance.new("ModuleScript")
moduleC.Name = "ModuleC"
moduleC.Source = [[
return "ModuleC"
]]
moduleC.Parent = modules

modules.Parent = game

loader._loadstring = loadstring
end)

afterEach(function()
modules:ClearAllChildren()
end)

it("should keep track of the consumers for a module", function()
loader:require(modules.ModuleA)

expect(loader._cache[modules.ModuleA:GetFullName()]).to.be.ok()

local cachedModuleB = loader._cache[modules.ModuleB:GetFullName()]

expect(cachedModuleB).to.be.ok()
expect(#cachedModuleB.consumers).to.equal(1)
expect(cachedModuleB.consumers[1]).to.equal(modules.ModuleA:GetFullName())
end)

it("should remove all consumers of a changed module from the cache", function()
loader:require(modules.ModuleA)

expect(next(loader._cache)).to.be.ok()

task.defer(function()
modules.ModuleB.Source = 'return "ModuleB Reloaded"'
end)
loader.loadedModuleChanged:Wait()

expect(next(loader._cache)).never.to.be.ok()
end)

it("should not interfere with other cached modules", function()
loader:require(modules.ModuleA)
loader:require(modules.ModuleC)

expect(next(loader._cache)).to.be.ok()

task.defer(function()
modules.ModuleB.Source = 'return "ModuleB Reloaded"'
end)
loader.loadedModuleChanged:Wait()

expect(loader._cache[modules.ModuleA:GetFullName()]).never.to.be.ok()
expect(loader._cache[modules.ModuleB:GetFullName()]).never.to.be.ok()
expect(loader._cache[modules.ModuleC:GetFullName()]).to.be.ok()
end)
end)
end

0 comments on commit 8a7a226

Please sign in to comment.