diff --git a/.lune/test-cloud.luau b/.lune/test-cloud.luau index 755ea6f..7be749a 100644 --- a/.lune/test-cloud.luau +++ b/.lune/test-cloud.luau @@ -10,6 +10,7 @@ local apiKey = assert(args.apiKey, "--apiKey must be supplied with a valid Open local testPlaceFile = "test-place.rbxl" +run("rm", { "-rf", project.BUILD_PATH }) compile("dev") run("rojo", { "build", project.ROJO_TESTS_PROJECT, "-o", testPlaceFile }) diff --git a/src/ModuleLoader.luau b/src/ModuleLoader.luau index e4fb142..23d4bf9 100644 --- a/src/ModuleLoader.luau +++ b/src/ModuleLoader.luau @@ -1,61 +1,25 @@ -local Janitor = require("@pkg/Janitor") +local createModuleLoader = require("./createModuleLoader") +local types = require("./types") -local bind = require("./bind") -local createTablePassthrough = require("./createTablePassthrough") -local getCallerPath = require("./getCallerPath") -local getEnv = require("./getEnv") -local getRobloxTsRuntime = require("./getRobloxTsRuntime") - -type CachedModuleResult = any - -type ModuleConsumers = { - [string]: boolean, -} - --- Each module gets its own global table that it can modify via _G. This makes --- it easy to clear out a module and the globals it defines without impacting --- other modules. A module's function environment has all globals merged --- together on _G -type ModuleGlobals = { - [any]: any, -} - -type CachedModule = { - instance: ModuleScript, - isLoaded: boolean, - result: CachedModuleResult, - consumers: ModuleConsumers, - globals: ModuleGlobals, -} +type ModuleLoader = types.ModuleLoader type ModuleLoaderProps = { - _cache: { [string]: CachedModule }, - _loadstring: typeof(loadstring), - _debugInfo: typeof(debug.info), - _janitors: { [string]: any }, - _globals: { [any]: any }, - - _loadedModuleChangedBindable: BindableEvent, + _loader: ModuleLoader, loadedModuleChanged: RBXScriptSignal, } type ModuleLoaderImpl = { __index: ModuleLoaderImpl, - new: () -> ModuleLoader, - - require: (self: ModuleLoader, moduleScript: ModuleScript) -> any, - cache: (self: ModuleLoader, moduleScript: ModuleScript, result: any) -> (), - clearModule: (self: ModuleLoader, moduleScript: ModuleScript) -> (), - clear: (self: ModuleLoader) -> (), + new: () -> ModuleLoaderClass, - _loadCachedModule: (self: ModuleLoader, moduleScript: ModuleScript) -> CachedModuleResult, - _getSource: (self: ModuleLoader, moduleScript: ModuleScript) -> string, - _trackChanges: (self: ModuleLoader, moduleScript: ModuleScript) -> (), - _getConsumers: (self: ModuleLoader, moduleScript: ModuleScript) -> { ModuleScript }, + require: (self: ModuleLoaderClass, moduleScript: ModuleScript) -> any, + cache: (self: ModuleLoaderClass, moduleScript: ModuleScript, result: any) -> (), + clearModule: (self: ModuleLoaderClass, moduleScript: ModuleScript) -> (), + clear: (self: ModuleLoaderClass) -> (), } -export type ModuleLoader = typeof(setmetatable({} :: ModuleLoaderProps, {} :: ModuleLoaderImpl)) +export type ModuleLoaderClass = typeof(setmetatable({} :: ModuleLoaderProps, {} :: ModuleLoaderImpl)) --[=[ ModuleScript loader that bypasses Roblox's require cache. @@ -77,12 +41,7 @@ ModuleLoader.__index = ModuleLoader function ModuleLoader.new() local self = {} - self._cache = {} - self._loadstring = loadstring - self._debugInfo = debug.info - self._janitors = {} - self._globals = {} - self._loadedModuleChangedBindable = Instance.new("BindableEvent") + self._loader = createModuleLoader() --[=[ Fired when any ModuleScript required through this class has its ancestry @@ -105,70 +64,11 @@ function ModuleLoader.new() @prop loadedModuleChanged RBXScriptSignal @within ModuleLoader ]=] - self.loadedModuleChanged = self._loadedModuleChangedBindable.Event + self.loadedModuleChanged = self._loader.loadedModuleChanged return setmetatable(self, ModuleLoader) end -function ModuleLoader:_loadCachedModule(moduleScript: ModuleScript) - local cachedModule: CachedModule = self._cache[moduleScript:GetFullName()] - - assert( - cachedModule.isLoaded, - "Requested module experienced an error while loading MODULE: " - .. moduleScript:GetFullName() - .. " - RESULT: " - .. tostring(cachedModule.result) - ) - - return cachedModule.result -end - ---[=[ - Gets the Source of a ModuleScript. - - This method exists primarily so we can better write unit tests. Attempting - to index the Source property from a regular script context throws an error, - so this method allows us to safely fallback in tests. - - @private -]=] -function ModuleLoader:_getSource(moduleScript: ModuleScript): string - local success, result = pcall(function() - return moduleScript.Source - end) - - return if success then result else "" -end - ---[=[ - Tracks the changes to a required module's ancestry and `Source`. - - When ancestry or `Source` changes, the `loadedModuleChanged` event is fired. - When this happens, the user should clear the cache and require the root - module again to reload. - - @private -]=] -function ModuleLoader:_trackChanges(moduleScript: ModuleScript) - local existingJanitor = self._janitors[moduleScript:GetFullName()] - local janitor = if existingJanitor then existingJanitor else Janitor.new() - - janitor:Cleanup() - - janitor:Add(moduleScript.AncestryChanged:Connect(function() - self:clearModule(moduleScript) - end)) - - janitor:Add(moduleScript.Changed:Connect(function(prop: string) - if prop == "Source" then - self:clearModule(moduleScript) - end - end)) - - self._janitors[moduleScript:GetFullName()] = janitor -end - --[=[ Set the cached value for a module before it is loaded. @@ -185,15 +85,7 @@ end ``` ]=] function ModuleLoader:cache(moduleScript: ModuleScript, result: any) - local cachedModule: CachedModule = { - instance = moduleScript, - result = result, - isLoaded = true, - consumers = {}, - globals = createTablePassthrough(self._globals), - } - - self._cache[moduleScript:GetFullName()] = cachedModule + self._loader.cache(moduleScript, result) end --[=[ @@ -209,117 +101,11 @@ end ``` ]=] function ModuleLoader:require(moduleScript: ModuleScript) - local cachedModule = self._cache[moduleScript:GetFullName()] - local callerPath = getCallerPath() - - if not callerPath then - return nil - end - - if cachedModule then - cachedModule.consumers[callerPath] = true - return self:_loadCachedModule(moduleScript) - end - - local source = self:_getSource(moduleScript) - local moduleFn, parseError = self._loadstring(source, moduleScript:GetFullName()) - - if not moduleFn then - local message = if parseError then parseError else "" - error(`Could not parse {moduleScript:GetFullName()}: {message}`) - end - - local globals = createTablePassthrough(self._globals) - - local newCachedModule: CachedModule = { - instance = moduleScript, - result = nil, - isLoaded = false, - consumers = { - [callerPath] = true, - }, - globals = globals, - } - self._cache[moduleScript:GetFullName()] = newCachedModule - - local env: any = getEnv(moduleScript, globals) - env.require = bind(self, self.require) - setfenv(moduleFn, env) - - local success, result = xpcall(moduleFn, debug.traceback) - - if success then - newCachedModule.isLoaded = true - newCachedModule.result = result - else - error(("Error requiring %s: %s"):format(moduleScript.Name, result)) - end - - self:_trackChanges(moduleScript) - - return self:_loadCachedModule(moduleScript) -end - -function ModuleLoader:_getConsumers(moduleScript: ModuleScript): { ModuleScript } - local function getConsumersRecursively(cachedModule: CachedModule, found: { [ModuleScript]: true }) - for consumer in cachedModule.consumers do - local cachedConsumer = self._cache[consumer] - - if cachedConsumer then - if not found[cachedConsumer.instance] then - found[cachedConsumer.instance] = true - getConsumersRecursively(cachedConsumer, found) - end - end - end - end - - local cachedModule: CachedModule = self._cache[moduleScript:GetFullName()] - local found = {} - - getConsumersRecursively(cachedModule, found) - - local consumers = {} - for consumer in found do - table.insert(consumers, consumer) - end - - return consumers + return self._loader.require(moduleScript) end function ModuleLoader:clearModule(moduleToClear: ModuleScript) - if not self._cache[moduleToClear:GetFullName()] then - return - end - - local consumers = self:_getConsumers(moduleToClear) - local modulesToClear = { moduleToClear, table.unpack(consumers) } - - local index = table.find(modulesToClear, getRobloxTsRuntime()) - if index then - table.remove(modulesToClear, index) - end - - for _, moduleScript in modulesToClear do - local fullName = moduleScript:GetFullName() - - local cachedModule = self._cache[fullName] - - if cachedModule then - self._cache[fullName] = nil - - for key in cachedModule.globals do - self._globals[key] = nil - end - - local janitor = self._janitors[fullName] - janitor:Cleanup() - end - end - - for _, moduleScript in modulesToClear do - self._loadedModuleChangedBindable:Fire(moduleScript) - end + self._loader.clearModule(moduleToClear) end --[=[ @@ -344,13 +130,7 @@ end ``` ]=] function ModuleLoader:clear() - self._cache = {} - self._globals = {} - - for _, janitor in self._janitors do - janitor:Cleanup() - end - self._janitors = {} + self._loader.clear() end return ModuleLoader diff --git a/src/ModuleLoader.spec.luau b/src/ModuleLoader.spec.luau deleted file mode 100644 index 15c77d0..0000000 --- a/src/ModuleLoader.spec.luau +++ /dev/null @@ -1,643 +0,0 @@ -local ReplicatedStorage = game:GetService("ReplicatedStorage") - -local JestGlobals = require("@pkg/JestGlobals") -local test = JestGlobals.test -local expect = JestGlobals.expect -local describe = JestGlobals.describe -local beforeEach = JestGlobals.beforeEach -local afterEach = JestGlobals.afterEach - -local ModuleLoader = require("./ModuleLoader") - -local function countDict(dict: { [string]: any }) - local count = 0 - for _ in pairs(dict) do - count += 1 - end - return count -end - -type ModuleTestTree = { - [string]: string | ModuleTestTree, -} -local testNumber = 0 -local function createModuleTest(tree: ModuleTestTree, parent: Instance?): any - testNumber += 1 - - local root = Instance.new("Folder") - root.Name = "ModuleTest" .. testNumber - - parent = if parent then parent else root - - for name, sourceOrDescendants in tree do - if typeof(sourceOrDescendants) == "table" then - createModuleTest(sourceOrDescendants, parent) - else - local module = Instance.new("ModuleScript") - module.Name = name - module.Source = sourceOrDescendants - module.Parent = parent - end - end - - root.Parent = game - - return root -end - -local mockModuleSource = {} -local loader: ModuleLoader.ModuleLoader -local tree - -beforeEach(function() - loader = ModuleLoader.new() -end) - -afterEach(function() - loader:clear() - - if tree then - tree:Destroy() - end -end) - -describe("_getSource", function() - -- This test doesn't supply much value. Essentially, the "Source" - -- property requires elevated permissions, so we need the _getSource - -- method so that that if tests are being run from within a normal - -- script context that an error will not be produced. - test("returns the Source property if it can be indexed", function() - local mockModuleInstance = Instance.new("ModuleScript") - - local canIndex = pcall(function() - return mockModuleInstance.Source - end) - - local source = loader:_getSource(mockModuleInstance) - - if canIndex then - expect(source).toBeDefined() - else - expect(source).toBeUndefined() - end - end) -end) - -describe("_trackChanges", function() - test("creates a Janitor instance if it doesn't exist", function() - local mockModuleInstance = Instance.new("ModuleScript") - - expect(loader._janitors[mockModuleInstance.Name]).toBeUndefined() - - loader:_trackChanges(mockModuleInstance) - - expect(loader._janitors[mockModuleInstance.Name]).toBeDefined() - end) - - test("reuses the same Janitor instance for future calls", function() - local mockModuleInstance = Instance.new("ModuleScript") - - loader:_trackChanges(mockModuleInstance) - - local janitor = loader._janitors[mockModuleInstance.Name] - - loader:_trackChanges(mockModuleInstance) - - expect(loader._janitors[mockModuleInstance.Name]).toBe(janitor) - end) -end) - -describe("loadedModuleChanged", function() - test("fires when a required module has its ancestry changed", function() - local mockModuleInstance = Instance.new("ModuleScript") - - local wasFired = false - - -- Parent the ModuleScript somewhere in the DataModel so we can - -- listen for AncestryChanged. - mockModuleInstance.Parent = game - - loader.loadedModuleChanged:Connect(function(other: ModuleScript) - if other == mockModuleInstance then - wasFired = true - end - end) - - -- Require the module so that events get setup - loader:require(mockModuleInstance) - - -- Trigger AncestryChanged to fire - mockModuleInstance.Parent = nil - - expect(wasFired).toBe(true) - end) - - test("fires when a required module has its Source property change", function() - local mockModuleInstance = Instance.new("ModuleScript") - - local wasFired = false - loader.loadedModuleChanged:Connect(function(other: ModuleScript) - if other == mockModuleInstance then - wasFired = true - end - end) - - -- Require the module so that events get setup - loader:require(mockModuleInstance) - - mockModuleInstance.Source = "Something different" - - expect(wasFired).toBe(true) - end) - - test("fires for every consumer up the chain", function() - tree = createModuleTest({ - ModuleA = [[ - return "ModuleA" - ]], - ModuleB = [[ - require(script.Parent.ModuleA) - return "ModuleB" - ]], - ModuleC = [[ - require(script.Parent.ModuleB) - return "ModuleC" - ]], - }) - - local count = 0 - loader.loadedModuleChanged:Connect(function(module) - for _, child in tree:GetChildren() do - if module == child then - count += 1 - end - end - end) - - loader:require(tree.ModuleC) - - tree.ModuleA.Source = "Changed" - - expect(count).toBe(3) - end) -end) - -describe("cache", function() - test("adds a module and its result to the cache", function() - local mockModuleInstance = Instance.new("ModuleScript") - - loader:cache(mockModuleInstance, mockModuleSource) - - local cachedModule = loader._cache[mockModuleInstance:GetFullName()] - - expect(cachedModule).toBeDefined() - expect(cachedModule.result).toBe(mockModuleSource) - end) -end) - -describe("require", function() - test("adds the module to the cache", function() - local mockModuleInstance = Instance.new("ModuleScript") - - loader:require(mockModuleInstance) - expect(loader._cache[mockModuleInstance:GetFullName()]).toBeDefined() - end) - - test("returns cached results", function() - tree = createModuleTest({ - -- We return a table since it can act as a unique symbol. So if - -- both consumers are getting the same table we can perform an - -- equality check - SharedModule = [[ - local module = {} - return module - ]], - Consumer1 = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - Consumer2 = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - }) - - local sharedModuleFromConsumer1 = loader:require(tree.Consumer1) - local sharedModuleFromConsumer2 = loader:require(tree.Consumer2) - - expect(sharedModuleFromConsumer1).toBe(sharedModuleFromConsumer2) - end) - - test("adds the calling script as a consumer", function() - tree = createModuleTest({ - SharedModule = [[ - local module = {} - return module - ]], - Consumer = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - }) - - loader:require(tree.Consumer) - - local cachedModule = loader._cache[tree.SharedModule:GetFullName()] - - expect(cachedModule).toBeDefined() - expect(cachedModule.consumers[tree.Consumer:GetFullName()]).toBeDefined() - end) - - test("updates consumers when requiring a cached module from a different script", function() - tree = createModuleTest({ - SharedModule = [[ - local module = {} - return module - ]], - Consumer1 = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - Consumer2 = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - }) - - loader:require(tree.Consumer1) - - local cachedModule = loader._cache[tree.SharedModule:GetFullName()] - - expect(cachedModule.consumers[tree.Consumer1:GetFullName()]).toBeDefined() - expect(cachedModule.consumers[tree.Consumer2:GetFullName()]).toBeUndefined() - - loader:require(tree.Consumer2) - - expect(cachedModule.consumers[tree.Consumer1:GetFullName()]).toBeDefined() - expect(cachedModule.consumers[tree.Consumer2:GetFullName()]).toBeDefined() - end) - - test("keeps track of _G between modules", function() - tree = createModuleTest({ - WriteGlobal = [[ - _G.foo = true - return nil - ]], - ReadGlobal = [[ - return _G.foo - ]], - }) - - loader:require(tree.WriteGlobal) - - expect(loader._globals.foo).toBe(true) - - local result = loader:require(tree.ReadGlobal) - - expect(result).toBe(true) - end) - - test("keeps track of _G in nested requires", function() - tree = createModuleTest({ - DefineGlobal = [[ - _G.foo = true - return nil - ]], - UseGlobal = [[ - require(script.Parent.DefineGlobal) - return _G.foo - ]], - }) - - local result = loader:require(tree.UseGlobal) - - expect(result).toBe(true) - - loader:clear() - - expect(loader._globals.foo).toBeUndefined() - end) - - test("adds globals on _G to the cachedModule's globals", function() - tree = createModuleTest({ - DefineGlobal = [[ - _G.foo = true - return nil - ]], - }) - - loader:require(tree.DefineGlobal) - - local cachedModule = loader._cache[tree.DefineGlobal:GetFullName()] - expect(cachedModule.globals.foo).toBe(true) - end) -end) - -describe("clearModule", function() - test("clears a module from the cache", function() - tree = createModuleTest({ - Module = [[ - return "Module" - ]], - }) - - loader:require(tree.Module) - - expect(loader._cache[tree.Module:GetFullName()]).toBeDefined() - - loader:clearModule(tree.Module) - - expect(loader._cache[tree.Module:GetFullName()]).toBeUndefined() - end) - - test("clears all consumers of a module from the cache", function() - tree = createModuleTest({ - SharedModule = [[ - local module = {} - return module - ]], - Consumer1 = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - Consumer2 = [[ - local sharedModule = require(script.Parent.SharedModule) - return sharedModule - ]], - }) - - loader:require(tree.Consumer1) - loader:require(tree.Consumer2) - - expect(loader._cache[tree.Consumer1:GetFullName()]).toBeDefined() - expect(loader._cache[tree.Consumer2:GetFullName()]).toBeDefined() - expect(loader._cache[tree.SharedModule:GetFullName()]).toBeDefined() - - loader:clearModule(tree.SharedModule) - - expect(loader._cache[tree.Consumer1:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.Consumer2:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.SharedModule:GetFullName()]).toBeUndefined() - end) - - test("only clears modules in the consumer chain", function() - tree = createModuleTest({ - Module = [[ - return nil - ]], - Consumer = [[ - require(script.Parent.Module) - return nil - ]], - Independent = [[ - return nil - ]], - }) - - loader:require(tree.Consumer) - loader:require(tree.Independent) - - expect(countDict(loader._cache)).toBe(3) - - loader:clearModule(tree.Module) - - expect(countDict(loader._cache)).toBe(1) - expect(loader._cache[tree.Independent:GetFullName()]).toBeDefined() - end) - - test("clears all globals that a module supplied", function() - tree = createModuleTest({ - DefineGlobalFoo = [[ - _G.foo = true - return nil - ]], - DefineGlobalBar = [[ - _G.bar = false - return nil - ]], - }) - - loader:require(tree.DefineGlobalFoo) - loader:require(tree.DefineGlobalBar) - - loader:clearModule(tree.DefineGlobalBar) - - expect(loader._globals.foo).toBeDefined() - expect(loader._globals.bar).toBeUndefined() - end) - - test("fires loadedModuleChanged when clearing a module", function() - tree = createModuleTest({ - Module = [[ - return nil - ]], - Consumer = [[ - require(script.Parent.Module) - return nil - ]], - }) - - local wasFired = false - - loader.loadedModuleChanged:Connect(function() - wasFired = true - end) - - loader:require(tree.Consumer) - loader:clearModule(tree.Consumer) - - expect(wasFired).toBe(true) - end) - - test("fires loadedModuleChanged for every module up the chain", function() - tree = createModuleTest({ - Module3 = [[ - return {} - ]], - Module2 = [[ - require(script.Parent.Module3) - return {} - ]], - Module1 = [[ - require(script.Parent.Module2) - return {} - ]], - Consumer = [[ - require(script.Parent.Module1) - return nil - ]], - }) - - local count = 0 - - loader.loadedModuleChanged:Connect(function() - count += 1 - end) - - loader:require(tree.Consumer) - loader:clearModule(tree.Module3) - - expect(count).toBe(4) - end) - - test("never fires loadedModuleChanged for a module that hasn't been required", function() - local wasFired = false - - loader.loadedModuleChanged:Connect(function() - wasFired = true - end) - - -- Do nothing if the module hasn't been cached - local module = Instance.new("ModuleScript") - loader:clearModule(module) - expect(wasFired).toBe(false) - end) -end) - -describe("clear", function() - test("removes all modules from the cache", function() - local mockModuleInstance = Instance.new("ModuleScript") - - loader:cache(mockModuleInstance, mockModuleSource) - - expect(countDict(loader._cache)).toBe(1) - - loader:clear() - - expect(countDict(loader._cache)).toBe(0) - end) - - test("resets globals", function() - local globals = loader._globals - - loader:clear() - - expect(loader._globals).never.toBe(globals) - end) -end) - -describe("consumers", function() - beforeEach(function() - tree = createModuleTest({ - ModuleA = [[ - require(script.Parent.ModuleB) - - return "ModuleA" - ]], - ModuleB = [[ - return "ModuleB" - ]], - - ModuleC = [[ - return "ModuleC" - ]], - }) - end) - - test("removes all consumers of a changed module from the cache", function() - loader:require(tree.ModuleA) - - local hasItems = next(loader._cache) ~= nil - expect(hasItems).toBe(true) - - tree.ModuleB.Source = 'return "ModuleB Reloaded"' - task.wait() - - hasItems = next(loader._cache) ~= nil - expect(hasItems).toBe(false) - end) - - test("never interferes with other cached modules", function() - loader:require(tree.ModuleA) - loader:require(tree.ModuleC) - - local hasItems = next(loader._cache) ~= nil - expect(hasItems).toBe(true) - - tree.ModuleB.Source = 'return "ModuleB Reloaded"' - task.wait() - - expect(loader._cache[tree.ModuleA:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.ModuleB:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.ModuleC:GetFullName()]).toBeDefined() - end) -end) - -describe("roblox-ts", function() - local rbxtsInclude - local mockRuntime - - beforeEach(function() - rbxtsInclude = Instance.new("Folder") - rbxtsInclude.Name = "rbxts_include" - - mockRuntime = Instance.new("ModuleScript") - mockRuntime.Name = "RuntimeLib" - mockRuntime.Source = [[ - local function import(...) - return require(...) - end - return { - import = import - } - ]] - mockRuntime.Parent = rbxtsInclude - - rbxtsInclude.Parent = ReplicatedStorage - end) - - afterEach(function() - loader:clear() - rbxtsInclude:Destroy() - end) - - test("clearModule() should never clear the roblox-ts runtime from the cache", function() - -- This example isn't quite how a roblox-ts project would be setup - -- in practice since the require's for `Shared` would be using - -- `TS.import`, but it should be close enough for our test case - tree = createModuleTest({ - Shared = [[ - local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) - return {} - ]], - Module1 = [[ - local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) - local Shared = TS.import(script.Parent.Shared) - return nil - ]], - Module2 = [[ - local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) - local Shared = TS.import(script.Parent.Shared) - return nil - ]], - Root = [[ - local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) - local Module1 = TS.import(script.Parent.Module1) - local Module2 = TS.import(script.Parent.Module2) - ]], - }) - - loader:require(tree.Root) - loader:clearModule(tree.Shared) - - expect(loader._cache[mockRuntime:GetFullName()]).toBeDefined() - expect(loader._cache[tree.Shared:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.Module1:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.Module2:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.Root:GetFullName()]).toBeUndefined() - end) - - test("clear() should clear the roblox-ts runtime when calling", function() - tree = createModuleTest({ - Module = [[ - local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) - ]], - }) - - loader:require(tree.Module) - loader:clear() - - expect(loader._cache[mockRuntime:GetFullName()]).toBeUndefined() - expect(loader._cache[tree.Module:GetFullName()]).toBeUndefined() - end) -end) diff --git a/src/bind.luau b/src/bind.luau deleted file mode 100644 index 27016a6..0000000 --- a/src/bind.luau +++ /dev/null @@ -1,35 +0,0 @@ ---[=[ - Binds an instance method so that it can be called like a function. - - Usage: - - ```lua - local Class = {} - Class.__index = Class - - function Class.new() - local self = {} - self.value = "foo" - return setmetatable(self, Class) - end - - function Class:getValue() - return self.value - end - - local instance = Class.new() - local getValue = bind(instance, instance.getValue) - - print(getValue()) -- "foo" - ``` - - @within ModuleLoader - @private -]=] -local function bind(self: T, callback: (self: T, ...any) -> any) - return function(...) - return callback(self, ...) - end -end - -return bind diff --git a/src/bind.spec.luau b/src/bind.spec.luau deleted file mode 100644 index 9817b0b..0000000 --- a/src/bind.spec.luau +++ /dev/null @@ -1,38 +0,0 @@ -local JestGlobals = require("@pkg/JestGlobals") -local test = JestGlobals.test -local expect = JestGlobals.expect - -local bind = require("./bind") - -test("binds 'self' to the given callback", function() - local module = { - value = "foo", - callback = function(self) - return self.value - end, - } - - local callback = bind(module, module.callback) - - expect(callback()).toBe("foo") -end) - -test("works for the usage example", function() - local Class = {} - Class.__index = Class - - function Class.new() - local self = {} - self.value = "foo" - return setmetatable(self, Class) - end - - function Class:getValue() - return self.value - end - - local instance = Class.new() - local getValue = bind(instance, instance.getValue) - - expect(getValue()).toBe("foo") -end) diff --git a/src/cleanLoadstringStack.luau b/src/cleanLoadstringStack.luau new file mode 100644 index 0000000..94b3b43 --- /dev/null +++ b/src/cleanLoadstringStack.luau @@ -0,0 +1,19 @@ +-- upstream: https://github.com/Roblox/jest-roblox/blob/408eac1b8d210e6e07387fb341fa9b9e181de897/src/roblox-shared/src/cleanLoadStringStack.lua + +return function(line: string): string + local spacing, filePath, lineNumber, extra = line:match('(%s*)%[string "(.-)"%]:(%d+)(.*)') + if filePath then + local match = filePath + if spacing then + match = spacing .. match + end + if lineNumber then + match = match .. ":" .. lineNumber + end + if extra then + match = match .. extra + end + return match + end + return line +end diff --git a/src/createModuleLoader.luau b/src/createModuleLoader.luau new file mode 100644 index 0000000..096f32b --- /dev/null +++ b/src/createModuleLoader.luau @@ -0,0 +1,412 @@ +local Janitor = require("@pkg/Janitor") +local LuauPolyfill = require("@pkg/LuauPolyfill") + +local cleanLoadstringStack = require("./cleanLoadstringStack") +local createModuleRegistry = require("./createModuleRegistry") +local getCallerPath = require("./getCallerPath") +local getModuleSource = require("./getModuleSource") +local getRobloxTsRuntime = require("./getRobloxTsRuntime") +local types = require("./types") + +local Error = LuauPolyfill.Error + +type LoadedModule = types.LoadedModule +type LoadedModuleExports = types.LoadedModuleExports +type LoadingStrategy = types.LoadingStrategy +type LoadModuleFn = types.LoadModuleFn +type ModuleLoader = types.ModuleLoader +type ModuleRegistry = types.ModuleRegistry + +local loadmodule: (ModuleScript) -> (() -> LoadedModuleExports, string?, () -> ()) = debug["loadmodule"] +local loadModuleEnabled = pcall(function() + return loadmodule(Instance.new("ModuleScript")) +end) + +--[=[ + ModuleScript loader that bypasses Roblox's require cache. + + This class aims to solve a common problem where code needs to be run in + Studio, but once a change is made to an already required module the whole + place must be reloaded for the cache to be reset. With this class, the cache + is ignored when requiring a module so you are able to load a module, make + changes, and load it again without reloading the whole place. + + @class ModuleLoader +]=] +function createModuleLoader(): ModuleLoader + local moduleRegistry = createModuleRegistry() + local loadedModuleFns: { [ModuleScript]: { any } } = {} + local cleanupFns: { () -> () } = {} + local loadingStrategy: LoadingStrategy = "Automatic" + local janitors: { [ModuleScript]: any } = {} + + local function _getModuleRegistry() + return moduleRegistry + end + + local function setLoadingStrategy(strategy: LoadingStrategy) + loadingStrategy = strategy + end + + --[=[ + Fired when any ModuleScript required through this class has its ancestry + or `Source` property changed. This applies to the ModuleScript passed to + `ModuleLoader:require()` and every module that it subsequently requirs. + + This event is useful for reloading a module when it or any of it + dependencies change. + + ```lua + local loader = createModuleLoader() + local result = loader.require(module) + + loader.loadedModuleChanged:Connect(function() + loader.clear() + result = loader.require(module) + end) + ``` + + @prop loadedModuleChanged RBXScriptSignal + @within ModuleLoader + ]=] + local loadedModuleChanged = Instance.new("BindableEvent") + + local loadModule: LoadModuleFn + + local function getConsumers(moduleScript: ModuleScript): { ModuleScript } + local function getConsumersRecursively(loadedModule: LoadedModule, found: { [ModuleScript]: true }) + for consumer in loadedModule.consumers do + local loadedChildModule = moduleRegistry.getByInstance(consumer) + + if loadedChildModule then + if not found[loadedChildModule.instance] then + found[loadedChildModule.instance] = true + getConsumersRecursively(loadedChildModule, found) + end + end + end + end + + local loadedModule: LoadedModule? = moduleRegistry.getByInstance(moduleScript) + + if loadedModule then + local found = {} + + getConsumersRecursively(loadedModule, found) + + local consumers = {} + for consumer in found do + table.insert(consumers, consumer) + end + + return consumers + else + return {} + end + end + + local function clearModule(moduleToClear: ModuleScript) + if not moduleRegistry.getByInstance(moduleToClear) then + return + end + + local consumers = getConsumers(moduleToClear) + local modulesToClear = { moduleToClear, table.unpack(consumers) } + + local index = table.find(modulesToClear, getRobloxTsRuntime()) + if index then + table.remove(modulesToClear, index) + end + + for _, moduleScript in modulesToClear do + local loadedModule = moduleRegistry.getByInstance(moduleScript) + + if loadedModule then + local janitor = janitors[moduleScript] + janitor:Cleanup() + end + end + + for _, moduleScript in modulesToClear do + loadedModuleChanged:Fire(moduleScript) + end + end + + --[=[ + Tracks the changes to a required module's ancestry and `Source`. + + When ancestry or `Source` changes, the `loadedModuleChanged` event is fired. + When this happens, the user should clear the cache and require the root + module again to reload. + + @private + ]=] + local function trackChanges(moduleScript: ModuleScript) + local existingJanitor = janitors[moduleScript] + local janitor = if existingJanitor then existingJanitor else Janitor.new() + + janitor:Cleanup() + + janitor:Add(moduleScript.AncestryChanged:Connect(function() + clearModule(moduleScript) + end)) + + janitor:Add(moduleScript.Changed:Connect(function(prop: string) + if prop == "Source" then + clearModule(moduleScript) + end + end)) + + janitor:Add(function() + moduleRegistry.remove(moduleScript) + loadedModuleFns[moduleScript] = nil + end) + + janitors[moduleScript] = janitor + end + + --[=[ + Set the cached value for a module before it is loaded. + + This is useful is very specific situations. For example, this method is + used to cache a copy of Roact so that when a module is loaded with this + class it uses the same table instance. + + ```lua + local moduleInstance = script.Parent.ModuleScript + local moduleScript = require(moduleInstance) + + local loader = createModuleLoader() + loader.cache(moduleInstance, moduleScript) + ``` + ]=] + local function cache(moduleScript: ModuleScript, result: any) + local loadedModule: LoadedModule = { + instance = moduleScript, + exports = result, + isLoaded = true, + dependencies = {}, + consumers = {}, + } + + moduleRegistry.add(moduleScript, loadedModule) + end + + local function execModule(loadedModule: LoadedModule) + -- This method is adapted from: + -- https://github.com/Roblox/jest-roblox/blob/408eac/src/jest-runtime/src/init.lua#L1847-L2102 + + local moduleFunction, defaultEnvironment, errorMessage, cleanupFn + local moduleScript = loadedModule.instance + + local shouldUseLoadmodule = loadingStrategy == "LoadModule" + or (loadingStrategy == "Automatic" and loadModuleEnabled) + + local existingLoadedModuleFns = loadedModuleFns[moduleScript] + if existingLoadedModuleFns then + moduleFunction = existingLoadedModuleFns[1] + defaultEnvironment = existingLoadedModuleFns[2] + else + if shouldUseLoadmodule then + moduleFunction, errorMessage, cleanupFn = loadmodule(moduleScript) + else + moduleFunction, errorMessage = loadstring(getModuleSource(moduleScript), moduleScript:GetFullName()) + + if errorMessage then + errorMessage = cleanLoadstringStack(errorMessage) + end + end + + if not moduleFunction then + error(Error.new(errorMessage)) + end + + -- Cache initial environment table to inherit from later + defaultEnvironment = getfenv(moduleFunction) + + if loadedModuleFns then + loadedModuleFns[moduleScript] = { moduleFunction, defaultEnvironment, cleanupFn } + else + if cleanupFn ~= nil then + table.insert(cleanupFns, cleanupFn) + end + end + end + + -- The default behavior for function environments is to inherit the table + -- instance from the parent environment. This means that each invocation of + -- `moduleFunction()` will return a new module instance but with the same + -- environment table as `moduleFunction` loadModule the time of invocation. + -- In order to properly sanbox module instances, we need to ensure that each + -- instance has its own distinct environment table containing the specific + -- overrides for it, but still inherits from the default parent environment + -- for non-overriden environment goodies. + + -- This is the 'least mocked' environment that scripts will be able to see. + -- The final function environment inherits from this sandbox. This is + -- separate so that, in the future, `globalEnv` could expose these + -- 'unmocked' functions instead of the ones in the global environment. + local sandboxEnvironment = setmetatable({ + script = if shouldUseLoadmodule then defaultEnvironment.script else moduleScript, + game = defaultEnvironment.game, + workspace = defaultEnvironment.workspace, + plugin = defaultEnvironment.plugin, + + -- legacy aliases for data model + Game = defaultEnvironment.game, + Workspace = defaultEnvironment.workspace, + + require = function(otherModule: ModuleScript | string) + if typeof(otherModule) == "string" then + -- Disabling this at the surface level of the API until we have + -- deeper support in Jest. + error("Require-by-string is not enabled for use inside Jest at this time.") + end + + loadedModule.dependencies[otherModule] = true + + return loadModule(otherModule) + end, + }, { + __index = defaultEnvironment, + }) + + -- This is the environment actually passed to scripts, including all global + -- mocks and other customisations the user might choose to apply. + local mockedSandboxEnvironment = setmetatable({}, { + __index = sandboxEnvironment, + }) + + setfenv(moduleFunction, mockedSandboxEnvironment :: any) + local moduleResult = table.pack(moduleFunction()) + + if moduleResult.n ~= 1 then + error( + string.format( + "[Module Error]: %s did not return a valid result\n" + .. "\tModuleScripts must return exactly one value", + tostring(moduleScript) + ) + ) + end + + trackChanges(moduleScript) + + loadedModule.exports = moduleResult[1] + end + + --[=[ + Require a module with a fresh ModuleScript require cache. + + This method is functionally the same as running `require(script.Parent.ModuleScript)`, + however in this case the module is not cached. As such, if a change occurs + to the module you can call this method again to get the latest changes. + + ```lua + local loader = createModuleLoader() + local module = loader.require(script.Parent.ModuleScript) + ``` + ]=] + function loadModule(moduleScript: ModuleScript) + if moduleScript.Name:find(".global$") then + return (require :: any)(moduleScript) + end + + local caller: ModuleScript? + local callerPath = getCallerPath() + if callerPath then + local loadedCallerModule = moduleRegistry.getByFullName(callerPath) + if loadedCallerModule and loadedCallerModule.instance then + caller = loadedCallerModule.instance + end + end + + local existingModule = moduleRegistry.getByInstance(moduleScript) + if existingModule then + if caller then + existingModule.consumers[caller] = true + end + + return existingModule.exports + end + + -- We must register the pre-allocated module object first so that any + -- circular dependencies that may arise while evaluating the module can + -- be satisfied. + local loadedModule: LoadedModule = { + instance = moduleScript, + exports = nil, + isLoaded = false, + dependencies = {}, + consumers = if caller + then { + [caller] = true, + } + else {}, + } + + moduleRegistry.add(moduleScript, loadedModule) + + local success, result = pcall(function() + execModule(loadedModule) + loadedModule.isLoaded = true + end) + if not success then + moduleRegistry.remove(moduleScript) + error(result) + end + + return loadedModule.exports + end + + --[=[ + Clears out the internal cache. + + While this module bypasses Roblox's ModuleScript cache, one is still + maintained internally so that repeated requires to the same module return a + cached value. + + This method should be called when you need to require a module again. i.e. + if the module's Source has been changed. + + ```lua + local loader = createModuleLoader() + loader.require(script.Parent.ModuleScript) + + -- Later... + + -- Clear the cache and require the module again + loader.clear() + loader.require(script.Parent.ModuleScript) + ``` + ]=] + local function clear() + for _, janitor in janitors do + janitor:Cleanup() + end + + for _, cleanupFn in cleanupFns do + cleanupFn() + end + + moduleRegistry.reset() + loadedModuleFns = {} + cleanupFns = {} + janitors = {} + end + + return { + _getModuleRegistry = _getModuleRegistry, + + cache = cache, + loadModule = loadModule, + require = loadModule, + clearModule = clearModule, + clear = clear, + setLoadingStrategy = setLoadingStrategy, + + loadedModuleChanged = loadedModuleChanged.Event, + } +end + +return createModuleLoader diff --git a/src/createModuleLoader.spec.luau b/src/createModuleLoader.spec.luau new file mode 100644 index 0000000..8d73b37 --- /dev/null +++ b/src/createModuleLoader.spec.luau @@ -0,0 +1,664 @@ +local ReplicatedStorage = game:GetService("ReplicatedStorage") + +local JestGlobals = require("@pkg/JestGlobals") +local test = JestGlobals.test +local expect = JestGlobals.expect +local describe = JestGlobals.describe +local beforeEach = JestGlobals.beforeEach +local afterEach = JestGlobals.afterEach +local describeEach = describe.each :: any + +local createModuleLoader = require("./createModuleLoader") +local types = require("./types") + +local loadModuleEnabled = pcall(function() + return debug["loadmodule"](Instance.new("ModuleScript")) +end) + +type ModuleTestTree = { + [string]: string | ModuleTestTree, +} +local testNumber = 0 +local function createModuleTest(tree: ModuleTestTree, parent: Instance?): any + testNumber += 1 + + local root = Instance.new("Folder") + root.Name = "ModuleTest" .. testNumber + + parent = if parent then parent else root + + for name, sourceOrDescendants in tree do + if typeof(sourceOrDescendants) == "table" then + createModuleTest(sourceOrDescendants, parent) + else + local module = Instance.new("ModuleScript") + module.Name = name + module.Source = sourceOrDescendants + module.Parent = parent + end + end + + root.Parent = game + + return root +end + +local mockModuleSource = {} +local loader +local moduleRegistry +local tree + +describeEach({ + "LoadString", + "LoadModule", +} :: { types.LoadingStrategy })("%s", function(loadingStrategy) + if loadingStrategy == "LoadModule" and not loadModuleEnabled then + test = test.skip :: any + end + + beforeEach(function() + loader = createModuleLoader() + loader.setLoadingStrategy(loadingStrategy) + + moduleRegistry = loader._getModuleRegistry() + end) + + afterEach(function() + loader.clear() + + if tree then + tree:Destroy() + end + end) + + describe("loadedModuleChanged", function() + test("fires when a required module has its ancestry changed", function() + local mockModuleInstance = Instance.new("ModuleScript") + mockModuleInstance.Source = "return nil" + + local wasFired = false + + -- Parent the ModuleScript somewhere in the DataModel so we can + -- listen for AncestryChanged. + mockModuleInstance.Parent = game + + loader.loadedModuleChanged:Connect(function(other: ModuleScript) + if other == mockModuleInstance then + wasFired = true + end + end) + + -- Require the module so that events get setup + loader.require(mockModuleInstance) + + expect(wasFired).toBe(false) + + -- Trigger AncestryChanged to fire + mockModuleInstance.Parent = nil + + expect(wasFired).toBe(true) + end) + + test("fires when a required module has its Source property change", function() + local mockModuleInstance = Instance.new("ModuleScript") + mockModuleInstance.Source = "return nil" + + local wasFired = false + loader.loadedModuleChanged:Connect(function(other: ModuleScript) + if other == mockModuleInstance then + wasFired = true + end + end) + + -- Require the module so that events get setup + loader.require(mockModuleInstance) + + expect(wasFired).toBe(false) + + mockModuleInstance.Source = "Something different" + task.wait() + + expect(wasFired).toBe(true) + end) + + test("fires for every consumer up the chain", function() + tree = createModuleTest({ + ModuleA = [[ + return "ModuleA" + ]], + ModuleB = [[ + require(script.Parent.ModuleA) + return "ModuleB" + ]], + ModuleC = [[ + require(script.Parent.ModuleB) + return "ModuleC" + ]], + }) + + local count = 0 + loader.loadedModuleChanged:Connect(function(module) + for _, child in tree:GetChildren() do + if module == child then + count += 1 + end + end + end) + + loader.require(tree.ModuleC) + + tree.ModuleA.Source = "Changed" + task.wait() + + expect(count).toBe(3) + end) + end) + + describe("cache", function() + test("adds a module and its result to the cache", function() + local mockModuleInstance = Instance.new("ModuleScript") + mockModuleInstance.Source = "return nil" + + loader.cache(mockModuleInstance, mockModuleSource) + + local loadedModule = moduleRegistry.getByInstance(mockModuleInstance) + + assert(loadedModule, "loaded module undefined") + expect(loadedModule.exports).toBe(mockModuleSource) + end) + end) + + describe("require", function() + test("adds the module to the cache", function() + local mockModuleInstance = Instance.new("ModuleScript") + mockModuleInstance.Source = "return nil" + + loader.require(mockModuleInstance) + expect(moduleRegistry.getByInstance(mockModuleInstance)).toBeDefined() + end) + + test("keeps track of module dependencies", function() + tree = createModuleTest({ + Module3 = [[ + return nil + ]], + Module2 = [[ + require(script.Parent.Module3) + return nil + ]], + Module1 = [[ + require(script.Parent.Module2) + return nil + ]], + Root = [[ + require(script.Parent.Module1) + return nil + ]], + }) + + loader.require(tree.Root) + + expect(moduleRegistry.getByInstance(tree.Root)).toMatchObject({ + dependencies = { + [tree.Module1] = true, + }, + }) + expect(moduleRegistry.getByInstance(tree.Module1 :: ModuleScript)).toMatchObject({ + dependencies = { + [tree.Module2] = true, + }, + }) + expect(moduleRegistry.getByInstance(tree.Module2 :: ModuleScript)).toMatchObject({ + dependencies = { + [tree.Module3] = true, + }, + }) + expect(moduleRegistry.getByInstance(tree.Module3 :: ModuleScript)).toMatchObject({ + dependencies = {}, + }) + end) + + test("returns cached results", function() + tree = createModuleTest({ + -- We return a table since it can act as a unique symbol. So if + -- both consumers are getting the same table we can perform an + -- equality check + SharedModule = [[ + local module = {} + return module + ]], + Consumer1 = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + Consumer2 = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + }) + + local sharedModuleFromConsumer1 = loader.require(tree.Consumer1) + local sharedModuleFromConsumer2 = loader.require(tree.Consumer2) + + expect(sharedModuleFromConsumer1).toBe(sharedModuleFromConsumer2) + end) + + test("adds the calling script as a consumer", function() + tree = createModuleTest({ + SharedModule = [[ + local module = {} + return module + ]], + Consumer = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + }) + + loader.require(tree.Consumer) + + local loadedModule = moduleRegistry.getByInstance(tree.SharedModule) + assert(loadedModule, "loaded module undefined") + expect(loadedModule.consumers[tree.Consumer]).toBeDefined() + end) + + test("updates consumers when requiring a cached module from a different script", function() + tree = createModuleTest({ + SharedModule = [[ + local module = {} + return module + ]], + Consumer1 = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + Consumer2 = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + }) + + loader.require(tree.Consumer1) + + local loadedModule = moduleRegistry.getByInstance(tree.SharedModule) + + assert(loadedModule, "loaded module undefined") + + expect(loadedModule.consumers[tree.Consumer1]).toBeDefined() + expect(loadedModule.consumers[tree.Consumer2]).toBeUndefined() + + loader.require(tree.Consumer2) + + expect(loadedModule.consumers[tree.Consumer1]).toBeDefined() + expect(loadedModule.consumers[tree.Consumer2]).toBeDefined() + end) + + test("keeps track of _G between modules", function() + tree = createModuleTest({ + WriteGlobal = [[ + _G.foo = true + return nil + ]], + ReadGlobal = [[ + return _G.foo + ]], + }) + + loader.require(tree.WriteGlobal) + + local result = loader.require(tree.ReadGlobal) + + expect(result).toBe(true) + end) + + test("keeps track of _G in nested requires", function() + tree = createModuleTest({ + DefineGlobal = [[ + _G.foo = true + return nil + ]], + UseGlobal = [[ + require(script.Parent.DefineGlobal) + return _G.foo + ]], + }) + + local result = loader.require(tree.UseGlobal) + + expect(result).toBe(true) + end) + + test("handles syntax errors for direct require", function() + tree = createModuleTest({ + Module = [[ + syntax error + ]], + }) + tree.Name = "SyntaxError" + + expect(function() + loader.require(tree.Module) + end).toThrow(`SyntaxError.Module:1: Incomplete statement: expected assignment or a function call`) + end) + + test("handles syntax errors for nested requires", function() + tree = createModuleTest({ + Module3 = [[ + syntax error + ]], + Module2 = [[ + require(script.Parent.Module3) + return {} + ]], + Module1 = [[ + require(script.Parent.Module2) + return {} + ]], + Consumer = [[ + require(script.Parent.Module1) + return nil + ]], + }) + tree.Name = "SyntaxError" + + expect(function() + loader.require(tree.Consumer) + end).toThrow(`SyntaxError.Module3:1: Incomplete statement: expected assignment or a function call`) + end) + + test("when a module's source changes requiring it again uses the new source", function() + local moduleScript = Instance.new("ModuleScript") + moduleScript.Source = "return true" + + expect(loader.require(moduleScript)).toBeTruthy() + + moduleScript.Source = "return false" + + expect(loader.require(moduleScript)).toBeFalsy() + end) + end) + + describe("clearModule", function() + test("clears a module from the cache", function() + tree = createModuleTest({ + Module = [[ + return "Module" + ]], + }) + + loader.require(tree.Module) + + expect(moduleRegistry.getByInstance(tree.Module)).toBeDefined() + + loader.clearModule(tree.Module) + + expect(moduleRegistry.getByInstance(tree.Module)).toBeUndefined() + end) + + test("clears all consumers of a module from the cache", function() + tree = createModuleTest({ + SharedModule = [[ + local module = {} + return module + ]], + Consumer1 = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + Consumer2 = [[ + local sharedModule = require(script.Parent.SharedModule) + return sharedModule + ]], + }) + + loader.require(tree.Consumer1) + loader.require(tree.Consumer2) + + expect(moduleRegistry.getByInstance(tree.Consumer1)).toBeDefined() + expect(moduleRegistry.getByInstance(tree.Consumer2)).toBeDefined() + expect(moduleRegistry.getByInstance(tree.SharedModule)).toBeDefined() + + loader.clearModule(tree.SharedModule) + + expect(moduleRegistry.getByInstance(tree.Consumer1)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.Consumer2)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.SharedModule)).toBeUndefined() + end) + + test("only clears modules in the consumer chain", function() + tree = createModuleTest({ + Module = [[ + return nil + ]], + Consumer = [[ + require(script.Parent.Module) + return nil + ]], + Independent = [[ + return nil + ]], + }) + + loader.require(tree.Consumer) + loader.require(tree.Independent) + + expect(#moduleRegistry.getAllModules()).toBe(3) + + loader.clearModule(tree.Module) + + expect(#moduleRegistry.getAllModules()).toBe(1) + expect(moduleRegistry.getByInstance(tree.Independent)).toBeDefined() + end) + + test("fires loadedModuleChanged when clearing a module", function() + tree = createModuleTest({ + Module = [[ + return nil + ]], + Consumer = [[ + require(script.Parent.Module) + return nil + ]], + }) + + local wasFired = false + + loader.loadedModuleChanged:Connect(function() + wasFired = true + end) + + loader.require(tree.Consumer) + loader.clearModule(tree.Consumer) + + expect(wasFired).toBe(true) + end) + + test("fires loadedModuleChanged for every module up the chain", function() + tree = createModuleTest({ + Module3 = [[ + return {} + ]], + Module2 = [[ + require(script.Parent.Module3) + return {} + ]], + Module1 = [[ + require(script.Parent.Module2) + return {} + ]], + Consumer = [[ + require(script.Parent.Module1) + return nil + ]], + }) + + local count = 0 + + loader.loadedModuleChanged:Connect(function() + count += 1 + end) + + loader.require(tree.Consumer) + loader.clearModule(tree.Module3 :: ModuleScript) + + expect(count).toBe(4) + end) + + test("never fires loadedModuleChanged for a module that hasn't been required", function() + local wasFired = false + + loader.loadedModuleChanged:Connect(function() + wasFired = true + end) + + -- Do nothing if the module hasn't been cached + local module = Instance.new("ModuleScript") + loader.clearModule(module) + expect(wasFired).toBe(false) + end) + end) + + describe("clear", function() + test("removes all modules from the cache", function() + local mockModuleInstance = Instance.new("ModuleScript") + mockModuleInstance.Source = "return nil" + + loader.cache(mockModuleInstance, mockModuleSource) + + expect(#moduleRegistry.getAllModules()).toBe(1) + + loader.clear() + + expect(#moduleRegistry.getAllModules()).toBe(0) + end) + end) + + describe("consumers", function() + beforeEach(function() + tree = createModuleTest({ + ModuleA = [[ + require(script.Parent.ModuleB) + + return "ModuleA" + ]], + ModuleB = [[ + return "ModuleB" + ]], + + ModuleC = [[ + return "ModuleC" + ]], + }) + end) + + test("removes all consumers of a changed module from the cache", function() + loader.require(tree.ModuleA) + + local hasItems = next(moduleRegistry.getAllModules()) ~= nil + expect(hasItems).toBe(true) + + tree.ModuleB.Source = 'return "ModuleB Reloaded"' + task.wait() + + hasItems = next(moduleRegistry.getAllModules()) ~= nil + expect(hasItems).toBe(false) + end) + + test("does not interfere with other cached modules", function() + loader.require(tree.ModuleA) + loader.require(tree.ModuleC) + + local hasItems = next(moduleRegistry.getAllModules()) ~= nil + expect(hasItems).toBe(true) + + tree.ModuleB.Source = 'return "ModuleB Reloaded"' + task.wait() + + expect(moduleRegistry.getByInstance(tree.ModuleA)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.ModuleB)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.ModuleC)).toBeDefined() + end) + end) + + describe("roblox-ts", function() + local rbxtsInclude + local mockRuntime + + beforeEach(function() + rbxtsInclude = Instance.new("Folder") + rbxtsInclude.Name = "rbxts_include" + + mockRuntime = Instance.new("ModuleScript") + mockRuntime.Name = "RuntimeLib" + mockRuntime.Source = [[ + local function import(...) + return require(...) + end + return { + import = import + } + ]] + mockRuntime.Parent = rbxtsInclude + + rbxtsInclude.Parent = ReplicatedStorage + end) + + afterEach(function() + loader.clear() + rbxtsInclude:Destroy() + end) + + test("clearModule() never clears the roblox-ts runtime from the cache", function() + -- This example isn't quite how a roblox-ts project would be setup since + -- the requires for `Shared` would be using `TS.import`, but it should + -- be close enough for our test case + tree = createModuleTest({ + Shared = [[ + local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) + return {} + ]], + Module1 = [[ + local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) + local Shared = TS.import(script.Parent.Shared) + return nil + ]], + Module2 = [[ + local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) + local Shared = TS.import(script.Parent.Shared) + return nil + ]], + Root = [[ + local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) + local Module1 = TS.import(script.Parent.Module1) + local Module2 = TS.import(script.Parent.Module2) + + return nil + ]], + }) + + loader.require(tree.Root) + loader.clearModule(tree.Shared) + + expect(moduleRegistry.getByInstance(mockRuntime)).toBeDefined() + expect(moduleRegistry.getByInstance(tree.Shared)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.Module1 :: ModuleScript)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.Module2 :: ModuleScript)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.Root)).toBeUndefined() + end) + + test("clear() clears the roblox-ts runtime when calling", function() + tree = createModuleTest({ + Module = [[ + local TS = require(game:GetService("ReplicatedStorage").rbxts_include.RuntimeLib) + return nil + ]], + }) + + loader.require(tree.Module) + loader.clear() + + expect(moduleRegistry.getByInstance(mockRuntime)).toBeUndefined() + expect(moduleRegistry.getByInstance(tree.Module)).toBeUndefined() + end) + end) +end) diff --git a/src/createModuleRegistry.luau b/src/createModuleRegistry.luau new file mode 100644 index 0000000..7afd9d5 --- /dev/null +++ b/src/createModuleRegistry.luau @@ -0,0 +1,53 @@ +local types = require("./types") + +type LoadedModule = types.LoadedModule +type ModuleRegistry = types.ModuleRegistry + +local function createModuleRegistry(): ModuleRegistry + local registry = { + byInstance = {}, + byPath = {}, + } + + local function add(moduleScript: ModuleScript, loadedModule: LoadedModule) + registry.byInstance[moduleScript] = loadedModule + registry.byPath[moduleScript:GetFullName()] = loadedModule + end + + local function remove(moduleScript: ModuleScript) + registry.byInstance[moduleScript] = nil + registry.byPath[moduleScript:GetFullName()] = nil + end + + local function getAllModules(): { ModuleScript } + local modules = {} + for moduleScript in registry.byInstance do + table.insert(modules, moduleScript) + end + return modules + end + + local function reset() + table.clear(registry.byInstance) + table.clear(registry.byPath) + end + + local function getByInstance(moduleScript: ModuleScript): LoadedModule? + return registry.byInstance[moduleScript] + end + + local function getByFullName(fullName: string): LoadedModule? + return registry.byPath[fullName] + end + + return { + add = add, + remove = remove, + reset = reset, + getAllModules = getAllModules, + getByInstance = getByInstance, + getByFullName = getByFullName, + } +end + +return createModuleRegistry diff --git a/src/createModuleRegistry.spec.luau b/src/createModuleRegistry.spec.luau new file mode 100644 index 0000000..4dc4d9e --- /dev/null +++ b/src/createModuleRegistry.spec.luau @@ -0,0 +1,62 @@ +local JestGlobals = require("@pkg/JestGlobals") +local test = JestGlobals.test +local expect = JestGlobals.expect + +local createModuleRegistry = require("./createModuleRegistry") + +test("adds a module to the registry", function() + local registry = createModuleRegistry() + + local moduleScript = Instance.new("ModuleScript") + local mockLoadedModule = {} :: any + + registry.add(moduleScript, mockLoadedModule) + + expect(registry.getByInstance(moduleScript)).toBeDefined() + expect(registry.getByFullName(moduleScript:GetFullName())).toBeDefined() +end) + +test("removes a module from the registry", function() + local registry = createModuleRegistry() + + local moduleScript = Instance.new("ModuleScript") + local mockLoadedModule = {} :: any + + registry.add(moduleScript, mockLoadedModule) + + expect(registry.getByInstance(moduleScript)).toBeDefined() + expect(registry.getByFullName(moduleScript:GetFullName())).toBeDefined() + + registry.remove(moduleScript) + + expect(registry.getByInstance(moduleScript)).toBeUndefined() + expect(registry.getByFullName(moduleScript:GetFullName())).toBeUndefined() +end) + +test("reset the registry", function() + local registry = createModuleRegistry() + + local folder = Instance.new("Folder") + + for _, name in { "ModuleA", "ModuleB", "ModuleC" } do + local moduleScript = Instance.new("ModuleScript") + moduleScript.Name = name + moduleScript.Parent = folder + + local mockLoadedModule = {} :: any + + registry.add(moduleScript, mockLoadedModule) + end + + for _, moduleScript in folder:GetChildren() do + expect(registry.getByInstance(moduleScript :: ModuleScript)).toBeDefined() + expect(registry.getByFullName(moduleScript:GetFullName())).toBeDefined() + end + + registry.reset() + + for _, moduleScript in folder:GetChildren() do + expect(registry.getByInstance(moduleScript :: ModuleScript)).toBeUndefined() + expect(registry.getByFullName(moduleScript:GetFullName())).toBeUndefined() + end +end) diff --git a/src/createTablePassthrough.luau b/src/createTablePassthrough.luau deleted file mode 100644 index cdb9258..0000000 --- a/src/createTablePassthrough.luau +++ /dev/null @@ -1,32 +0,0 @@ ---[[ - Creates a table that can be indexed and added to while also adding to a base - table. - - This is used for module globals so that a module can define variables on _G - which are maintained in a dictionary of all globals AND a dictionary of the - globals a given module has defined. - - This makes it easy to clear out the globals a modeule defines when removing - it from the cache. -]] - -type AnyTable = { [any]: any } - -local function createTablePassthrough(base: AnyTable): AnyTable - local proxy = {} - - setmetatable(proxy, { - __index = function(self, key) - local global = rawget(self, key) - return if global then global else base[key] - end, - __newindex = function(self, key, value) - base[key] = value - rawset(self, key, value) - end, - }) - - return proxy :: any -end - -return createTablePassthrough diff --git a/src/createTablePassthrough.spec.luau b/src/createTablePassthrough.spec.luau deleted file mode 100644 index f8191ff..0000000 --- a/src/createTablePassthrough.spec.luau +++ /dev/null @@ -1,25 +0,0 @@ -local JestGlobals = require("@pkg/JestGlobals") -local test = JestGlobals.test -local expect = JestGlobals.expect - -local createTablePassthrough = require("./createTablePassthrough") - -test("works for the use case of maintaining global variables", function() - local allGlobals = {} - local moduleGlobals1 = createTablePassthrough(allGlobals) - local moduleGlobals2 = createTablePassthrough(allGlobals) - - moduleGlobals1.foo = true - moduleGlobals2.bar = true - - expect(moduleGlobals1.foo).toBe(true) - expect(moduleGlobals1.bar).toBe(true) - expect(rawget(moduleGlobals1, "bar")).toBeUndefined() - - expect(moduleGlobals2.bar).toBe(true) - expect(moduleGlobals2.foo).toBe(true) - expect(rawget(moduleGlobals2, "foo")).toBeUndefined() - - expect(allGlobals.foo).toBe(true) - expect(allGlobals.bar).toBe(true) -end) diff --git a/src/getEnv.luau b/src/getEnv.luau deleted file mode 100644 index efb9871..0000000 --- a/src/getEnv.luau +++ /dev/null @@ -1,31 +0,0 @@ -local baseEnv = getfenv() - -local function getEnv(scriptRelativeTo: LuaSourceContainer?, globals: { [any]: any }?) - local newEnv = {} - - setmetatable(newEnv, { - __index = function(_, key) - if key ~= "plugin" then - return baseEnv[key] - else - return nil - end - end, - }) - - newEnv._G = globals - newEnv.script = scriptRelativeTo - - local realDebug = debug - - newEnv.debug = setmetatable({ - traceback = function(message) - -- Block traces to prevent overly verbose TestEZ output - return message or "" - end, - }, { __index = realDebug }) - - return newEnv -end - -return getEnv diff --git a/src/getEnv.spec.luau b/src/getEnv.spec.luau deleted file mode 100644 index ee8948e..0000000 --- a/src/getEnv.spec.luau +++ /dev/null @@ -1,24 +0,0 @@ -local JestGlobals = require("@pkg/JestGlobals") -local test = JestGlobals.test -local expect = JestGlobals.expect - -local getEnv = require("./getEnv") - -test("returns a table", function() - expect(typeof(getEnv())).toBe("table") -end) - -test("has the correct 'script' global", function() - local env = getEnv(script.Parent.getEnv) - expect(env.script).toBe(script.Parent.getEnv) -end) - -test("sets _G to the 'globals' argument", function() - local globals = {} - local env = getEnv(script.Parent.getEnv, globals) - - expect(env._G).toBeDefined() - expect(env._G).toBe(globals) - -- selene: allow(global_usage) - expect(env._G).never.toBe(_G) -end) diff --git a/src/getModuleSource.luau b/src/getModuleSource.luau new file mode 100644 index 0000000..3466078 --- /dev/null +++ b/src/getModuleSource.luau @@ -0,0 +1,9 @@ +local function getModuleSource(moduleScript: ModuleScript): string + local success, result = pcall(function() + return moduleScript.Source + end) + + return if success then result else "" +end + +return getModuleSource diff --git a/src/init.luau b/src/init.luau index 4278cfc..b9bb6d8 100644 --- a/src/init.luau +++ b/src/init.luau @@ -1,6 +1,6 @@ local ModuleLoader = require("./ModuleLoader") -export type ModuleLoader = ModuleLoader.ModuleLoader -export type Class = ModuleLoader.ModuleLoader +export type ModuleLoader = ModuleLoader.ModuleLoaderClass +export type Class = ModuleLoader.ModuleLoaderClass return ModuleLoader diff --git a/src/types.luau b/src/types.luau new file mode 100644 index 0000000..036ccde --- /dev/null +++ b/src/types.luau @@ -0,0 +1,38 @@ +export type LoadingStrategy = "Automatic" | "LoadString" | "LoadModule" + +export type LoadedModuleExports = unknown? + +export type LoadedModule = { + instance: ModuleScript, + isLoaded: boolean, + exports: LoadedModuleExports, + dependencies: { [ModuleScript]: boolean }, + consumers: { [ModuleScript]: boolean }, +} + +export type ModuleRegistry = { + add: (moduleScript: ModuleScript, loadedModule: LoadedModule) -> (), + remove: (moduleScript: ModuleScript) -> (), + reset: () -> (), + getAllModules: () -> { ModuleScript }, + getByInstance: (moduleScript: ModuleScript) -> LoadedModule?, + getByFullName: (fullName: string) -> LoadedModule?, +} + +export type LoadModuleFn = (moduleScript: ModuleScript) -> LoadedModuleExports + +export type ModuleLoader = { + _getModuleRegistry: () -> ModuleRegistry, + + require: LoadModuleFn, + loadModule: LoadModuleFn, + cache: (moduleScript: ModuleScript, result: LoadedModuleExports) -> (), + clearModule: (moduleScript: ModuleScript) -> (), + clear: () -> (), + + setLoadingStrategy: (loadingStrategy: LoadingStrategy) -> (), + + loadedModuleChanged: RBXScriptSignal, +} + +return nil diff --git a/wally.toml b/wally.toml index 1c4ea2f..a2fba36 100644 --- a/wally.toml +++ b/wally.toml @@ -17,6 +17,8 @@ include = [ [dependencies] Janitor = "howmanysmall/janitor@1.13.15" +Sift = "csqrl/sift@0.0.8" +LuauPolyfill = "jsdotlua/luau-polyfill@1.2.7" #[dev-dependencies] Jest = "jsdotlua/jest@3.10.0"