diff --git a/build.sh b/build.sh index b734608f..44fcbe57 100755 --- a/build.sh +++ b/build.sh @@ -97,11 +97,13 @@ function build_rbxmx { function build_lua { compile echo "[net-build] compiling lua output..." + rm -rf lualib mkdir -p lualib cp -r out/* lualib cp -r include lualib/vendor - find dist/lua -name '*.d.ts' -delete + find lualib -name '*.d.ts' -delete + rm -rf lualib/Test echo "[net-build] Output to ./lualib" } diff --git a/lualib/ClientEvent.lua b/lualib/ClientEvent.lua new file mode 100644 index 00000000..b81fb354 --- /dev/null +++ b/lualib/ClientEvent.lua @@ -0,0 +1,45 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local NetClientEvent; +local _0 = TS.import(script.Parent, "internal"); +local getRemoteOrThrow, IS_CLIENT, waitForEvent, MAX_CLIENT_WAITFORCHILD_TIMEOUT = _0.getRemoteOrThrow, _0.IS_CLIENT, _0.waitForEvent, _0.MAX_CLIENT_WAITFORCHILD_TIMEOUT; +do + NetClientEvent = setmetatable({}, { + __tostring = function() return "NetClientEvent" end; + }); + NetClientEvent.__index = NetClientEvent; + function NetClientEvent.new(...) + local self = setmetatable({}, NetClientEvent); + self:constructor(...); + return self; + end; + function NetClientEvent:constructor(name) + self.instance = getRemoteOrThrow("RemoteEvent", name); + assert(IS_CLIENT, "Cannot create a Net.ClientEvent on the Server!"); + end; + NetClientEvent.WaitFor = TS.async(function(self, name) + local fun = waitForEvent(name, MAX_CLIENT_WAITFORCHILD_TIMEOUT); + if not fun then + error("Failed to retrieve client Event!"); + end; + return NetClientEvent.new(name); + end); + function NetClientEvent:getInstance() + return self.instance; + end; + function NetClientEvent:getEvent() + return self.instance.OnClientEvent; + end; + function NetClientEvent:Connect(callback) + return self:getEvent():Connect(callback); + end; + function NetClientEvent:SendToServer(...) + local args = { ... }; + self.instance:FireServer(unpack(args)); + end; +end; +exports.default = NetClientEvent; +return exports; diff --git a/lualib/ClientFunction.lua b/lualib/ClientFunction.lua new file mode 100644 index 00000000..7c0d9c68 --- /dev/null +++ b/lualib/ClientFunction.lua @@ -0,0 +1,67 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local NetClientFunction; +local _0 = TS.import(script.Parent, "internal"); +local getRemoteOrThrow, IS_CLIENT, functionExists, waitForFunction, MAX_CLIENT_WAITFORCHILD_TIMEOUT = _0.getRemoteOrThrow, _0.IS_CLIENT, _0.functionExists, _0.waitForFunction, _0.MAX_CLIENT_WAITFORCHILD_TIMEOUT; +do + NetClientFunction = setmetatable({}, { + __tostring = function() return "NetClientFunction" end; + }); + NetClientFunction.__index = NetClientFunction; + function NetClientFunction.new(...) + local self = setmetatable({}, NetClientFunction); + self:constructor(...); + return self; + end; + function NetClientFunction:constructor(name) + self.lastPing = -1; + self.cached = {}; + self.instance = getRemoteOrThrow("RemoteFunction", name); + assert(IS_CLIENT, "Cannot create a Net.ClientFunction on the Server!"); + assert(functionExists(name), "The specified function '" .. name .. "' does not exist!"); + end; + NetClientFunction.WaitFor = TS.async(function(self, name) + local fun = waitForFunction(name, MAX_CLIENT_WAITFORCHILD_TIMEOUT); + if not fun then + error("Failed to retrieve client Function!"); + end; + return NetClientFunction.new(name); + end); + function NetClientFunction:getCallback() + return self.instance.OnClientInvoke; + end; + function NetClientFunction:setCallback(func) + self.instance.OnClientInvoke = func; + end; + function NetClientFunction:getInstance() + return self.instance; + end; + function NetClientFunction:getCache() + local cache = self.instance:FindFirstChild("Cache"); + if cache then + return cache.Value; + else + return 0; + end; + end; + function NetClientFunction:CallServer(...) + local args = { ... }; + if self.lastPing < os.time() + self:getCache() then + local result = self.instance:InvokeServer(unpack(args)); + self.cached = result; + self.lastPing = os.time(); + return result; + else + return self.cached; + end; + end; + NetClientFunction.CallServerAsync = TS.async(function(self, ...) + local args = { ... }; + return self:CallServer(unpack(args)); + end); +end; +exports.default = NetClientFunction; +return exports; diff --git a/lualib/GlobalClientEvent.lua b/lualib/GlobalClientEvent.lua new file mode 100644 index 00000000..7c77e313 --- /dev/null +++ b/lualib/GlobalClientEvent.lua @@ -0,0 +1,27 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local getGlobalRemoteId = TS.import(script.Parent, "internal").getGlobalRemote; +local NetClientEvent = TS.import(script.Parent, "ClientEvent").default; +local NetGlobalClientEvent; +do + NetGlobalClientEvent = setmetatable({}, { + __tostring = function() return "NetGlobalClientEvent" end; + }); + NetGlobalClientEvent.__index = NetGlobalClientEvent; + function NetGlobalClientEvent.new(...) + local self = setmetatable({}, NetGlobalClientEvent); + self:constructor(...); + return self; + end; + function NetGlobalClientEvent:constructor(name) + self.instance = NetClientEvent.new(getGlobalRemoteId(name)); + end; + function NetGlobalClientEvent:Connect(callback) + self.instance:Connect(callback); + end; +end; +exports.default = NetGlobalClientEvent; +return exports; diff --git a/lualib/GlobalEvent.lua b/lualib/GlobalEvent.lua new file mode 100644 index 00000000..6ec5b5ee --- /dev/null +++ b/lualib/GlobalEvent.lua @@ -0,0 +1,112 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local NetGlobalEvent; +local _0 = TS.import(script.Parent, "internal"); +local ServerTickFunctions, isLuaTable = _0.ServerTickFunctions, _0.isLuaTable; +local MockMessagingService = TS.import(script.Parent, "MockMessagingService"); +local MessagingService = game:GetService("MessagingService"); +local Players = game:GetService("Players"); +local IS_STUDIO = game:GetService("RunService"):IsStudio(); +local function isSubscriptionMessage(value) + if isLuaTable(value) then + local hasData = (value["Data"] ~= nil); + return hasData; + else + return false; + end; +end; +local function isJobTargetMessage(value) + if isSubscriptionMessage(value) then + if isLuaTable(value.Data) then + return (value.Data["jobId"] ~= nil); + end; + end; + return false; +end; +local globalMessageQueue = {}; +local lastQueueTick = 0; +local globalEventMessageCounter = 0; +local globalSubscriptionCounter = 0; +local function processMessageQueue() + if tick() >= lastQueueTick + 60 then + globalEventMessageCounter = 0; + globalSubscriptionCounter = 0; + lastQueueTick = tick(); + while #globalMessageQueue > 0 do + local _1 = #globalMessageQueue; + local message = globalMessageQueue[_1]; + globalMessageQueue[_1] = nil; -- globalMessageQueue.pop + MessagingService:PublishAsync(message.Name, message.Data); + globalEventMessageCounter = globalEventMessageCounter + 1; + end; + if globalEventMessageCounter >= NetGlobalEvent:GetMessageLimit() then + warn("[rbx-net] Too many messages are being sent, any further messages will be queued!"); + end; + end; +end; +do + NetGlobalEvent = setmetatable({}, { + __tostring = function() return "NetGlobalEvent" end; + }); + NetGlobalEvent.__index = NetGlobalEvent; + function NetGlobalEvent.new(...) + local self = setmetatable({}, NetGlobalEvent); + self:constructor(...); + return self; + end; + function NetGlobalEvent:constructor(name) + self.name = name; + end; + function NetGlobalEvent.GetMessageLimit(self) + return 150 + 60 * #Players:GetPlayers(); + end; + function NetGlobalEvent.GetSubscriptionLimit(self) + return 5 + 2 * #Players:GetPlayers(); + end; + function NetGlobalEvent:SendToServer(jobId, message) + self:SendToAllServers({ + jobId = jobId; + message = message; + }); + end; + function NetGlobalEvent:SendToAllServers(message) + local limit = NetGlobalEvent:GetMessageLimit(); + if globalEventMessageCounter >= limit then + warn("[rbx-net] Exceeded message limit of " .. tostring(limit) .. ", adding to queue..."); + globalMessageQueue[#globalMessageQueue + 1] = { + Name = self.name; + Data = message; + }; + else + globalEventMessageCounter = globalEventMessageCounter + 1; + TS.Promise.spawn(function() + ((IS_STUDIO and MockMessagingService) or MessagingService):PublishAsync(self.name, message); + end); + end; + end; + function NetGlobalEvent:Connect(handler) + local limit = NetGlobalEvent:GetSubscriptionLimit(); + if globalSubscriptionCounter >= limit then + error("[rbx-net] Exceeded Subscription limit of " .. tostring(limit) .. "!"); + end; + globalSubscriptionCounter = globalSubscriptionCounter + 1; + return ((IS_STUDIO and MockMessagingService) or MessagingService):SubscribeAsync(self.name, function(recieved) + local Sent = recieved.Sent; + if isJobTargetMessage(recieved) then + local Data = recieved.Data; + if game.JobId == Data.JobId then + handler(Data.InnerData, Sent); + end; + else + handler(recieved.Data, Sent); + end; + end); + end; +end; +ServerTickFunctions[#ServerTickFunctions + 1] = processMessageQueue; +exports.isSubscriptionMessage = isSubscriptionMessage; +exports.default = NetGlobalEvent; +return exports; diff --git a/lualib/GlobalServerEvent.lua b/lualib/GlobalServerEvent.lua new file mode 100644 index 00000000..9374073c --- /dev/null +++ b/lualib/GlobalServerEvent.lua @@ -0,0 +1,119 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local NetServerEvent = TS.import(script.Parent, "ServerEvent").default; +local _0 = TS.import(script.Parent, "GlobalEvent"); +local NetGlobalEvent, isSubscriptionMessage, ISubscriptionMessage = _0.default, _0.isSubscriptionMessage, _0.ISubscriptionMessage; +local _1 = TS.import(script.Parent, "internal"); +local getGlobalRemote, IS_CLIENT, isLuaTable = _1.getGlobalRemote, _1.IS_CLIENT, _1.isLuaTable; +local Players = game:GetService("Players"); +local function isTargetedSubscriptionMessage(value) + if isSubscriptionMessage(value) then + if isLuaTable(value.Data) then + return (value.Data["InnerData"] ~= nil); + end; + end; + return false; +end; +local NetGlobalServerEvent; +do + NetGlobalServerEvent = setmetatable({}, { + __tostring = function() return "NetGlobalServerEvent" end; + }); + NetGlobalServerEvent.__index = NetGlobalServerEvent; + function NetGlobalServerEvent.new(...) + local self = setmetatable({}, NetGlobalServerEvent); + self:constructor(...); + return self; + end; + function NetGlobalServerEvent:constructor(name) + self.instance = NetServerEvent.new(getGlobalRemote(name)); + self.event = NetGlobalEvent.new(name); + assert(not IS_CLIENT, "Cannot create a Net.GlobalServerEvent on the Client!"); + self.eventHandler = self.event:Connect(function(message) + if isTargetedSubscriptionMessage(message) then + self:recievedMessage(message.Data); + else + warn("[rbx-net] Recieved malformed message for GlobalServerEvent: " .. name); + end; + end); + end; + function NetGlobalServerEvent:getPlayersMatchingId(matching) + if (typeof(matching) == "number") then + return Players:GetPlayerByUserId(matching); + else + local players = {}; + for _2 = 1, #matching do + local id = matching[_2]; + local player = Players:GetPlayerByUserId(id); + if player then + players[#players + 1] = player; + end; + end; + return players; + end; + end; + function NetGlobalServerEvent:recievedMessage(message) + if message.TargetIds then + local players = self:getPlayersMatchingId(message.TargetIds); + if players then + self.instance:SendToPlayers(players, unpack(message.InnerData)); + end; + elseif message.TargetId then + local player = self:getPlayersMatchingId(message.TargetId); + if player then + self.instance:SendToPlayer(player, unpack(message.InnerData)); + end; + else + self.instance:SendToAllPlayers(unpack(message.InnerData)); + end; + end; + function NetGlobalServerEvent:Disconnect() + self.eventHandler:Disconnect(); + end; + function NetGlobalServerEvent:SendToAllServers(...) + local args = { ... }; + self.event:SendToAllServers({ + data = { unpack(args) }; + }); + end; + function NetGlobalServerEvent:SendToServer(jobId, ...) + local args = { ... }; + self.event:SendToServer(jobId, { + data = { unpack(args) }; + }); + end; + function NetGlobalServerEvent:SendToPlayer(userId, ...) + local args = { ... }; + local player = Players:GetPlayerByUserId(userId); + if player then + self.instance:SendToPlayer(player, unpack(args)); + else + self.event:SendToAllServers({ + data = { unpack(args) }; + targetId = userId; + }); + end; + end; + function NetGlobalServerEvent:SendToPlayers(userIds, ...) + local args = { ... }; + for _2 = 1, #userIds do + local targetId = userIds[_2]; + local player = Players:GetPlayerByUserId(targetId); + if player then + self.instance:SendToPlayer(player, unpack(args)); + table.remove(userIds, targetId + 1); + end; + end; + if #userIds > 0 then + self.event:SendToAllServers({ + data = { unpack(args) }; + targetIds = userIds; + }); + end; + end; +end; +exports.default = NetGlobalServerEvent; +return exports; diff --git a/lualib/MockMessagingService/init.lua b/lualib/MockMessagingService/init.lua new file mode 100644 index 00000000..4acc43bf --- /dev/null +++ b/lualib/MockMessagingService/init.lua @@ -0,0 +1,28 @@ +local MockMessagingService = {} + +local topics = {} + +function MockMessagingService:PublishAsync(topicName, message) + local topic = topics[topicName] + if topic then + topic:Fire( + { + Sent = tick(), + Data = message + } + ) + end +end + +function MockMessagingService:SubscribeAsync(topicName, callback) + local topic = topics[topicName] + if not topic then + topic = Instance.new("BindableEvent", script) + topic.Name = topicName + topics[topicName] = topic + end + + return topic.Event:Connect(callback) +end + +return MockMessagingService diff --git a/lualib/Serializer/init.lua b/lualib/Serializer/init.lua new file mode 100644 index 00000000..7d4173cf --- /dev/null +++ b/lualib/Serializer/init.lua @@ -0,0 +1,89 @@ +local Serializer = {} + +function Serializer.Serialize(object) + if type(object) ~= "table" then + error("Cannot serialize non-object", 2) + end + + if type(object.serialize) == "function" then + return object:serialize() + end + + local serialized = {} + for index, value in next, object do + if type(value) == "table" then + serialized[index] = Serializer.Serialize(value) + else + serialized[index] = value + end + end + + return serialized +end + +function Serializer.Deserialize(struct, deserializer) + if type(deserializer) == "function" then + return deserializer(struct) + elseif type(deserializer) == "table" then + if type(deserializer.deserialize) == "function" then + return deserializer:deserialize(struct) + end + + for index, value in next, struct do + deserializer[index] = value + end + end +end + +local function isMixed(t) + assert(type(t) == "table") + local mixed = false + local _idxType + + for index, value in next, t do + if _idxType and _idxType ~= type(index) then + return true + end + + _idxType = type(index) + if type(value) == "table" then + mixed = mixed and isMixed(value) and not (not getmetatable(value)) + end + end + + return mixed +end + +function Serializer.IsSerializable(value) + local _type = type(value) + if _type == "number" or _type == "boolean" or _type == "string" then + return true + elseif _type == "table" then + return not isMixed(value) and not getmetatable(value) + elseif _type == "userdata" and typeof(_type) ~= "userdata" then -- Instances / Value Types + return true + else + return false + end +end + +function Serializer.makeDeserializable(class, callback) + local wrapper = {} + if (type(callback) == "function") then + wrapper.deserialize = function(_, serialized) + return callback(serialized) + end + else + wrapper.deserialize = function(_, serialized, ...) + local obj = class.new(...) + for index, value in next, serialized do + obj[index] = value + end + return obj + end + end + + return setmetatable(wrapper, {__index = class}) +end + +return Serializer diff --git a/lualib/ServerEvent.lua b/lualib/ServerEvent.lua new file mode 100644 index 00000000..c852551e --- /dev/null +++ b/lualib/ServerEvent.lua @@ -0,0 +1,70 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local _0 = TS.import(script.Parent, "internal"); +local findOrCreateRemote, IS_CLIENT = _0.findOrCreateRemote, _0.IS_CLIENT; +local Players = game:GetService("Players"); +local NetServerEvent; +do + NetServerEvent = setmetatable({}, { + __tostring = function() return "NetServerEvent" end; + }); + NetServerEvent.__index = NetServerEvent; + function NetServerEvent.new(...) + local self = setmetatable({}, NetServerEvent); + self:constructor(...); + return self; + end; + function NetServerEvent:constructor(name) + self.instance = findOrCreateRemote("RemoteEvent", name); + assert(not IS_CLIENT, "Cannot create a Net.ServerEvent on the Client!"); + end; + function NetServerEvent:getInstance() + return self.instance; + end; + function NetServerEvent:getEvent() + return self.instance.OnServerEvent; + end; + function NetServerEvent:Connect(callback) + return self:getEvent():Connect(callback); + end; + function NetServerEvent:SendToAllPlayers(...) + local args = { ... }; + self.instance:FireAllClients(unpack(args)); + end; + function NetServerEvent:SendToAllPlayersExcept(blacklist, ...) + local args = { ... }; + if (typeof(blacklist) == "Instance") then + local otherPlayers = TS.array_filter(Players:GetPlayers(), function(p) + return p ~= blacklist; + end); + for _1 = 1, #otherPlayers do + local player = otherPlayers[_1]; + self.instance:FireClient(player, unpack(args)); + end; + elseif (typeof(blacklist) == "table") then + local _1 = Players:GetPlayers(); + for _2 = 1, #_1 do + local player = _1[_2]; + if TS.array_indexOf(blacklist, player) == -1 then + self.instance:FireClient(player, unpack(args)); + end; + end; + end; + end; + function NetServerEvent:SendToPlayer(player, ...) + local args = { ... }; + self.instance:FireClient(player, unpack(args)); + end; + function NetServerEvent:SendToPlayers(players, ...) + local args = { ... }; + for _1 = 1, #players do + local player = players[_1]; + self:SendToPlayer(player, unpack(args)); + end; + end; +end; +exports.default = NetServerEvent; +return exports; diff --git a/lualib/ServerFunction.lua b/lualib/ServerFunction.lua new file mode 100644 index 00000000..a55b97cb --- /dev/null +++ b/lualib/ServerFunction.lua @@ -0,0 +1,59 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local _0 = TS.import(script.Parent, "internal"); +local findOrCreateRemote, IS_CLIENT = _0.findOrCreateRemote, _0.IS_CLIENT; +local NetServerFunction; +do + NetServerFunction = setmetatable({}, { + __tostring = function() return "NetServerFunction" end; + }); + NetServerFunction.__index = NetServerFunction; + function NetServerFunction.new(...) + local self = setmetatable({}, NetServerFunction); + self:constructor(...); + return self; + end; + function NetServerFunction:constructor(name) + self.instance = findOrCreateRemote("RemoteFunction", name); + assert(not IS_CLIENT, "Cannot create a Net.ServerFunction on the Client!"); + end; + function NetServerFunction:getCallback() + return self.instance.OnServerInvoke; + end; + function NetServerFunction:setCallback(func) + self.instance.OnServerInvoke = func; + return self; + end; + function NetServerFunction:getInstance() + return self.instance; + end; + function NetServerFunction:getClientCache() + local cache = self.instance:FindFirstChild("Cache"); + if cache then + return cache.Value; + else + return 0; + end; + end; + function NetServerFunction:setClientCache(time) + local cache = self.instance:FindFirstChild("Cache"); + if not cache then + local cacheTimer = Instance.new("NumberValue", self.instance); + cacheTimer.Value = time; + cacheTimer.Name = "Cache"; + else + cache.Value = time; + end; + return self; + end; + NetServerFunction.CallPlayerAsync = TS.async(function(self, player, ...) + local args = { ... }; + warn("[rbx-net] CallPlayerAsync is possibly going to be removed\n" .. "\tsee https://github.com/roblox-aurora/rbx-net/issues/13 for more details."); + return self.instance:InvokeClient(player, unpack(args)); + end); +end; +exports.default = NetServerFunction; +return exports; diff --git a/lualib/ServerThrottledEvent.lua b/lualib/ServerThrottledEvent.lua new file mode 100644 index 00000000..e3d9c908 --- /dev/null +++ b/lualib/ServerThrottledEvent.lua @@ -0,0 +1,65 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local NetServerEvent = TS.import(script.Parent, "ServerEvent").default; +local errorft = TS.import(script.Parent, "internal").errorft; +local throttler = TS.import(script.Parent, "Throttle"); +local GetConfiguration = TS.import(script.Parent, "configuration").GetConfiguration; +local NetServerThrottledEvent; +do + local super = NetServerEvent; + NetServerThrottledEvent = setmetatable({}, { + __index = super; + __tostring = function() return "NetServerThrottledEvent" end; + }); + NetServerThrottledEvent.__index = NetServerThrottledEvent; + function NetServerThrottledEvent.new(...) + local self = setmetatable({}, NetServerThrottledEvent); + self:constructor(...); + return self; + end; + function NetServerThrottledEvent:constructor(name, rateLimit) + super.constructor(self, name); + self.maxRequestsPerMinute = 0; + self.maxRequestsPerMinute = rateLimit; + self.clientRequests = throttler:Get("Event~" .. name); + local clientValue = Instance.new("IntValue", self.instance); + clientValue.Name = "RateLimit"; + clientValue.Value = rateLimit; + end; + function NetServerThrottledEvent:Connect(callback) + return self.instance.OnServerEvent:Connect(function(player, ...) + local args = { ... }; + local maxRequests = self.maxRequestsPerMinute; + local clientRequestCount = self.clientRequests:Get(player); + if clientRequestCount >= maxRequests then + errorft(GetConfiguration("ServerThrottleMessage"), { + player = player.UserId; + remote = self.instance.Name; + limit = maxRequests; + }); + else + self.clientRequests:Increment(player); + callback(player, unpack((args))); + end; + end); + end; + function NetServerThrottledEvent:setRateLimit(requestsPerMinute) + self.maxRequestsPerMinute = requestsPerMinute; + local clientValue = self.instance:FindFirstChild("RateLimit"); + if clientValue then + clientValue.Value = requestsPerMinute; + else + clientValue = Instance.new("IntValue", self.instance); + clientValue.Name = "RateLimit"; + clientValue.Value = requestsPerMinute; + end; + end; + function NetServerThrottledEvent:getRateLimit() + return self.maxRequestsPerMinute; + end; +end; +exports.default = NetServerThrottledEvent; +return exports; diff --git a/lualib/ServerThrottledFunction.lua b/lualib/ServerThrottledFunction.lua new file mode 100644 index 00000000..f84999e3 --- /dev/null +++ b/lualib/ServerThrottledFunction.lua @@ -0,0 +1,68 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.Parent.vendor.RuntimeLib); +local exports = {}; +local _0 = TS.import(script.Parent, "internal"); +local RequestCounter, errorft = _0.RequestCounter, _0.errorft; +local throttler = TS.import(script.Parent, "Throttle"); +local GetConfiguration = TS.import(script.Parent, "configuration").GetConfiguration; +local NetServerFunction = TS.import(script.Parent, "ServerFunction").default; +local NetServerThrottledFunction; +do + local super = NetServerFunction; + NetServerThrottledFunction = setmetatable({}, { + __index = super; + __tostring = function() return "NetServerThrottledFunction" end; + }); + NetServerThrottledFunction.__index = NetServerThrottledFunction; + function NetServerThrottledFunction.new(...) + local self = setmetatable({}, NetServerThrottledFunction); + self:constructor(...); + return self; + end; + function NetServerThrottledFunction:constructor(name, rateLimit) + super.constructor(self, name); + self.maxRequestsPerMinute = 0; + self.maxRequestsPerMinute = rateLimit; + self.clientRequests = throttler:Get("Function~" .. name); + local clientValue = Instance.new("IntValue", self.instance); + clientValue.Name = "RateLimit"; + clientValue.Value = rateLimit; + end; + function NetServerThrottledFunction:setCallback(callback) + self.instance.OnServerInvoke = function(player, ...) + local args = { ... }; + local maxRequests = self.maxRequestsPerMinute; + local clientRequestCount = self.clientRequests:Get(player); + if clientRequestCount >= maxRequests then + errorft(GetConfiguration("ServerThrottleMessage"), { + player = player.UserId; + remote = self.instance.Name; + limit = maxRequests; + }); + else + self.clientRequests:Increment(player); + return callback(player, unpack(args)); + end; + end; + return self; + end; + function NetServerThrottledFunction:setRateLimit(requestsPerMinute) + self.maxRequestsPerMinute = requestsPerMinute; + local clientValue = self.instance:FindFirstChild("RateLimit"); + if clientValue then + clientValue.Value = requestsPerMinute; + else + clientValue = Instance.new("IntValue", self.instance); + clientValue.Name = "RateLimit"; + clientValue.Value = requestsPerMinute; + end; + end; + function NetServerThrottledFunction:getRateLimit() + return self.maxRequestsPerMinute; + end; + NetServerThrottledFunction.rates = {}; +end; +exports.default = NetServerThrottledFunction; +return exports; diff --git a/lualib/Throttle/init.lua b/lualib/Throttle/init.lua new file mode 100644 index 00000000..35828908 --- /dev/null +++ b/lualib/Throttle/init.lua @@ -0,0 +1,57 @@ +local Throttle = { + counters = {} +} +local RequestCounter = {} +RequestCounter.__index = RequestCounter + +function RequestCounter.new() + local self = { + counter = {} + } + + return setmetatable(self, RequestCounter) +end + +function RequestCounter:Get(player) + local counter = self.counter + local playerQueue = counter[player.UserId] + return playerQueue or 0 +end + +function RequestCounter:Increment(player) + local counter = self.counter + local playerQueue = counter[player.UserId] + + if not counter[player.UserId] then + counter[player.UserId] = 1 + else + counter[player.UserId] = playerQueue + 1 + end +end + +function RequestCounter:__tostring() + return "RequestCounter" +end + +function RequestCounter:ClearAll() + self.counter = {} +end + +function Throttle:Get(name) + local existing = self.counters[name] + if (existing) then + return existing + else + local newCounter = RequestCounter.new() + self.counters[name] = newCounter + return newCounter + end +end + +function Throttle:Clear() + for _, counter in pairs(self.counters) do + counter:ClearAll() + end +end + +return Throttle diff --git a/lualib/configuration.lua b/lualib/configuration.lua new file mode 100644 index 00000000..27e00c75 --- /dev/null +++ b/lualib/configuration.lua @@ -0,0 +1,35 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local exports; +local NetConfig; +local runService = game:GetService("RunService"); +local IS_SERVER = runService:IsServer(); +local throttleResetTimer = 60; +local rateLimitReachedMessage = "Request limit exceeded ({limit}) by {player} via {remote}"; +NetConfig = NetConfig or {} do + local _0 = NetConfig; + local function SetConfiguration(key, value) + assert(IS_SERVER, "Cannot modify configuration on client!"); + if key == "ServerThrottleMessage" then + throttleResetTimer = value; + elseif key == "ServerThrottleMessage" then + rateLimitReachedMessage = value; + end; + end; + local function GetConfiguration(key) + if key == "ServerThrottleResetTimer" then + assert(IS_SERVER, "ServerThrottleResetTimer is not used on the client!"); + return throttleResetTimer; + elseif key == "ServerThrottleMessage" then + assert(IS_SERVER, "ServerThrottleMessage is not used on the client!"); + return rateLimitReachedMessage; + else + return nil; + end; + end; + _0.SetConfiguration = SetConfiguration; + _0.GetConfiguration = GetConfiguration; +end; +exports = NetConfig; +return exports; diff --git a/lualib/init.lua b/lualib/init.lua new file mode 100644 index 00000000..5b864e30 --- /dev/null +++ b/lualib/init.lua @@ -0,0 +1,175 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local TS = require(script.vendor.RuntimeLib); +local exports; +local Net; +local throttler = TS.import(script, "Throttle"); +local Serializer = TS.import(script, "Serializer"); +local config = TS.import(script, "configuration"); +local _0 = TS.import(script, "internal"); +local functionExists, eventExists, ServerTickFunctions = _0.functionExists, _0.eventExists, _0.ServerTickFunctions; +local NetServerEvent = TS.import(script, "ServerEvent").default; +local NetClientEvent = TS.import(script, "ClientEvent").default; +local NetClientFunction = TS.import(script, "ClientFunction").default; +local NetServerFunction = TS.import(script, "ServerFunction").default; +local NetServerThrottledFunction = TS.import(script, "ServerThrottledFunction").default; +local NetServerThrottledEvent = TS.import(script, "ServerThrottledEvent").default; +local NetGlobalEvent = TS.import(script, "GlobalEvent").default; +local NetGlobalServerEvent = TS.import(script, "GlobalServerEvent").default; +local runService = game:GetService("RunService"); +local IS_CLIENT = (__LEMUR__ and not runService:IsServer()) or runService:IsClient(); +local IS_SERVER = runService:IsServer(); +local IS_STUDIO = runService:IsStudio(); +Net = Net or {} do + local _1 = Net; + local SetConfiguration = config.SetConfiguration; + local GetConfiguration = config.GetConfiguration; + local VERSION = { + number = { + major = 1; + minor = 0; + revision = 12; + }; + date = 190602; + tag = "release"; + }; + setmetatable(VERSION, { + __tostring = function(self) + local _2 = self.number; + local major = _2.major; + local minor = _2.minor; + local revision = _2.revision; + return tostring(major) .. "." .. tostring(minor) .. "." .. tostring(revision); + end; + }); + local ServerEvent = NetServerEvent; + local ClientEvent = NetClientEvent; + local ClientFunction = NetClientFunction; + local ServerFunction = NetServerFunction; + local GlobalEvent = NetGlobalEvent; + local GlobalServerEvent = NetGlobalServerEvent; + local ServerThrottledEvent = NetServerThrottledEvent; + local ServerThrottledFunction = NetServerThrottledFunction; + local function IsClient() + return IS_CLIENT; + end; + local function IsServer() + return IS_SERVER; + end; + local Serialize = Serializer.Serialize; + local Deserialize = Serializer.Deserialize; + local IsSerializable = Serializer.IsSerializable; + local function CreateFunction(nameOrOptions) + if IS_SERVER then + if (typeof(nameOrOptions) == "string") then + return NetServerFunction.new(nameOrOptions); + else + local fn; + if nameOrOptions.rateLimit ~= nil then + fn = NetServerThrottledFunction.new(nameOrOptions.name, nameOrOptions.rateLimit); + else + fn = NetServerFunction.new(nameOrOptions.name); + end; + if nameOrOptions.callback then + fn:setCallback(nameOrOptions.callback); + end; + if nameOrOptions.cacheSeconds then + fn:setClientCache(nameOrOptions.cacheSeconds); + end; + return fn; + end; + else + error("Net.createFunction can only be used on the server!"); + error(""); + end; + end; + local function CreateThrottledFunction(name, rateLimit) + if IS_SERVER then + return NetServerThrottledFunction.new(name, rateLimit); + else + error("Net.createFunction can only be used on the server!"); + error(""); + end; + end; + local function CreateThrottledEvent(name, rateLimit) + if IS_SERVER then + return NetServerThrottledEvent.new(name, rateLimit); + else + error("Net.createFunction can only be used on the server!"); + error("Net.createFunction can only be used on the server!"); + end; + end; + local function CreateEvent(name) + if IS_SERVER then + return NetServerEvent.new(name); + else + error("Net.createFunction can only be used on the server!"); + error("Net.createFunction can only be used on the server!"); + end; + end; + local WaitForClientFunctionAsync = TS.async(function(name) + return NetClientFunction:WaitFor(name); + end); + local WaitForClientEventAsync = TS.async(function(name) + return NetClientEvent:WaitFor(name); + end); + local function GetServerEventAsync(name) + return TS.Promise.new(function(resolve, reject) + if eventExists(name) then + local newFunc = ServerEvent.new(name); + resolve(newFunc); + else + reject("Could not find Server Event: " .. name .. " (did you create it on the server?)"); + end; + end); + end; + local function GetServerFunctionAsync(name) + return TS.Promise.new(function(resolve, reject) + if functionExists(name) then + local newFunc = NetServerFunction.new(name); + resolve(newFunc); + else + reject("Could not find Server Function: " .. name .. " (did you create it?)"); + end; + end); + end; + if IS_STUDIO then + print("[rbx-net] Loaded rbx-net", "v" .. tostring(VERSION)); + end; + if IS_SERVER then + local lastTick = 0; + ServerTickFunctions[#ServerTickFunctions + 1] = function() + if tick() > lastTick + GetConfiguration("ServerThrottleResetTimer") then + lastTick = tick(); + throttler:Clear(); + end; + end; + end; + _1.SetConfiguration = SetConfiguration; + _1.GetConfiguration = GetConfiguration; + _1.VERSION = VERSION; + _1.ServerEvent = ServerEvent; + _1.ClientEvent = ClientEvent; + _1.ClientFunction = ClientFunction; + _1.ServerFunction = ServerFunction; + _1.GlobalEvent = GlobalEvent; + _1.GlobalServerEvent = GlobalServerEvent; + _1.ServerThrottledEvent = ServerThrottledEvent; + _1.ServerThrottledFunction = ServerThrottledFunction; + _1.IsClient = IsClient; + _1.IsServer = IsServer; + _1.Serialize = Serialize; + _1.Deserialize = Deserialize; + _1.IsSerializable = IsSerializable; + _1.CreateFunction = CreateFunction; + _1.CreateThrottledFunction = CreateThrottledFunction; + _1.CreateThrottledEvent = CreateThrottledEvent; + _1.CreateEvent = CreateEvent; + _1.WaitForClientFunctionAsync = WaitForClientFunctionAsync; + _1.WaitForClientEventAsync = WaitForClientEventAsync; + _1.GetServerEventAsync = GetServerEventAsync; + _1.GetServerFunctionAsync = GetServerFunctionAsync; +end; +exports = Net; +return exports; diff --git a/lualib/internal.lua b/lualib/internal.lua new file mode 100644 index 00000000..b9920e20 --- /dev/null +++ b/lualib/internal.lua @@ -0,0 +1,121 @@ +-- Compiled with https://roblox-ts.github.io v0.2.14 +-- August 6, 2019, 5:39 PM New Zealand Standard Time + +local exports = {}; +local replicatedStorage = game:GetService("ReplicatedStorage"); +local runService = game:GetService("RunService"); +local IS_SERVER = runService:IsServer(); +local IS_CLIENT = (__LEMUR__ and not runService:IsServer()) or runService:IsClient(); +local MAX_CLIENT_WAITFORCHILD_TIMEOUT = 10; +local function getGlobalRemote(name) + return "$" .. name; +end; +local function isLuaTable(value) + return (typeof(value) == "table"); +end; +local REMOTES_FOLDER_NAME = "Remotes"; +local FUNCTIONS_FOLDER_NAME = "Functions"; +local EVENTS_FOLDER_NAME = "Events"; +local remoteFolder; +local eventFolder; +local functionFolder; +local ServerTickFunctions = {}; +local function findOrCreateFolder(parent, name) + local folder = parent:FindFirstChild(name); + if folder then + return folder; + else + folder = Instance.new("Folder", parent); + folder.Name = name; + return folder; + end; +end; +remoteFolder = findOrCreateFolder(replicatedStorage, REMOTES_FOLDER_NAME); +functionFolder = findOrCreateFolder(remoteFolder, FUNCTIONS_FOLDER_NAME); +eventFolder = findOrCreateFolder(remoteFolder, EVENTS_FOLDER_NAME); +local function errorft(message, vars) + message = message:gsub("{([%w_][%w%d_]*)}", function(token) + return vars[token] or token; + end); + error(message, 2); +end; +local function eventExists(name) + return eventFolder:FindFirstChild(name) ~= nil; +end; +local function functionExists(name) + return functionFolder:FindFirstChild(name) ~= nil; +end; +local function waitForEvent(name, timeOut) + return eventFolder:WaitForChild(name, timeOut); +end; +local function waitForFunction(name, timeOut) + return functionFolder:WaitForChild(name, timeOut); +end; +local function getRemoteFolder(remoteType) + local targetFolder; + if remoteType == "RemoteEvent" then + targetFolder = eventFolder; + elseif remoteType == "RemoteFunction" then + targetFolder = functionFolder; + else + return error("Invalid type: " .. remoteType); + end; + return targetFolder; +end; +local function findRemote(remoteType, name) + local targetFolder = getRemoteFolder(remoteType); + local existing = targetFolder:FindFirstChild(name); + return existing; +end; +local function getRemoteOrThrow(remoteType, name) + local existing = findRemote(remoteType, name); + if existing then + return existing; + else + error("Could not find Remote of type " .. remoteType .. " called \"" .. name .. "\""); + end; +end; +local function findOrCreateRemote(remoteType, name) + local existing = findRemote(remoteType, name); + if existing then + return existing; + else + if not IS_SERVER then + error("Creation of Events or Functions must be done on server!"); + end; + local remote; + if remoteType == "RemoteEvent" or remoteType == "RemoteFunction" then + remote = Instance.new(remoteType); + else + error("Invalid Remote Type: " .. remoteType); + end; + remote.Name = name; + remote.Parent = getRemoteFolder(remoteType); + return remote; + end; +end; +if IS_SERVER then + game:GetService("RunService").Stepped:Connect(function(time, step) + for _0 = 1, #ServerTickFunctions do + local f = ServerTickFunctions[_0]; + f(); + end; + end); +end; +exports.IS_SERVER = IS_SERVER; +exports.IS_CLIENT = IS_CLIENT; +exports.MAX_CLIENT_WAITFORCHILD_TIMEOUT = MAX_CLIENT_WAITFORCHILD_TIMEOUT; +exports.getGlobalRemote = getGlobalRemote; +exports.isLuaTable = isLuaTable; +exports.ServerTickFunctions = ServerTickFunctions; +exports.findOrCreateFolder = findOrCreateFolder; +exports.errorft = errorft; +exports.eventExists = eventExists; +exports.functionExists = functionExists; +exports.waitForEvent = waitForEvent; +exports.waitForFunction = waitForFunction; +exports.getRemoteFolder = getRemoteFolder; +exports.findRemote = findRemote; +exports.getRemoteOrThrow = getRemoteOrThrow; +exports.findOrCreateRemote = findOrCreateRemote; +return exports; diff --git a/lualib/vendor/Promise.lua b/lualib/vendor/Promise.lua new file mode 100644 index 00000000..63037bc8 --- /dev/null +++ b/lualib/vendor/Promise.lua @@ -0,0 +1,578 @@ +--[[ + An implementation of Promises similar to Promise/A+. + Forked from LPGhatguy/roblox-lua-promise, modified for roblox-ts. +]] + +local PROMISE_DEBUG = false + +--[[ + Packs a number of arguments into a table and returns its length. + + Used to cajole varargs without dropping sparse values. +]] +local function pack(...) + local len = select("#", ...) + + return len, { ... } +end + +--[[ + wpcallPacked is a version of xpcall that: + * Returns the length of the result first + * Returns the result packed into a table + * Passes extra arguments through to the passed function; xpcall doesn't + * Issues a warning if PROMISE_DEBUG is enabled +]] +local function wpcallPacked(f, ...) + local argsLength, args = pack(...) + + local body = function() + return f(unpack(args, 1, argsLength)) + end + + local resultLength, result = pack(xpcall(body, debug.traceback)) + + -- If promise debugging is on, warn whenever a pcall fails. + -- This is useful for debugging issues within the Promise implementation + -- itself. + if PROMISE_DEBUG and not result[1] then + warn(result[2]) + end + + return resultLength, result +end + +--[[ + Creates a function that invokes a callback with correct error handling and + resolution mechanisms. +]] +local function createAdvancer(callback, resolve, reject) + return function(...) + local resultLength, result = wpcallPacked(callback, ...) + local ok = result[1] + + if ok then + resolve(unpack(result, 2, resultLength)) + else + reject(unpack(result, 2, resultLength)) + end + end +end + +local function isEmpty(t) + return next(t) == nil +end + +local function createSymbol(name) + assert(type(name) == "string", "createSymbol requires `name` to be a string.") + + local symbol = newproxy(true) + + getmetatable(symbol).__tostring = function() + return ("Symbol(%s)"):format(name) + end + + return symbol +end + +local PromiseMarker = createSymbol("PromiseMarker") + +local Promise = {} +Promise.prototype = {} +Promise.__index = Promise.prototype + +Promise.Status = { + Started = createSymbol("Started"), + Resolved = createSymbol("Resolved"), + Rejected = createSymbol("Rejected"), + Cancelled = createSymbol("Cancelled"), +} + +--[[ + Constructs a new Promise with the given initializing callback. + + This is generally only called when directly wrapping a non-promise API into + a promise-based version. + + The callback will receive 'resolve' and 'reject' methods, used to start + invoking the promise chain. + + For example: + + local function get(url) + return Promise.new(function(resolve, reject) + spawn(function() + resolve(HttpService:GetAsync(url)) + end) + end) + end + + get("https://google.com") + :andThen(function(stuff) + print("Got some stuff!", stuff) + end) + + Second parameter, parent, is used internally for tracking the "parent" in a + promise chain. External code shouldn't need to worry about this. +]] +function Promise.new(callback, parent) + if parent ~= nil and not Promise.is(parent) then + error("Argument #2 to Promise.new must be a promise or nil", 2) + end + + local self = { + -- Used to locate where a promise was created + _source = debug.traceback(), + + -- A tag to identify us as a promise + [PromiseMarker] = true, + + _status = Promise.Status.Started, + + -- A table containing a list of all results, whether success or failure. + -- Only valid if _status is set to something besides Started + _values = nil, + + -- Lua doesn't like sparse arrays very much, so we explicitly store the + -- length of _values to handle middle nils. + _valuesLength = -1, + + -- If an error occurs with no observers, this will be set. + _unhandledRejection = false, + + -- Queues representing functions we should invoke when we update! + _queuedResolve = {}, + _queuedReject = {}, + _queuedFinally = {}, + + -- The function to run when/if this promise is cancelled. + _cancellationHook = nil, + + -- The "parent" of this promise in a promise chain. Required for + -- cancellation propagation. + _parent = parent, + + -- The number of consumers attached to this promise. This is needed so that + -- we don't propagate promise cancellations when there are still uncancelled + -- consumers. + _numConsumers = 0, + } + + setmetatable(self, Promise) + + local function resolve(...) + self:_resolve(...) + end + + local function reject(...) + self:_reject(...) + end + + local function onCancel(cancellationHook) + assert(type(cancellationHook) == "function", "onCancel must be called with a function as its first argument.") + + if self._status == Promise.Status.Cancelled then + cancellationHook() + else + self._cancellationHook = cancellationHook + end + end + + local _, result = wpcallPacked(callback, resolve, reject, onCancel) + local ok = result[1] + local err = result[2] + + if not ok and self._status == Promise.Status.Started then + reject(err) + end + + return self +end + +--[[ + Fast spawn: Spawns a thread with predictable timing. + Runs immediately instead of first cycle being deferred. +]] +local spawnBindable = Instance.new("BindableEvent") +function Promise.spawn(callback, ...) + local args = { ... } + local length = select("#", ...) + local connection = spawnBindable.Event:Connect(function() + callback(unpack(args, 1, length)) + end) + spawnBindable:Fire() + connection:Disconnect() +end + +--[[ + Create a promise that represents the immediately resolved value. +]] +function Promise.resolve(value) + return Promise.new(function(resolve) + resolve(value) + end) +end + +--[[ + Create a promise that represents the immediately rejected value. +]] +function Promise.reject(value) + return Promise.new(function(_, reject) + reject(value) + end) +end + +--[[ + Returns a new promise that: + * is resolved when all input promises resolve + * is rejected if ANY input promises reject +]] +function Promise.all(promises) + if type(promises) ~= "table" then + error("Please pass a list of promises to Promise.all", 2) + end + + -- If there are no values then return an already resolved promise. + if #promises == 0 then + return Promise.resolve({}) + end + + -- We need to check that each value is a promise here so that we can produce + -- a proper error rather than a rejected promise with our error. + for i = 1, #promises do + if not Promise.is(promises[i]) then + error(("Non-promise value passed into Promise.all at index #%d"):format(i), 2) + end + end + + return Promise.new(function(resolve, reject) + -- An array to contain our resolved values from the given promises. + local resolvedValues = {} + + -- Keep a count of resolved promises because just checking the resolved + -- values length wouldn't account for promises that resolve with nil. + local resolvedCount = 0 + + -- Called when a single value is resolved and resolves if all are done. + local function resolveOne(i, ...) + resolvedValues[i] = ... + resolvedCount = resolvedCount + 1 + + if resolvedCount == #promises then + resolve(resolvedValues) + end + end + + -- We can assume the values inside `promises` are all promises since we + -- checked above. + for i = 1, #promises do + promises[i]:andThen( + function(...) + resolveOne(i, ...) + end, + function(...) + reject(...) + end + ) + end + end) +end + +--[[ + Is the given object a Promise instance? +]] +function Promise.is(object) + if type(object) ~= "table" then + return false + end + + return object[PromiseMarker] == true +end + +function Promise.prototype:getStatus() + return self._status +end + +function Promise.prototype:isRejected() + return self._status == Promise.Status.Rejected +end + +function Promise.prototype:isResolved() + return self._status == Promise.Status.Resolved +end + +function Promise.prototype:isPending() + return self._status == Promise.Status.Started +end + +function Promise.prototype:isCancelled() + return self._status == Promise.Status.Cancelled +end + +--[[ + Creates a new promise that receives the result of this promise. + + The given callbacks are invoked depending on that result. +]] +function Promise.prototype:andThen(successHandler, failureHandler) + self._unhandledRejection = false + self._numConsumers = self._numConsumers + 1 + + -- Create a new promise to follow this part of the chain + return Promise.new(function(resolve, reject) + -- Our default callbacks just pass values onto the next promise. + -- This lets success and failure cascade correctly! + + local successCallback = resolve + if successHandler then + successCallback = createAdvancer(successHandler, resolve, reject) + end + + local failureCallback = reject + if failureHandler then + failureCallback = createAdvancer(failureHandler, resolve, reject) + end + + if self._status == Promise.Status.Started then + -- If we haven't resolved yet, put ourselves into the queue + table.insert(self._queuedResolve, successCallback) + table.insert(self._queuedReject, failureCallback) + elseif self._status == Promise.Status.Resolved then + -- This promise has already resolved! Trigger success immediately. + successCallback(unpack(self._values, 1, self._valuesLength)) + elseif self._status == Promise.Status.Rejected then + -- This promise died a terrible death! Trigger failure immediately. + failureCallback(unpack(self._values, 1, self._valuesLength)) + elseif self._status == Promise.Status.Cancelled then + -- We don't want to call the success handler or the failure handler, + -- we just reject this promise outright. + reject("Promise is cancelled") + end + end, self) +end + +--[[ + Used to catch any errors that may have occurred in the promise. +]] +function Promise.prototype:catch(failureCallback) + return self:andThen(nil, failureCallback) +end + +--[[ + Cancels the promise, disallowing it from rejecting or resolving, and calls + the cancellation hook if provided. +]] +function Promise.prototype:cancel() + if self._status ~= Promise.Status.Started then + return + end + + self._status = Promise.Status.Cancelled + + if self._cancellationHook then + self._cancellationHook() + end + + if self._parent then + self._parent:_consumerCancelled() + end + + self:_finalize() +end + +--[[ + Used to decrease the number of consumers by 1, and if there are no more, + cancel this promise. +]] +function Promise.prototype:_consumerCancelled() + self._numConsumers = self._numConsumers - 1 + + if self._numConsumers <= 0 then + self:cancel() + end +end + +--[[ + Used to set a handler for when the promise resolves, rejects, or is + cancelled. Returns a new promise chained from this promise. +]] +function Promise.prototype:finally(finallyHandler) + self._numConsumers = self._numConsumers + 1 + + -- Return a promise chained off of this promise + return Promise.new(function(resolve, reject) + local finallyCallback = resolve + if finallyHandler then + finallyCallback = createAdvancer(finallyHandler, resolve, reject) + end + + if self._status == Promise.Status.Started then + -- The promise is not settled, so queue this. + table.insert(self._queuedFinally, finallyCallback) + else + -- The promise already settled or was cancelled, run the callback now. + finallyCallback() + end + end, self) +end + +--[[ + Yield until the promise is completed. + + This matches the execution model of normal Roblox functions. +]] +function Promise.prototype:await() + self._unhandledRejection = false + + if self._status == Promise.Status.Started then + local result + local resultLength + local bindable = Instance.new("BindableEvent") + + self:andThen( + function(...) + resultLength, result = pack(...) + bindable:Fire(true) + end, + function(...) + resultLength, result = pack(...) + bindable:Fire(false) + end + ) + self:finally(function() + bindable:Fire(nil) + end) + + local ok = bindable.Event:Wait() + bindable:Destroy() + + if ok == nil then + -- If cancelled, we return nil. + return nil + end + + return ok, unpack(result, 1, resultLength) + elseif self._status == Promise.Status.Resolved then + return true, unpack(self._values, 1, self._valuesLength) + elseif self._status == Promise.Status.Rejected then + return false, unpack(self._values, 1, self._valuesLength) + end + + -- If the promise is cancelled, fall through to nil. + return nil +end + +--[[ + Intended for use in tests. + + Similar to await(), but instead of yielding if the promise is unresolved, + _unwrap will throw. This indicates an assumption that a promise has + resolved. +]] +function Promise.prototype:_unwrap() + if self._status == Promise.Status.Started then + error("Promise has not resolved or rejected.", 2) + end + + local success = self._status == Promise.Status.Resolved + + return success, unpack(self._values, 1, self._valuesLength) +end + +function Promise.prototype:_resolve(...) + if self._status ~= Promise.Status.Started then + return + end + + -- If the resolved value was a Promise, we chain onto it! + if Promise.is((...)) then + -- Without this warning, arguments sometimes mysteriously disappear + if select("#", ...) > 1 then + local message = ( + "When returning a Promise from andThen, extra arguments are " .. + "discarded! See:\n\n%s" + ):format( + self._source + ) + warn(message) + end + + (...):andThen( + function(...) + self:_resolve(...) + end, + function(...) + self:_reject(...) + end + ) + + return + end + + self._status = Promise.Status.Resolved + self._valuesLength, self._values = pack(...) + + -- We assume that these callbacks will not throw errors. + for _, callback in ipairs(self._queuedResolve) do + callback(...) + end + + self:_finalize() +end + +function Promise.prototype:_reject(...) + if self._status ~= Promise.Status.Started then + return + end + + self._status = Promise.Status.Rejected + self._valuesLength, self._values = pack(...) + + -- If there are any rejection handlers, call those! + if not isEmpty(self._queuedReject) then + -- We assume that these callbacks will not throw errors. + for _, callback in ipairs(self._queuedReject) do + callback(...) + end + else + -- At this point, no one was able to observe the error. + -- An error handler might still be attached if the error occurred + -- synchronously. We'll wait one tick, and if there are still no + -- observers, then we should put a message in the console. + + self._unhandledRejection = true + local err = tostring((...)) + + spawn(function() + -- Someone observed the error, hooray! + if not self._unhandledRejection then + return + end + + -- Build a reasonable message + local message = ("Unhandled promise rejection:\n\n%s\n\n%s"):format( + err, + self._source + ) + warn(message) + end) + end + + self:_finalize() +end + +--[[ + Calls any :finally handlers. We need this to be a separate method and + queue because we must call all of the finally callbacks upon a success, + failure, *and* cancellation. +]] +function Promise.prototype:_finalize() + for _, callback in ipairs(self._queuedFinally) do + -- Purposefully not passing values to callbacks here, as it could be the + -- resolved values, or rejected errors. If the developer needs the values, + -- they should use :andThen or :catch explicitly. + callback() + end +end + +return Promise diff --git a/lualib/vendor/RuntimeLib.lua b/lualib/vendor/RuntimeLib.lua new file mode 100644 index 00000000..0a80c1f2 --- /dev/null +++ b/lualib/vendor/RuntimeLib.lua @@ -0,0 +1,1057 @@ +local Promise = require(script.Parent.Promise) + +local HttpService = game:GetService("HttpService") + +-- constants +local table_sort = table.sort +local table_concat = table.concat +local math_ceil = math.ceil +local math_floor = math.floor + +local TS = {} + +-- runtime classes +TS.Promise = Promise + +local Symbol do + Symbol = {} + Symbol.__index = Symbol + setmetatable(Symbol, { + __call = function(_, description) + local self = setmetatable({}, Symbol) + self.description = "Symbol(" .. (description or "") .. ")" + return self + end + }) + + local symbolRegistry = setmetatable({}, { + __index = function(self, k) + self[k] = Symbol(k) + return self[k] + end + }) + + function Symbol:toString() + return self.description + end + + Symbol.__tostring = Symbol.toString + + -- Symbol.for + function Symbol.getFor(key) + return symbolRegistry[key] + end + + function Symbol.keyFor(goalSymbol) + for key, symbol in pairs(symbolRegistry) do + if symbol == goalSymbol then + return key + end + end + end +end + +TS.Symbol = Symbol +TS.Symbol_iterator = Symbol("Symbol.iterator") + +-- module resolution +local globalModules = script.Parent:FindFirstChild("node_modules") + +function TS.getModule(moduleName) + local object = getfenv(2).script.Parent + if not globalModules then + error("Could not find any modules!", 2) + end + if object:IsDescendantOf(globalModules) then + while object.Parent do + local modules = object:FindFirstChild("node_modules") + if modules then + local module = modules:FindFirstChild(moduleName) + if module then + return module + end + end + object = object.Parent + end + else + local module = globalModules:FindFirstChild(moduleName) + if module then + return module + end + end + error("Could not find module: " .. moduleName, 2) +end + +-- This is a hash which TS.import uses as a kind of linked-list-like history of [Script who Loaded] -> Library +local loadedLibraries = {} +local currentlyLoading = {} + +function TS.import(module, ...) + for i = 1, select("#", ...) do + module = module:WaitForChild((select(i, ...))) + end + + if module.ClassName == "ModuleScript" then + local data = loadedLibraries[module] + + if data == nil then + -- If called from command bar, use table as a reference (this is never concatenated) + local caller = getfenv(0).script or { Name = "Command bar" } + currentlyLoading[caller] = module + + -- Check to see if a case like this occurs: + -- module -> Module1 -> Module2 -> module + + -- WHERE currentlyLoading[module] is Module1 + -- and currentlyLoading[Module1] is Module2 + -- and currentlyLoading[Module2] is module + + local currentModule = module + local depth = 0 + + while currentModule do + depth = depth + 1 + currentModule = currentlyLoading[currentModule] + + if currentModule == module then + local str = currentModule.Name -- Get the string traceback + + for _ = 1, depth do + currentModule = currentlyLoading[currentModule] + str = str .. " -> " .. currentModule.Name + end + + error("Failed to import! Detected a circular dependency chain: " .. str, 2) + end + end + + assert(_G[module] == nil, "Invalid module access!") + _G[module] = TS + data = { value = require(module) } + + if currentlyLoading[caller] == module then -- Thread-safe cleanup! + currentlyLoading[caller] = nil + end + + loadedLibraries[module] = data -- Cache for subsequent calls + end + + return data.value + else + error("Failed to import! Expected ModuleScript, got " .. module.ClassName, 2) + end +end + +function TS.exportNamespace(module, ancestor) + for key, value in pairs(module) do + ancestor[key] = value + end +end + +-- general utility functions +function TS.instanceof(obj, class) + -- custom Class.instanceof() check + if type(class) == "table" and type(class.instanceof) == "function" then + return class.instanceof(obj) + end + + -- metatable check + if type(obj) == "table" then + obj = getmetatable(obj) + while obj ~= nil do + if obj == class then + return true + end + local mt = getmetatable(obj) + if mt then + obj = mt.__index + else + obj = nil + end + end + end + + return false +end + +function TS.async(callback) + return function(...) + local n = select("#", ...) + local args = { ... } + return Promise.new(function(resolve, reject) + coroutine.wrap(function() + local ok, result = pcall(callback, unpack(args, 1, n)) + if ok then + resolve(result) + else + reject(result) + end + end)() + end) + end +end + +function TS.await(promise) + if not Promise.is(promise) then + return promise + end + + local ok, result = promise:await() + if ok then + return result + else + TS.throw(ok == nil and "The awaited Promise was cancelled" or result) + end +end + +function TS.add(a, b) + if type(a) == "string" or type(b) == "string" then + return a .. b + else + return a + b + end +end + +local function bitTruncate(a) + if a < 0 then + return math_ceil(a) + else + return math_floor(a) + end +end + +TS.bit_truncate = bitTruncate + +-- bitwise operations +local powOfTwo = setmetatable({}, { + __index = function(self, i) + local v = 2 ^ i + self[i] = v + return v + end; +}) + +local _2_52 = powOfTwo[52] +local function bitop(a, b, oper) + local r, m, s = 0, _2_52 + repeat + s, a, b = a + b + m, a % m, b % m + r, m = r + m * oper % (s - a - b), m / 2 + until m < 1 + return r +end + +function TS.bit_not(a) + return -a - 1 +end + +function TS.bit_or(a, b) + a = bitTruncate(tonumber(a)) + b = bitTruncate(tonumber(b)) + return bitop(a, b, 1) +end + +function TS.bit_and(a, b) + a = bitTruncate(tonumber(a)) + b = bitTruncate(tonumber(b)) + return bitop(a, b, 4) +end + +function TS.bit_xor(a, b) + a = bitTruncate(tonumber(a)) + b = bitTruncate(tonumber(b)) + return bitop(a, b, 3) +end + +function TS.bit_lsh(a, b) + a = bitTruncate(tonumber(a)) + b = bitTruncate(tonumber(b)) + return a * powOfTwo[b] +end + +function TS.bit_rsh(a, b) + a = bitTruncate(tonumber(a)) + b = bitTruncate(tonumber(b)) + return bitTruncate(a / powOfTwo[b]) +end + +function TS.bit_lrsh(a, b) + a = bitTruncate(tonumber(a)) + b = bitTruncate(tonumber(b)) + if a >= 0 then return TS.bit_rsh(a, b) end + return TS.bit_rsh((a % powOfTwo[32]), b) +end + +-- utility functions +local function copy(object) + local result = {} + for k, v in pairs(object) do + result[k] = v + end + return result +end + +local function deepCopy(object) + local result = {} + for k, v in pairs(object) do + if type(v) == "table" then + result[k] = deepCopy(v) + else + result[k] = v + end + end + return result +end + +local function deepEquals(a, b) + -- a[k] == b[k] + for k in pairs(a) do + local av = a[k] + local bv = b[k] + if type(av) == "table" and type(bv) == "table" then + local result = deepEquals(av, bv) + if not result then + return false + end + elseif av ~= bv then + return false + end + end + + -- extra keys in b + for k in pairs(b) do + if a[k] == nil then + return false + end + end + + return true +end + +-- Object static functions + +function TS.Object_keys(object) + local result = {} + for key in pairs(object) do + result[#result + 1] = key + end + return result +end + +function TS.Object_values(object) + local result = {} + for _, value in pairs(object) do + result[#result + 1] = value + end + return result +end + +function TS.Object_entries(object) + local result = {} + for key, value in pairs(object) do + result[#result + 1] = { key, value } + end + return result +end + +function TS.Object_assign(toObj, ...) + for i = 1, select("#", ...) do + local arg = select(i, ...) + if type(arg) == "table" then + for key, value in pairs(arg) do + toObj[key] = value + end + end + end + return toObj +end + +TS.Object_copy = copy + +TS.Object_deepCopy = deepCopy + +TS.Object_deepEquals = deepEquals + +local function toString(data) + return HttpService:JSONEncode(data) +end + +TS.Object_toString = toString + +-- string macro functions +function TS.string_find_wrap(a, b, ...) + if a then + return a - 1, b - 1, ... + end +end + +-- array macro functions +local function array_copy(list) + local result = {} + for i = 1, #list do + result[i] = list[i] + end + return result +end + +TS.array_copy = array_copy + +function TS.array_entries(list) + local result = {} + for key = 1, #list do + result[key] = { key - 1, list[key] } + end + return result +end + +function TS.array_forEach(list, callback) + for i = 1, #list do + local v = list[i] + if v ~= nil then + callback(v, i - 1, list) + end + end +end + +function TS.array_map(list, callback) + local result = {} + for i = 1, #list do + local v = list[i] + if v ~= nil then + result[i] = callback(v, i - 1, list) + end + end + return result +end + +function TS.array_filter(list, callback) + local result = {} + for i = 1, #list do + local v = list[i] + if v ~= nil and callback(v, i - 1, list) == true then + result[#result + 1] = v + end + end + return result +end + +local function sortFallback(a, b) + return tostring(a) < tostring(b) +end + +function TS.array_sort(list, callback) + local sorted = array_copy(list) + + if callback then + table_sort(sorted, function(a, b) + return 0 < callback(a, b) + end) + else + table_sort(sorted, sortFallback) + end + + return sorted +end + +TS.array_toString = toString + +function TS.array_slice(list, startI, endI) + local length = #list + + if startI == nil then startI = 0 end + if endI == nil then endI = length end + + if startI < 0 then startI = length + startI end + if endI < 0 then endI = length + endI end + + local result = {} + + for i = startI + 1, endI do + result[i - startI] = list[i] + end + + return result +end + +function TS.array_splice(list, start, deleteCount, ...) + local len = #list + local actualStart + if start < 0 then + actualStart = len + start + if actualStart < 0 then + actualStart = 0 + end + else + if start < len then + actualStart = start + else + actualStart = len + end + end + local items = { ... } + local itemCount = #items + local actualDeleteCount + if start == nil then + actualDeleteCount = 0 + elseif deleteCount == nil then + actualDeleteCount = len - actualStart + else + if deleteCount < 0 then + deleteCount = 0 + end + actualDeleteCount = len - actualStart + if deleteCount < actualDeleteCount then + actualDeleteCount = deleteCount + end + end + local out = {} + local k = 0 + while k < actualDeleteCount do + local from = actualStart + k + if list[from + 1] then + out[k + 1] = list[from + 1] + end + k = k + 1 + end + if itemCount < actualDeleteCount then + k = actualStart + while k < len - actualDeleteCount do + local from = k + actualDeleteCount + local to = k + itemCount + if list[from + 1] then + list[to + 1] = list[from + 1] + else + list[to + 1] = nil + end + k = k + 1 + end + k = len + while k > len - actualDeleteCount + itemCount do + list[k] = nil + k = k - 1 + end + elseif itemCount > actualDeleteCount then + k = len - actualDeleteCount + while k > actualStart do + local from = k + actualDeleteCount + local to = k + itemCount + if list[from] then + list[to] = list[from] + else + list[to] = nil + end + k = k - 1 + end + end + k = actualStart + for i = 1, #items do + list[k + 1] = items[i] + k = k + 1 + end + k = #list + while k > len - actualDeleteCount + itemCount do + list[k] = nil + k = k - 1 + end + return out +end + +function TS.array_some(list, callback) + for i = 1, #list do + local v = list[i] + if v ~= nil and callback(v, i - 1, list) == true then + return true + end + end + return false +end + +function TS.array_every(list, callback) + for i = 1, #list do + local v = list[i] + if v ~= nil and callback(v, i - 1, list) == false then + return false + end + end + return true +end + +function TS.array_includes(list, item, startingIndex) + for i = (startingIndex or 0) + 1, #list do + if list[i] == item then + return true + end + end + return false +end + +function TS.array_indexOf(list, value, fromIndex) + for i = (fromIndex or 0) + 1, #list do + if value == list[i] then + return i - 1 + end + end + return -1 +end + +function TS.array_lastIndexOf(list, value, fromIndex) + for i = (fromIndex or #list - 1) + 1, 1, -1 do + if value == list[i] then + return i - 1 + end + end + return -1 +end + +function TS.array_reverse(list) + local result = {} + local length = #list + local n = length + 1 + for i = 1, length do + result[i] = list[n - i] + end + return result +end + +function TS.array_reduce(list, callback, initialValue) + local start = 1 + if initialValue == nil then + initialValue = list[start] + start = 2 + end + local accumulator = initialValue + for i = start, #list do + local v = list[i] + if v ~= nil then + accumulator = callback(accumulator, v, i) + end + end + return accumulator +end + +function TS.array_reduceRight(list, callback, initialValue) + local start = #list + if initialValue == nil then + initialValue = list[start] + start = start - 1 + end + local accumulator = initialValue + for i = start, 1, -1 do + local v = list[i] + if v ~= nil then + accumulator = callback(accumulator, v, i) + end + end + return accumulator +end + +function TS.array_unshift(list, ...) + local n = #list + local argsLength = select("#", ...) + for i = n, 1, -1 do + list[i + argsLength] = list[i] + end + for i = 1, argsLength do + list[i] = select(i, ...) + end + return n + argsLength +end + +local function array_push_apply(list, ...) + local len = #list + for i = 1, select("#", ...) do + local list2 = select(i, ...) + local len2 = #list2 + for j = 1, len2 do + list[len + j] = list2[j] + end + len = len + len2 + end + return len +end + +TS.array_push_apply = array_push_apply + +function TS.array_push_stack(list, ...) + local len = #list + local len2 = select("#", ...) + for i = 1, len2 do + list[len + i] = select(i, ...) + end + return len + len2 +end + +function TS.array_concat(...) + local result = {} + array_push_apply(result, ...) + return result +end + +function TS.array_join(list, separator) + local result = {} + for i = 1, #list do + local item = list[i] + if item == nil then + result[i] = "" + else + result[i] = tostring(list[i]) + end + end + return table_concat(result, separator or ",") +end + +function TS.array_find(list, callback) + for i = 1, #list do + local v = list[i] + if callback(v, i - 1, list) == true then + return v + end + end +end + +function TS.array_findIndex(list, callback) + for i = 0, #list - 1 do + if callback(list[i + 1], i, list) == true then + return i + end + end + return -1 +end + +local function array_flat_helper(list, depth, count, result) + for i = 1, #list do + local v = list[i] + + if v ~= nil then + if type(v) == "table" then + if depth ~= 0 then + count = array_flat_helper(v, depth - 1, count, result) + else + count = count + 1 + result[count] = v + end + else + count = count + 1 + result[count] = v + end + end + end + + return count +end + +function TS.array_flat(list, depth) + local result = {} + array_flat_helper(list, depth or 1, 0, result) + return result +end + +function TS.array_fill(list, value, from, to) + local length = #list + + if from == nil then + from = 0 + elseif from < 0 then + from = from + length + end + + if to == nil or to > length then + to = length + elseif to < 0 then + to = to + length + end + + for i = from + 1, to do + list[i] = value + end + + return list +end + +function TS.array_copyWithin(list, target, from, to) + local length = #list + + if target < 0 then + target = target + length + end + + if from == nil then + from = 0 + elseif from < 0 then + from = from + length + end + + if to == nil or to > length then + to = length + elseif to < 0 then + to = to + length + end + + local tf = target - from + local overshoot = to + tf - length + + if overshoot > 0 then + to = from + length - target + end + + for i = to, from + 1, -1 do + list[i + tf] = list[i] + end + + return list +end + +TS.array_deepCopy = deepCopy + +TS.array_deepEquals = deepEquals + +-- map macro functions + +function TS.map_new(pairs) + local result = {} + for i = 1, #pairs do + local pair = pairs[i] + result[pair[1]] = pair[2] + end + return result +end + +function TS.map_clear(map) + for key in pairs(map) do + map[key] = nil + end +end + +local function getNumKeys(map) + local result = 0 + for _ in pairs(map) do + result = result + 1 + end + return result +end + +TS.map_size = getNumKeys +TS.map_entries = TS.Object_entries + +function TS.map_forEach(map, callback) + for key, value in pairs(map) do + callback(value, key, map) + end +end + +TS.map_keys = TS.Object_keys + +TS.map_values = TS.Object_values +TS.map_toString = toString + +-- set macro functions + +function TS.set_new(values) + local result = {} + for i = 1, #values do + result[values[i]] = true + end + return result +end + +TS.set_clear = TS.map_clear + +function TS.set_forEach(set, callback) + for key in pairs(set) do + callback(key, key, set) + end +end + +function TS.set_union(set1, set2) + local result = {} + + for value in pairs(set1) do + result[value] = true + end + + for value in pairs(set2) do + result[value] = true + end + + return result +end + +function TS.set_intersect(set1, set2) + local result = {} + + for value in pairs(set1) do + if set2[value] then + result[value] = true + end + end + + return result +end + +function TS.set_isDisjointWith(set1, set2) + for value in pairs(set1) do + if set2[value] then + return false + end + end + return true +end + +function TS.set_isSubsetOf(set1, set2) + for value in pairs(set1) do + if set2[value] == nil then + return false + end + end + + return true +end + +function TS.set_difference(set1, set2) + local result = {} + for value in pairs(set1) do + if set2[value] == nil then + result[value] = true + end + end + return result +end + +TS.set_values = TS.Object_keys + +TS.set_size = getNumKeys + +TS.set_toString = toString + +function TS.iterableCache(iter) + local results = {} + local count = 0 + for _0 in iter.next do + if _0.done then break end + count = count + 1 + results[count] = _0.value + end + return results +end + +local function package(...) + return select("#", ...), {...} +end + +function TS.iterableFunctionCache(iter) + local results = {} + local count = 0 + + while true do + local size, t = package(iter()); + if size == 0 then break end + count = count + 1 + results[count] = t + end + + return results +end + +-- roact functions + +function TS.Roact_combine(...) + local args = { ... } + local result = {} + for i = 1, #args do + for key, value in pairs(args[i]) do + if (type(key) == "number") then + table.insert(result, value) + else + result[key] = value + end + end + end + return result +end + +-- opcall + +function TS.opcall(func, ...) + local success, valueOrErr = pcall(func, ...) + if success then + return { + success = true, + value = valueOrErr, + } + else + return { + success = false, + error = valueOrErr, + } + end +end + +-- try catch utilities + +local function pack(...) + return { size = select("#", ...), ... } +end + +local throwStack = {} + +function TS.throw(value) + if #throwStack > 0 then + throwStack[#throwStack](value) + else + error("Uncaught " .. tostring(value), 2) + end +end + +function TS.try(tryCallback, catchCallback) + local done = false + local yielded = false + local popped = false + local resumeThread = coroutine.running() + + local returns + + local function pop() + if not popped then + popped = true + throwStack[#throwStack] = nil + end + end + + local function resume() + if yielded then + local success, errorMsg = coroutine.resume(resumeThread) + if not success then + warn(errorMsg) + end + else + done = true + end + end + + local function throw(value) + pop() + if catchCallback then + returns = pack(catchCallback(value)) + end + resume() + coroutine.yield() + end + + throwStack[#throwStack + 1] = throw + + coroutine.wrap(function() + returns = pack(tryCallback()) + resume() + end)() + + if not done then + yielded = true + coroutine.yield() + end + + pop() + + return returns +end + +return TS