diff --git a/.gitmodules b/.gitmodules index c164164..051942e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "libs/yaml-cpp"] path = libs/yaml-cpp url = https://github.com/jbeder/yaml-cpp.git +[submodule "libs/json"] + path = libs/json + url = https://github.com/nlohmann/json.git diff --git a/.idea/vcs.xml b/.idea/vcs.xml index be8446c..444b296 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -3,6 +3,7 @@ + diff --git a/CMakeLists.txt b/CMakeLists.txt index 864c5a2..005e8e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,10 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin) +# Include OpenSSL. +find_package(OpenSSL REQUIRED) +include_directories(${OPENSSL_INCLUDE_DIR}) + # Include libuv. add_subdirectory(libs/libuv) include_directories(libs/libuv/include) @@ -18,6 +22,10 @@ include_directories(libs/libuv/include) # Include uvw (libuv C++ wrapper). include_directories(libs/uvw/src) +# Include JSON. +add_subdirectory(libs/json) +include_directories(libs/json/include) + # Include YAML-CPP. add_subdirectory(libs/yaml-cpp) include_directories(libs/yaml-cpp/include) @@ -63,7 +71,7 @@ add_executable(ardos ${ARDOS_SOURCES} ${ARDOS_HEADERS}) add_subdirectory(libs/dclass) include_directories(libs/dclass) -target_link_libraries(ardos PRIVATE uv yaml-cpp amqpcpp prometheus-cpp::pull) +target_link_libraries(ardos PRIVATE uv OpenSSL::SSL nlohmann_json::nlohmann_json yaml-cpp amqpcpp prometheus-cpp::pull) if (ARDOS_WANT_DB_SERVER) target_link_libraries(ardos PRIVATE mongo::mongocxx_shared mongo::bsoncxx_shared) diff --git a/config.example.yml b/config.example.yml index 1c95a98..b86c1c7 100644 --- a/config.example.yml +++ b/config.example.yml @@ -23,6 +23,9 @@ want-db-state-server: true # Do we want metrics collection on this instance? want-metrics: true +# Do we want a web interface running on this instance? +want-web-panel: true + # UberDOG definitions. # Some example ones follow: uberdogs: @@ -36,6 +39,22 @@ uberdogs: - id: 4667 class: FriendsManager +# Web Panel configuration. +# Can be accessed via the Ardos Web panel for debugging. +web-panel: + name: Ardos # The cluster name to appear in the dashboard. + port: 7781 # Port the WS connection listens on. + + # Auth options. + # Make sure to change these in PROD environments. + username: ardos + password: ardos + + # SSL/TLS config for web panel. + # The below two options can be omitted to disable SSL/TLS. + certificate: cert.pem + private-key: key.pem + # Metrics (Prometheus) configuration. # This should be configured as a target in your Prometheus config. metrics: diff --git a/libs/json b/libs/json new file mode 160000 index 0000000..9cca280 --- /dev/null +++ b/libs/json @@ -0,0 +1 @@ +Subproject commit 9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03 diff --git a/src/clientagent/client_agent.cpp b/src/clientagent/client_agent.cpp index 2287c47..9751ec4 100644 --- a/src/clientagent/client_agent.cpp +++ b/src/clientagent/client_agent.cpp @@ -4,6 +4,7 @@ #include "../util/globals.h" #include "../util/logger.h" #include "../util/metrics.h" +#include "../web/web_panel.h" #include "client_participant.h" namespace Ardos { @@ -27,7 +28,7 @@ ClientAgent::ClientAgent() { _version = config["version"].as(); // DC hash configuration. - // Can be manually overriden in CA config. + // Can be manually overridden in CA config. _dcHash = g_dc_file->get_hash(); if (auto manualHash = config["manual-dc-hash"]) { _dcHash = manualHash.as(); @@ -123,8 +124,7 @@ ClientAgent::ClientAgent() { srv.accept(*client); // Create a new client for this connected participant. - // TODO: These should be tracked in a vector. - new ClientParticipant(this, client); + _participants.insert(new ClientParticipant(this, client)); }); // Initialize metrics. @@ -259,10 +259,12 @@ void ClientAgent::ParticipantJoined() { /** * Called when a participant disconnects. */ -void ClientAgent::ParticipantLeft() { +void ClientAgent::ParticipantLeft(ClientParticipant *client) { if (_participantsGauge) { _participantsGauge->Decrement(); } + + _participants.erase(client); } /** @@ -354,4 +356,92 @@ void ClientAgent::InitMetrics() { _freeChannelsGauge->Set((double)(_channelsMax - _nextChannel)); } +void ClientAgent::HandleWeb(ws28::Client *client, nlohmann::json &data) { + if (data["msg"] == "init") { + // Build up an array of connected clients. + nlohmann::json clientInfo = nlohmann::json::array(); + for (const auto &participant : _participants) { + clientInfo.push_back({ + {"channel", std::to_string(participant->GetChannel())}, + {"ip", participant->GetRemoteAddress().ip}, + {"port", participant->GetRemoteAddress().port}, + {"state", participant->GetAuthState()}, + {"channels", participant->GetLocalChannels().size()}, + {"postRemoves", participant->GetPostRemoves().size()}, + }); + } + + WebPanel::Send(client, { + {"type", "ca:init"}, + {"success", true}, + {"listenIp", _host}, + {"listenPort", _port}, +#ifdef ARDOS_USE_LEGACY_CLIENT + {"legacy", true}, +#else + {"legacy", false}, +#endif + {"clients", clientInfo}, + }); + } else if (data["msg"] == "client") { + // We have to do this terribleness because JavaScript doesn't support + // uint64's. + auto channel = std::stoull(data["channel"].template get()); + + // Try to find a matching client for the provided channel. + auto participant = + std::find_if(_participants.begin(), _participants.end(), + [&channel](ClientParticipant *participant) { + return participant->GetChannel() == channel; + }); + if (participant == _participants.end()) { + WebPanel::Send(client, { + {"type", "ca:client"}, + {"success", false}, + }); + return; + } + + // Build an owned object array. + nlohmann::json ownedObjs = nlohmann::json::array(); + for (const auto &obj : (*participant)->GetOwnedObjects()) { + ownedObjs.push_back({{"doId", obj.first}, + {"clsName", obj.second.dcc->get_name()}, + {"parent", obj.second.parent}, + {"zone", obj.second.zone}}); + } + + // Build a session object array. + nlohmann::json sessionObjs = nlohmann::json::array(); + for (const auto &doId : (*participant)->GetSessionObjects()) { + sessionObjs.push_back({{"doId", doId}}); + } + + // Build an active interests array. + nlohmann::json interests = nlohmann::json::array(); + for (const auto &interest : (*participant)->GetInterests()) { + interests.push_back({{"id", interest.first}, + {"parent", interest.second.parent}, + {"zones", interest.second.zones}}); + } + + WebPanel::Send( + client, + { + {"type", "ca:client"}, + {"success", true}, + {"ip", (*participant)->GetRemoteAddress().ip}, + {"port", (*participant)->GetRemoteAddress().port}, + {"state", (*participant)->GetAuthState()}, + {"channelHi", ((*participant)->GetChannel() >> 32) & 0xFFFFFFFF}, + {"channelLo", (*participant)->GetChannel() & 0xFFFFFFFF}, + {"channels", (*participant)->GetLocalChannels().size()}, + {"postRemoves", (*participant)->GetPostRemoves().size()}, + {"owned", ownedObjs}, + {"session", sessionObjs}, + {"interests", interests}, + }); + } +} + } // namespace Ardos diff --git a/src/clientagent/client_agent.h b/src/clientagent/client_agent.h index 55a0b39..67c2e4a 100644 --- a/src/clientagent/client_agent.h +++ b/src/clientagent/client_agent.h @@ -3,13 +3,17 @@ #include #include +#include #include +#include #include #include #include #include +#include "../net/ws/Client.h" + namespace Ardos { struct Uberdog { @@ -24,6 +28,8 @@ enum InterestsPermission { INTERESTS_DISABLED, }; +class ClientParticipant; + class ClientAgent { public: ClientAgent(); @@ -43,11 +49,13 @@ class ClientAgent { [[nodiscard]] unsigned long GetInterestTimeout() const; void ParticipantJoined(); - void ParticipantLeft(); + void ParticipantLeft(ClientParticipant *client); void RecordDatagram(const uint16_t &size); void RecordInterestTimeout(); void RecordInterestTime(const double &seconds); + void HandleWeb(ws28::Client *client, nlohmann::json &data); + private: void InitMetrics(); @@ -66,6 +74,8 @@ class ClientAgent { std::unordered_map _uberdogs; + std::unordered_set _participants; + uint64_t _nextChannel; uint64_t _channelsMax; std::queue _freedChannels; diff --git a/src/clientagent/client_participant.cpp b/src/clientagent/client_participant.cpp index b166745..f4fcd5e 100644 --- a/src/clientagent/client_participant.cpp +++ b/src/clientagent/client_participant.cpp @@ -55,7 +55,7 @@ ClientParticipant::~ClientParticipant() { // Call shutdown just in-case (most likely redundant.) Shutdown(); - _clientAgent->ParticipantLeft(); + _clientAgent->ParticipantLeft(this); } /** @@ -338,8 +338,9 @@ void ClientParticipant::HandleDatagram(const std::shared_ptr &dg) { break; } case CLIENTAGENT_GET_NETWORK_ADDRESS: { - auto resp = std::make_shared(sender, _channel, CLIENTAGENT_GET_NETWORK_ADDRESS_RESP); - resp->AddUint32(dgi.GetUint32()); // Context. + auto resp = std::make_shared( + sender, _channel, CLIENTAGENT_GET_NETWORK_ADDRESS_RESP); + resp->AddUint32(dgi.GetUint32()); // Context. resp->AddString(GetRemoteAddress().ip); resp->AddUint16(GetRemoteAddress().port); resp->AddString(GetLocalAddress().ip); @@ -517,6 +518,9 @@ void ClientParticipant::HandleDatagram(const std::shared_ptr &dg) { case STATESERVER_OBJECT_CHANGING_LOCATION: { uint32_t doId = dgi.GetUint32(); if (TryQueuePending(doId, dgi.GetUnderlyingDatagram())) { + // The object that's changing location is currently generating inside an + // active InterestOperation. Queue this message to be handled after it + // generates. return; } diff --git a/src/clientagent/client_participant.h b/src/clientagent/client_participant.h index 698d033..af0ca17 100644 --- a/src/clientagent/client_participant.h +++ b/src/clientagent/client_participant.h @@ -44,6 +44,22 @@ class ClientParticipant final : public NetworkClient, public ChannelSubscriber { friend class InterestOperation; + [[nodiscard]] uint64_t GetChannel() const { return _channel; } + [[nodiscard]] uint8_t GetAuthState() const { return _authState; } + [[nodiscard]] std::vector> GetPostRemoves() const { + return _postRemoves; + } + [[nodiscard]] std::unordered_map + GetOwnedObjects() const { + return _ownedObjects; + } + [[nodiscard]] std::unordered_set GetSessionObjects() const { + return _sessionObjects; + } + [[nodiscard]] std::unordered_map GetInterests() const { + return _interests; + } + private: void Shutdown() override; diff --git a/src/database/database_server.cpp b/src/database/database_server.cpp index fd3bf6e..c380bda 100644 --- a/src/database/database_server.cpp +++ b/src/database/database_server.cpp @@ -12,6 +12,7 @@ #include "../util/globals.h" #include "../util/logger.h" #include "../util/metrics.h" +#include "../web/web_panel.h" #include "database_utils.h" // For document, finalize, et al. @@ -528,8 +529,7 @@ void DatabaseServer::HandleGetField(DatagramIterator &dgi, auto dbField = fields[field->get_name()]; if (dbField) { // Pack the field into our object datagram. - DatabaseUtils::PackField(field, dbField.get_value(), - objectDg); + DatabaseUtils::PackField(field, dbField.get_value(), objectDg); } else { // Pack a default value. objectDg.AddData(field->get_default_value()); @@ -1004,4 +1004,15 @@ void DatabaseServer::ReportFailed(const DatabaseServer::OperationType &type) { } } +void DatabaseServer::HandleWeb(ws28::Client *client, nlohmann::json &data) { + WebPanel::Send(client, { + {"type", "db"}, + {"success", true}, + {"host", _uri.to_string()}, + {"channel", _channel}, + {"minDoId", _minDoId}, + {"maxDoId", _maxDoId}, + }); +} + } // namespace Ardos diff --git a/src/database/database_server.h b/src/database/database_server.h index 4899c19..7fb31a0 100644 --- a/src/database/database_server.h +++ b/src/database/database_server.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -10,6 +11,7 @@ #include "../messagedirector/channel_subscriber.h" #include "../net/datagram_iterator.h" #include "../net/message_types.h" +#include "../net/ws/Client.h" namespace Ardos { @@ -17,6 +19,8 @@ class DatabaseServer final : public ChannelSubscriber { public: DatabaseServer(); + void HandleWeb(ws28::Client *client, nlohmann::json &data); + private: void HandleDatagram(const std::shared_ptr &dg) override; diff --git a/src/messagedirector/channel_subscriber.h b/src/messagedirector/channel_subscriber.h index ac5907d..c3192c1 100644 --- a/src/messagedirector/channel_subscriber.h +++ b/src/messagedirector/channel_subscriber.h @@ -33,6 +33,10 @@ class ChannelSubscriber { */ void PublishDatagram(const std::shared_ptr &dg); + [[nodiscard]] std::vector GetLocalChannels() const { + return _localChannels; + } + protected: virtual void HandleDatagram(const std::shared_ptr &dg) = 0; diff --git a/src/messagedirector/md_participant.cpp b/src/messagedirector/md_participant.cpp index 9846b4b..eda8899 100644 --- a/src/messagedirector/md_participant.cpp +++ b/src/messagedirector/md_participant.cpp @@ -20,7 +20,7 @@ MDParticipant::~MDParticipant() { // Call shutdown just in-case (most likely redundant.) Shutdown(); - MessageDirector::Instance()->ParticipantLeft(); + MessageDirector::Instance()->ParticipantLeft(this); } /** @@ -34,7 +34,8 @@ void MDParticipant::Shutdown() { // Kill the network connection. NetworkClient::Shutdown(); - // Unsubscribe from all channels so post removes aren't accidently routed to us. + // Unsubscribe from all channels so post removes aren't accidentally routed to + // us. ChannelSubscriber::Shutdown(); Logger::Verbose(std::format("[MD] Routing {} post-remove(s) for '{}'", diff --git a/src/messagedirector/md_participant.h b/src/messagedirector/md_participant.h index dc97751..e765473 100644 --- a/src/messagedirector/md_participant.h +++ b/src/messagedirector/md_participant.h @@ -16,6 +16,11 @@ class MDParticipant final : public NetworkClient, public ChannelSubscriber { explicit MDParticipant(const std::shared_ptr &socket); ~MDParticipant() override; + [[nodiscard]] std::string GetName() const { return _connName; } + [[nodiscard]] std::vector> GetPostRemoves() const { + return _postRemoves; + } + private: void Shutdown() override; void HandleDisconnect(uv_errno_t code) override; diff --git a/src/messagedirector/message_director.cpp b/src/messagedirector/message_director.cpp index 3abf4a0..aec1db6 100644 --- a/src/messagedirector/message_director.cpp +++ b/src/messagedirector/message_director.cpp @@ -6,11 +6,10 @@ #endif #include "../net/address_utils.h" #include "../stateserver/database_state_server.h" -#include "../stateserver/state_server.h" #include "../util/config.h" -#include "../util/globals.h" #include "../util/logger.h" #include "../util/metrics.h" +#include "../web/web_panel.h" #include "md_participant.h" namespace Ardos { @@ -42,13 +41,11 @@ MessageDirector::MessageDirector() { } // RabbitMQ configuration. - std::string rHost = "127.0.0.1"; if (auto hostParam = config["rabbitmq-host"]) { - rHost = hostParam.as(); + _rHost = hostParam.as(); } - int rPort = 5672; if (auto portParam = config["rabbitmq-port"]) { - rPort = portParam.as(); + _rPort = portParam.as(); } std::string user = "guest"; if (auto userParam = config["rabbitmq-user"]) { @@ -61,14 +58,13 @@ MessageDirector::MessageDirector() { // Socket events. _listenHandle->on( - [](const uvw::listen_event &, uvw::tcp_handle &srv) { + [this](const uvw::listen_event &, uvw::tcp_handle &srv) { std::shared_ptr client = srv.parent().resource(); srv.accept(*client); // Create a new client for this connected participant. - // TODO: These should be tracked in a vector. - new MDParticipant(client); + _participants.insert(new MDParticipant(client)); }); _connectHandle->on( @@ -112,8 +108,8 @@ MessageDirector::MessageDirector() { // Start connecting/listening! _listenHandle->bind(_host, _port); - _connectHandle->connect(AddressUtils::resolve_host(g_loop, rHost, rPort), - rPort); + _connectHandle->connect(AddressUtils::resolve_host(g_loop, _rHost, _rPort), + _rPort); } /** @@ -177,16 +173,16 @@ void MessageDirector::onReady(AMQP::Connection *connection) { // Startup configured roles. if (Config::Instance()->GetBool("want-state-server")) { - new StateServer(); + _stateServer = std::make_unique(); } if (Config::Instance()->GetBool("want-client-agent")) { - new ClientAgent(); + _clientAgent = std::make_unique(); } if (Config::Instance()->GetBool("want-database")) { #ifdef ARDOS_WANT_DB_SERVER - new DatabaseServer(); + _db = std::make_unique(); #else Logger::Error("want-database was set to true but Ardos was " "built without ARDOS_WANT_DB_SERVER"); @@ -195,7 +191,11 @@ void MessageDirector::onReady(AMQP::Connection *connection) { } if (Config::Instance()->GetBool("want-db-state-server")) { - new DatabaseStateServer(); + _dbss = std::make_unique(); + } + + if (Config::Instance()->GetBool("want-web-panel")) { + new WebPanel(); } // Start listening for incoming connections. @@ -292,10 +292,12 @@ void MessageDirector::ParticipantJoined() { /** * Called when a participant disconnects. */ -void MessageDirector::ParticipantLeft() { +void MessageDirector::ParticipantLeft(MDParticipant *participant) { if (_participantsGauge) { _participantsGauge->Decrement(); } + + _participants.erase(participant); } /** @@ -416,4 +418,28 @@ bool MessageDirector::WithinGlobalRange(const std::string &routingKey) { }); } +void MessageDirector::HandleWeb(ws28::Client *client, nlohmann::json &data) { + // Build up an array of connected participants. + nlohmann::json participantInfo = nlohmann::json::array(); + for (const auto &participant : _participants) { + participantInfo.push_back({ + {"name", participant->GetName()}, + {"ip", participant->GetRemoteAddress().ip}, + {"port", participant->GetRemoteAddress().port}, + {"channels", participant->GetLocalChannels().size()}, + {"postRemoves", participant->GetPostRemoves().size()}, + }); + } + + WebPanel::Send(client, { + {"type", "md"}, + {"success", true}, + {"listenIp", _host}, + {"listenPort", _port}, + {"connectIp", _rHost}, + {"connectPort", _rPort}, + {"participants", participantInfo}, + }); +} + } // namespace Ardos diff --git a/src/messagedirector/message_director.h b/src/messagedirector/message_director.h index 3976304..61e5295 100644 --- a/src/messagedirector/message_director.h +++ b/src/messagedirector/message_director.h @@ -4,16 +4,25 @@ #include #include +#include #include #include #include #include +#include "../net/ws/Client.h" + namespace Ardos { const std::string kGlobalExchange = "global-exchange"; class ChannelSubscriber; +class MDParticipant; + +class StateServer; +class ClientAgent; +class DatabaseServer; +class DatabaseStateServer; class MessageDirector : public AMQP::ConnectionHandler { public: @@ -32,7 +41,14 @@ class MessageDirector : public AMQP::ConnectionHandler { void RemoveSubscriber(ChannelSubscriber *subscriber); void ParticipantJoined(); - void ParticipantLeft(); + void ParticipantLeft(MDParticipant *participant); + + void HandleWeb(ws28::Client *client, nlohmann::json &data); + + StateServer *GetStateServer() { return _stateServer.get(); } + ClientAgent *GetClientAgent() { return _clientAgent.get(); } + DatabaseServer *GetDbServer() { return _db.get(); } + DatabaseStateServer *GetDbStateServer() { return _dbss.get(); } private: MessageDirector(); @@ -45,8 +61,14 @@ class MessageDirector : public AMQP::ConnectionHandler { static MessageDirector *_instance; + std::unique_ptr _stateServer; + std::unique_ptr _clientAgent; + std::unique_ptr _db; + std::unique_ptr _dbss; + std::unordered_set _subscribers; std::unordered_set _leavingSubscribers; + std::unordered_set _participants; std::shared_ptr _connectHandle; std::shared_ptr _listenHandle; @@ -56,8 +78,12 @@ class MessageDirector : public AMQP::ConnectionHandler { std::string _consumeTag; std::vector _frameBuffer; + // Listen info. std::string _host = "127.0.0.1"; int _port = 7100; + // RabbitMQ connect info. + std::string _rHost = "127.0.0.1"; + int _rPort = 5672; prometheus::Counter *_datagramsObservedCounter = nullptr; prometheus::Counter *_datagramsProcessedCounter = nullptr; diff --git a/src/net/network_client.h b/src/net/network_client.h index b81546b..556f4d0 100644 --- a/src/net/network_client.h +++ b/src/net/network_client.h @@ -13,12 +13,13 @@ class NetworkClient { public: explicit NetworkClient(const std::shared_ptr &socket); + [[nodiscard]] uvw::socket_address GetRemoteAddress() const; + [[nodiscard]] uvw::socket_address GetLocalAddress() const; + protected: ~NetworkClient(); [[nodiscard]] bool Disconnected() const; - [[nodiscard]] uvw::socket_address GetRemoteAddress() const; - [[nodiscard]] uvw::socket_address GetLocalAddress() const; void Shutdown(); diff --git a/src/net/ws/Client.cpp b/src/net/ws/Client.cpp new file mode 100644 index 0000000..5e7574b --- /dev/null +++ b/src/net/ws/Client.cpp @@ -0,0 +1,971 @@ +#include "Client.h" +#include "Server.h" +#include "base64.h" +#include +#include +#include + +#include + +namespace ws28 { + +namespace detail { + bool equalsi(std::string_view a, std::string_view b){ + if(a.size() != b.size()) return false; + for(;;){ + if(tolower(a.front()) != tolower(b.front())) return false; + + a.remove_prefix(1); + b.remove_prefix(1); + if(a.empty()) return true; + } + } + + bool equalsi(std::string_view a, std::string_view b, size_t n){ + while(n--){ + if(a.empty()) return b.empty(); + else if(b.empty()) return false; + + if(tolower(a.front()) != tolower(b.front())) return false; + + a.remove_prefix(1); + b.remove_prefix(1); + } + + return true; + } + + bool HeaderContains(std::string_view header, std::string_view substring){ + bool hasMatch = false; + + while(!header.empty()){ + if(header.front() == ' ' || header.front() == '\t'){ + header.remove_prefix(1); + continue; + } + + if(hasMatch){ + if(header.front() == ',') return true; + hasMatch = false; + header.remove_prefix(1); + + // Skip to comma or end of string + while(!header.empty() && header.front() != ',') header.remove_prefix(1); + if(header.empty()) return false; + + // Skip comma + assert(header.front() == ','); + header.remove_prefix(1); + }else{ + if(detail::equalsi(header, substring, substring.size())){ + // We have a match... if the header ends here, or has a comma + hasMatch = true; + header.remove_prefix(substring.size()); + }else{ + header.remove_prefix(1); + + // Skip to comma or end of string + while(!header.empty() && header.front() != ',') header.remove_prefix(1); + if(header.empty()) return false; + + // Skip comma + assert(header.front() == ','); + header.remove_prefix(1); + } + } + } + + return hasMatch; + } + + + struct Corker { + Client &client; + + Corker(Client &client) : client(client) { client.Cork(true); } + ~Corker(){ client.Cork(false); } + }; +} + +struct DataFrameHeader { + char *data; + + DataFrameHeader(char *data) : data(data){ + + } + + void reset(){ + data[0] = 0; + data[1] = 0; + } + + void fin(bool v) { data[0] &= ~(1 << 7); data[0] |= v << 7; } + void rsv1(bool v) { data[0] &= ~(1 << 6); data[0] |= v << 6; } + void rsv2(bool v) { data[0] &= ~(1 << 5); data[0] |= v << 5; } + void rsv3(bool v) { data[0] &= ~(1 << 4); data[0] |= v << 4; } + void mask(bool v) { data[1] &= ~(1 << 7); data[1] |= v << 7; } + void opcode(uint8_t v) { + data[0] &= ~0x0F; + data[0] |= v & 0x0F; + } + + void len(uint8_t v) { + data[1] &= ~0x7F; + data[1] |= v & 0x7F; + } + + bool fin() { return (data[0] >> 7) & 1; } + bool rsv1() { return (data[0] >> 6) & 1; } + bool rsv2() { return (data[0] >> 5) & 1; } + bool rsv3() { return (data[0] >> 4) & 1; } + bool mask() { return (data[1] >> 7) & 1; } + + uint8_t opcode() { + return data[0] & 0x0F; + } + + uint8_t len() { + return data[1] & 0x7F; + } +}; + +Client::Client(Server *server, SocketHandle socket) : m_pServer(server), m_Socket(std::move(socket)){ + m_Socket->data = this; + + // Default to true since that's what most people want + uv_tcp_nodelay(m_Socket.get(), true); + uv_tcp_keepalive(m_Socket.get(), true, 10000); + + { // Put IP in m_IP + m_IP[0] = '\0'; + struct sockaddr_storage addr; + int addrLen = sizeof(addr); + uv_tcp_getpeername(m_Socket.get(), (sockaddr*) &addr, &addrLen); + + if(addr.ss_family == AF_INET){ + int r = uv_ip4_name((sockaddr_in*) &addr, m_IP, sizeof(m_IP)); + (void) r; + assert(r == 0); + }else if(addr.ss_family == AF_INET6){ + int r = uv_ip6_name((sockaddr_in6*) &addr, m_IP, sizeof(m_IP)); + (void) r; + assert(r == 0); + + // Remove this prefix if it exists, it means that we actually have a ipv4 + static const char *ipv4Prefix = "::ffff:"; + if(strncmp(m_IP, ipv4Prefix, strlen(ipv4Prefix)) == 0){ + memmove(m_IP, m_IP + strlen(ipv4Prefix), strlen(m_IP) - strlen(ipv4Prefix) + 1); + } + }else{ + // Server::OnConnection will destroy us + } + } + + uv_read_start((uv_stream_t*) m_Socket.get(), [](uv_handle_t*, size_t suggested_size, uv_buf_t *buf){ + buf->base = new char[suggested_size]; + buf->len = suggested_size; + }, [](uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf){ + auto client = (Client*) stream->data; + + if(client != nullptr){ + if(nread < 0){ + client->Destroy(); + }else if(nread > 0){ + client->OnRawSocketData(buf->base, (size_t) nread); + } + } + + if(buf != nullptr) delete[] buf->base; + }); +} + +Client::~Client(){ + assert(!m_Socket); +} + +void Client::Destroy(){ + if(!m_Socket) return; + + Cork(false); + + m_Socket->data = nullptr; + + auto myself = m_pServer->NotifyClientPreDestroyed(this); + + struct ShutdownRequest : uv_shutdown_t { + SocketHandle socket; + std::unique_ptr client; + Server::ClientDisconnectedFn cb; + }; + + auto req = new ShutdownRequest(); + req->socket = std::move(m_Socket); + req->client = std::move(myself); + req->cb = m_pServer->m_fnClientDisconnected; + + m_pServer = nullptr; + + static auto cb = [](uv_shutdown_t* reqq, int){ + auto req = (ShutdownRequest*) reqq; + + if(req->cb && req->client->m_bHasCompletedHandshake){ + req->cb(req->client.get()); + } + + delete req; + }; + + if(uv_shutdown(req, (uv_stream_t*) req->socket.get(), cb) != 0){ + // Shutdown failed, but we have to delay the destruction to the next event loop + auto timer = new uv_timer_t; + uv_timer_init(req->socket->loop, timer); + timer->data = req; + uv_timer_start(timer, [](uv_timer_t *timer){ + auto req = (ShutdownRequest*) timer->data; + cb(req, 0); + uv_close((uv_handle_t*) timer, [](uv_handle_t *h){ delete (uv_timer_t*) h;}); + }, 0, 0); + } +} + + + +template +void Client::WriteRaw(uv_buf_t bufs[N]){ + if(!m_Socket) return; + + // Try to write without allocating memory first, if that doesn't work, we call WriteRawQueue + int written = uv_try_write((uv_stream_t*) m_Socket.get(), bufs, N); + if(written == UV_EAGAIN) written = 0; + + if(written >= 0){ + size_t totalLength = 0; + + for(size_t i = 0; i < N; ++i){ + auto &buf = bufs[i]; + totalLength += buf.len; + } + + size_t skipping = (size_t) written; + if(skipping == totalLength) return; // Complete write + + // Partial write + // Copy the remainder into a buffer to send to WriteRawQueue + + auto cpy = std::make_unique(totalLength); + size_t offset = 0; + + for(size_t i = 0; i < N; ++i){ + auto &buf = bufs[i]; + if(skipping >= buf.len){ + skipping -= buf.len; + continue; + } + + memcpy(cpy.get() + offset, buf.base + skipping, buf.len - skipping); + offset += buf.len - skipping; + skipping = 0; + } + + WriteRawQueue(std::move(cpy), offset); + }else{ + // Write error + Destroy(); + return; + } +} + +void Client::WriteRawQueue(std::unique_ptr data, size_t len){ + if(!m_Socket) return; + + struct CustomWriteRequest { + uv_write_t req; + Client *client; + std::unique_ptr data; + }; + + uv_buf_t buf; + buf.base = data.get(); + buf.len = len; + + auto request = new CustomWriteRequest(); + request->client = this; + request->data = std::move(data); + + if(uv_write(&request->req, (uv_stream_t*) m_Socket.get(), &buf, 1, [](uv_write_t* req, int status){ + auto request = (CustomWriteRequest*) req; + + if(status < 0){ + request->client->Destroy(); + } + + delete request; + }) != 0){ + delete request; + Destroy(); + } +} + +template +void Client::Write(uv_buf_t bufs[N]){ + if(!m_Socket) return; + if(IsSecure()){ + for(size_t i = 0; i < N; ++i){ + if(!m_pTLS->Write(bufs[i].base, bufs[i].len)) return Destroy(); + } + FlushTLS(); + }else{ + WriteRaw(bufs); + } +} + +void Client::Write(const char *data, size_t len){ + uv_buf_t bufs[1]; + bufs[0].base = (char*) data; + bufs[0].len = len; + Write<1>(bufs); +} + +void Client::Write(const char *data){ + Write(data, strlen(data)); +} + +void Client::WriteDataFrameHeader(uint8_t opcode, size_t len, char *headerStart){ + DataFrameHeader header{ headerStart }; + + header.reset(); + header.fin(true); + header.opcode(opcode); + header.mask(false); + header.rsv1(false); + header.rsv2(false); + header.rsv3(false); + if(len >= 126){ + if(len > UINT16_MAX){ + header.len(127); + *(uint8_t*)(headerStart + 2) = (len >> 56) & 0xFF; + *(uint8_t*)(headerStart + 3) = (len >> 48) & 0xFF; + *(uint8_t*)(headerStart + 4) = (len >> 40) & 0xFF; + *(uint8_t*)(headerStart + 5) = (len >> 32) & 0xFF; + *(uint8_t*)(headerStart + 6) = (len >> 24) & 0xFF; + *(uint8_t*)(headerStart + 7) = (len >> 16) & 0xFF; + *(uint8_t*)(headerStart + 8) = (len >> 8) & 0xFF; + *(uint8_t*)(headerStart + 9) = (len >> 0) & 0xFF; + }else{ + header.len(126); + *(uint8_t*)(headerStart + 2) = (len >> 8) & 0xFF; + *(uint8_t*)(headerStart + 3) = (len >> 0) & 0xFF; + } + }else{ + header.len(len); + } +} + + +size_t Client::GetDataFrameHeaderSize(size_t len){ + if(len >= 126){ + if(len > UINT16_MAX){ + return 10; + }else{ + return 4; + } + }else{ + return 2; + } +} + +void Client::OnRawSocketData(char *data, size_t len){ + if(len == 0) return; + if(!m_Socket) return; + + if(m_bWaitingForFirstPacket){ + m_bWaitingForFirstPacket = false; + + assert(!IsSecure()); + + if(m_pServer->GetSSLContext() != nullptr && (data[0] == 0x16 || uint8_t(data[0]) == 0x80)){ + if(m_pServer->m_fnCheckTCPConnection && !m_pServer->m_fnCheckTCPConnection(GetIP(), true)){ + return Destroy(); + } + + InitSecure(); + }else{ + if(m_pServer->m_fnCheckTCPConnection && !m_pServer->m_fnCheckTCPConnection(GetIP(), false)){ + return Destroy(); + } + } + } + + if(IsSecure()){ + if(!m_pTLS->ReceivedData(data, len, [&](char *data, size_t len){ + OnSocketData(data, len); + })){ + return Destroy(); + } + + FlushTLS(); + }else{ + OnSocketData(data, len); + } +} + +void Client::OnSocketData(char *data, size_t len){ + if(m_pServer == nullptr) return; + + // This gives us an extra byte just in case + if(m_Buffer.size() + len + 1 >= m_pServer->m_iMaxMessageSize){ + if(m_bHasCompletedHandshake){ + Close(1009, "Message too large"); + } + + Destroy(); + return; + } + + // If we don't have anything stored in our class-level buffer (m_Buffer), + // we use the buffer we received in the function arguments so we don't have to + // perform any copying. The Bail function needs to be called before we leave this + // function (unless we're destroying the client), to copy the unused part of the buffer + // back to our class-level buffer + std::string_view buffer; + bool usingLocalBuffer; + + if(m_Buffer.empty()){ + usingLocalBuffer = true; + buffer = std::string_view(data, len); + }else{ + usingLocalBuffer = false; + + m_Buffer.insert(m_Buffer.end(), data, data + len); + buffer = std::string_view(m_Buffer.data(), m_Buffer.size()); + } + + auto Consume = [&](size_t amount){ + assert(buffer.size() >= amount); + buffer.remove_prefix(amount); + }; + + auto Bail = [&](){ + // Copy partial HTTP headers to our buffer + if(usingLocalBuffer){ + if(!buffer.empty()){ + assert(m_Buffer.empty()); + m_Buffer.insert(m_Buffer.end(), buffer.data(), buffer.data() + buffer.size()); + } + }else{ + if(buffer.empty()){ + m_Buffer.clear(); + }else if(buffer.size() != m_Buffer.size()){ + memmove(m_Buffer.data(), buffer.data(), buffer.size()); + m_Buffer.resize(buffer.size()); + } + } + }; + + if(!m_bHasCompletedHandshake && m_pServer->GetAllowAlternativeProtocol() && buffer[0] == 0x00){ + m_bHasCompletedHandshake = true; + m_bUsingAlternativeProtocol = true; + Consume(1); + + if(!m_pServer->m_fnCheckAlternativeConnection || m_pServer->m_fnCheckAlternativeConnection(this)){ + Destroy(); + return; + } + + RequestHeaders headers; + HTTPRequest req{ + m_pServer, + "GET", + "/", + m_IP, + headers, + }; + + m_pServer->NotifyClientInit(this, req); + }else if(!m_bHasCompletedHandshake){ + // HTTP headers not done yet, wait + auto endOfHeaders = buffer.find("\r\n\r\n"); + if(endOfHeaders == std::string_view::npos) return Bail(); + + auto MalformedRequest = [&](){ + Write("HTTP/1.1 400 Bad Request\r\n\r\n"); + Destroy(); + }; + + auto headersBuffer = buffer.substr(0, endOfHeaders+4); // Include \r\n\r\n + + std::string_view method; + std::string_view path; + + { + auto methodEnd = headersBuffer.find(' '); + + auto endOfLine = headersBuffer.find("\r\n"); + assert(endOfLine != std::string_view::npos); // Can't fail because of a check above + + if(methodEnd == std::string_view::npos || methodEnd > endOfLine) return MalformedRequest(); + + method = headersBuffer.substr(0, methodEnd); + + // Uppercase method + std::transform(method.begin(), method.end(), (char*) method.data(), [](char v) -> char{ + if(v < 0 || v >= 127) return v; + return toupper(v); + }); + + auto pathStart = methodEnd + 1; + auto pathEnd = headersBuffer.find(' ', pathStart); + + if(pathEnd == std::string_view::npos || pathEnd > endOfLine) return MalformedRequest(); + + path = headersBuffer.substr(pathStart, pathEnd - pathStart); + + // Skip line + headersBuffer = headersBuffer.substr(endOfLine + 2); + } + + RequestHeaders headers; + + for(;;) { + auto nextLine = headersBuffer.find("\r\n"); + + // This means that we have finished parsing the headers + if(nextLine == 0) { + break; + } + + // This can't happen... right? + if(nextLine == std::string_view::npos) return MalformedRequest(); + + auto colonPos = headersBuffer.find(':'); + if(colonPos == std::string_view::npos || colonPos > nextLine) return MalformedRequest(); + + auto key = headersBuffer.substr(0, colonPos); + + // Key to lower case + std::transform(key.begin(), key.end(), (char*) key.data(), [](char v) -> char { + if(v < 0 || v >= 127) return v; + return tolower(v); + }); + + auto value = headersBuffer.substr(colonPos + 1, nextLine - (colonPos + 1)); + + // Trim spaces + while(!key.empty() && key.front() == ' ') key.remove_prefix(1); + while(!key.empty() && key.back() == ' ') key.remove_suffix(1); + while(!value.empty() && value.front() == ' ') value.remove_prefix(1); + while(!value.empty() && value.back() == ' ') value.remove_suffix(1); + + headers.Set(key, value); + + headersBuffer = headersBuffer.substr(nextLine+2); + } + + HTTPRequest req{ + m_pServer, + method, + path, + m_IP, + headers, + }; + + { + if(auto upgrade = headers.Get("upgrade")){ + if(!detail::equalsi(*upgrade, "websocket")){ + return MalformedRequest(); + } + }else{ + + HTTPResponse res; + + if(m_pServer->m_fnHTTPRequest) m_pServer->m_fnHTTPRequest(req, res); + + if(res.statusCode == 0) res.statusCode = 404; + if(res.statusCode < 200 || res.statusCode >= 1000) res.statusCode = 500; + + + const char *statusCodeText = "WS28"; // Too lazy to create a table of those + + std::stringstream ss; + ss << "HTTP/1.1 " << res.statusCode << " " << statusCodeText << "\r\n"; + ss << "Connection: close\r\n"; + ss << "Content-Length: " << res.body.size() << "\r\n"; + + for(auto &p : res.headers){ + ss << p.first << ": " << p.second << "\r\n"; + } + + ss << "\r\n"; + + ss << res.body; + + std::string str = ss.str(); + Write(str.data(), str.size()); + + Destroy(); + return; + } + } + + // WebSocket upgrades must be GET + if(method != "GET") return MalformedRequest(); + + auto connectionHeader = headers.Get("connection"); + if(!connectionHeader) return MalformedRequest(); + + // Hackish, ideally we should check it's surrounded by commas (or start/end of string) + if(!detail::HeaderContains(*connectionHeader, "upgrade")) return MalformedRequest(); + + bool sendMyVersion = false; + + auto websocketVersion = headers.Get("sec-websocket-version"); + if(!websocketVersion) return MalformedRequest(); + if(!detail::equalsi(*websocketVersion, "13")){ + sendMyVersion = true; + } + + auto websocketKey = headers.Get("sec-websocket-key"); + if(!websocketKey) return MalformedRequest(); + + std::string securityKey = std::string(*websocketKey); + + if(m_pServer->m_fnCheckConnection && !m_pServer->m_fnCheckConnection(this, req)){ + Write("HTTP/1.1 403 Forbidden\r\n\r\n"); + Destroy(); + return; + } + + + securityKey += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + unsigned char hash[20]; +#if OPENSSL_VERSION_NUMBER <= 0x030000000L + SHA_CTX sha1; + SHA1_Init(&sha1); + SHA1_Update(&sha1, securityKey.data(), securityKey.size()); + SHA1_Final(hash, &sha1); +#else + EVP_MD_CTX *sha1 = EVP_MD_CTX_new(); + EVP_DigestInit_ex(sha1, EVP_sha1(), NULL); + EVP_DigestUpdate(sha1, securityKey.data(), securityKey.size()); + EVP_DigestFinal_ex(sha1, hash, NULL); + EVP_MD_CTX_free(sha1); +#endif + + auto solvedHash = base64_encode(hash, sizeof(hash)); + + char buf[256]; // We can use up to 101 + 27 + 28 + 1 characters, and we round up just because + int bufLen = snprintf(buf, sizeof(buf), + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "%s" + "Sec-WebSocket-Accept: %s\r\n\r\n", + + sendMyVersion ? "Sec-WebSocket-Version: 13\r\n" : "", + solvedHash.c_str() + ); + + assert(bufLen >= 0 && (size_t) bufLen < sizeof(buf)); + + Write(buf, bufLen); + if(!m_Socket) return; // if write failed, we're being destroyed + + m_bHasCompletedHandshake = true; + + m_pServer->NotifyClientInit(this, req); + + m_Buffer.clear(); + + return; + } + + detail::Corker corker{*this}; + + for(;;){ + if(!m_Socket) return; // No need to destroy even + + if(m_bUsingAlternativeProtocol){ + if(buffer.size() < 4) return Bail(); + uint32_t frameLength = ((uint32_t)(uint8_t) buffer[0]) | ((uint32_t)(uint8_t) buffer[1] << 8) | ((uint32_t)(uint8_t) buffer[2] << 16) | ((uint32_t)(uint8_t) buffer[3] << 24); + if(frameLength > m_pServer->m_iMaxMessageSize) return Close(1002, "Too large"); + if(buffer.size() < 4 + frameLength) return Bail(); + + ProcessDataFrame(2, (char*)buffer.data() + 4, frameLength); + Consume(4 + frameLength); + }else{ // Websockets + // Not enough to read the header + if(buffer.size() < 2) return Bail(); + + DataFrameHeader header((char*) buffer.data()); + + if(header.rsv1() || header.rsv2() || header.rsv3()) return Close(1002, "Reserved bit used"); + + // Clients MUST mask their headers + if(!header.mask()) return Close(1002, "Clients must mask their payload"); + assert(header.mask()); + + char *curPosition = (char*) buffer.data() + 2; + + size_t frameLength = header.len(); + if(frameLength == 126){ + if(buffer.size() < 4) return Bail(); + frameLength = (*(uint8_t*)(curPosition) << 8) | (*(uint8_t*)(curPosition + 1)); + curPosition += 2; + }else if(frameLength == 127){ + if(buffer.size() < 10) return Bail(); + + frameLength = ((uint64_t)*(uint8_t*)(curPosition) << 56) | ((uint64_t)*(uint8_t*)(curPosition + 1) << 48) + | ((uint64_t)*(uint8_t*)(curPosition + 2) << 40) | ((uint64_t)*(uint8_t*)(curPosition + 3) << 32) + | (*(uint8_t*)(curPosition + 4) << 24) | (*(uint8_t*)(curPosition + 5) << 16) + | (*(uint8_t*)(curPosition + 6) << 8) | (*(uint8_t*)(curPosition + 7) << 0); + + curPosition += 8; + } + + auto amountLeft = buffer.size() - (curPosition - buffer.data()); + const char *maskKey = nullptr; + + { // Read mask + if(amountLeft < 4) return Bail(); + maskKey = curPosition; + curPosition += 4; + amountLeft -= 4; + } + + if(frameLength > amountLeft) return Bail(); + + { // Unmask + for(size_t i = 0; i < (frameLength & ~3); i += 4){ + curPosition[i + 0] ^= maskKey[0]; + curPosition[i + 1] ^= maskKey[1]; + curPosition[i + 2] ^= maskKey[2]; + curPosition[i + 3] ^= maskKey[3]; + } + + for(size_t i = frameLength & ~3; i < frameLength; ++i){ + curPosition[i] ^= maskKey[i % 4]; + } + } + + if(header.opcode() >= 0x08){ + if(!header.fin()) return Close(1002, "Control op codes can't be fragmented"); + if(frameLength > 125) return Close(1002, "Control op codes can't be more than 125 bytes"); + + + ProcessDataFrame(header.opcode(), curPosition, frameLength); + }else if(!IsBuildingFrames() && header.fin()){ + // Fast path, we received a whole frame and we don't need to combine it with anything + ProcessDataFrame(header.opcode(), curPosition, frameLength); + }else{ + if(IsBuildingFrames()){ + if(header.opcode() != 0) return Close(1002, "Expected continuation frame"); + }else{ + if(header.opcode() == 0) return Close(1002, "Unexpected continuation frame"); + m_iFrameOpcode = header.opcode(); + } + + if(m_FrameBuffer.size() + frameLength >= m_pServer->m_iMaxMessageSize) return Close(1009, "Message too large"); + + m_FrameBuffer.insert(m_FrameBuffer.end(), curPosition, curPosition + frameLength); + + if(header.fin()){ + // Assemble frame + + ProcessDataFrame(m_iFrameOpcode, m_FrameBuffer.data(), m_FrameBuffer.size()); + + m_iFrameOpcode = 0; + m_FrameBuffer.clear(); + } + + } + + Consume((curPosition - buffer.data()) + frameLength); + } + } + + // Unreachable +} + + +void Client::ProcessDataFrame(uint8_t opcode, char *data, size_t len){ + switch(opcode){ + case 9: // Ping + if(m_bIsClosing) return; + Send(data, len, 10); // Send Pong + break; + + case 10: break; // Pong + + case 8: // Close + m_bClientRequestedClose = true; + if(m_bIsClosing){ + Destroy(); + }else{ + + if(len == 1){ + Close(1002, "Incomplete close code"); + return; + }else if(len >= 2){ + bool invalid = false; + uint16_t code = (uint8_t(data[0]) << 8) | uint8_t(data[1]); + if(code < 1000 || code >= 5000) invalid = true; + + switch(code){ + case 1004: + case 1005: + case 1006: + case 1015: + invalid = true; + default:; + } + + if(invalid){ + Close(1002, "Invalid close code"); + return; + } + + if(len > 2 && !IsValidUTF8(data + 2, len - 2)){ + Close(1002, "Close reason is not UTF-8"); + return; + } + } + + // Copy close message + m_bIsClosing = true; + + char header[MAX_HEADER_SIZE]; + WriteDataFrameHeader(8, len, header); + + uv_buf_t bufs[2]; + bufs[0].base = header; + bufs[0].len = GetDataFrameHeaderSize(len); + bufs[1].base = (char*) data; + bufs[1].len = len; + + Write<2>(bufs); + + // We always close the tcp connection on our side, as allowed in 7.1.1 + Destroy(); + } + break; + + case 1: // Text + case 2: // Binary + if(m_bIsClosing) return; + if(opcode == 1 && !IsValidUTF8(data, len)) return Close(1007, "Invalid UTF-8 in text frame"); + + m_pServer->NotifyClientData(this, data, len, opcode); + break; + + default: + return Close(1002, "Unknown op code"); + } +} + +void Client::Close(uint16_t code, const char *reason, size_t reasonLen){ + if(m_bIsClosing) return; + + m_bIsClosing = true; + + if(!m_bUsingAlternativeProtocol){ + char coded[2]; + coded[0] = (code >> 8) & 0xFF; + coded[1] = (code >> 0) & 0xFF; + + if(reason == nullptr){ + Send(coded, sizeof(coded), 8); + }else{ + if(reasonLen == (size_t) -1) reasonLen = strlen(reason); + + char header[MAX_HEADER_SIZE]; + WriteDataFrameHeader(8, 2 + reasonLen, header); + + uv_buf_t bufs[2]; + bufs[0].base = header; + bufs[0].len = GetDataFrameHeaderSize(2 + reasonLen); + bufs[1].base = (char*) reason; + bufs[1].len = reasonLen; + + Write<2>(bufs); + } + } + + // We always close the tcp connection on our side, as allowed in 7.1.1 + Destroy(); +} + + +void Client::Send(const char *data, size_t len, uint8_t opcode){ + if(!m_Socket) return; + + if(m_bUsingAlternativeProtocol){ + uint32_t len32 = (uint32_t) len; + uint8_t header[4]; + header[0] = (len32 >> 0) & 0xFF; + header[1] = (len32 >> 8) & 0xFF; + header[2] = (len32 >> 16) & 0xFF; + header[3] = (len32 >> 24) & 0xFF; + + uv_buf_t bufs[2]; + bufs[0].base = (char*) header; + bufs[0].len = 4; + bufs[1].base = (char*) data; + bufs[1].len = len; + + Write<2>(bufs); + }else{ + char header[MAX_HEADER_SIZE]; + WriteDataFrameHeader(opcode, len, header); + + uv_buf_t bufs[2]; + bufs[0].base = header; + bufs[0].len = GetDataFrameHeaderSize(len); + bufs[1].base = (char*) data; + bufs[1].len = len; + + Write<2>(bufs); + } +} + +void Client::InitSecure(){ + m_pTLS = std::make_unique(m_pServer->GetSSLContext()); +} + +void Client::FlushTLS(){ + assert(m_pTLS != nullptr); + m_pTLS->ForEachPendingWrite([&](const char *data, size_t len){ + uv_buf_t bufs[1]; + bufs[0].base = (char*) data; + bufs[0].len = len; + WriteRaw<1>(bufs); + }); +} + +void Client::Cork(bool v){ + if(!m_Socket) return; + +#if defined(TCP_CORK) || defined(TCP_NOPUSH) + + int enable = v; + uv_os_fd_t fd; + uv_fileno((uv_handle_t*) m_Socket.get(), &fd); + + // Shamelessly copied from uWebSockets +#if defined(TCP_CORK) + // Linux + setsockopt(fd, IPPROTO_TCP, TCP_CORK, &enable, sizeof(int)); +#elif defined(TCP_NOPUSH) + // Mac OS X & FreeBSD + setsockopt(fd, IPPROTO_TCP, TCP_NOPUSH, &enable, sizeof(int)); + + // MacOS needs this to flush the messages out + if(!enable){ + ::send(fd, "", 0, 0); + } +#endif + +#endif +} + + +} diff --git a/src/net/ws/Client.h b/src/net/ws/Client.h new file mode 100644 index 0000000..c699ddf --- /dev/null +++ b/src/net/ws/Client.h @@ -0,0 +1,118 @@ +#ifndef H_0AC5AB22DD724A3F8FE93E27C178D633 +#define H_0AC5AB22DD724A3F8FE93E27C178D633 + +#include +#include +#include +#include +#include + +#include "Headers.h" +#include "TLS.h" + +namespace ws28 { + namespace detail { + struct Corker; + struct SocketDeleter { + void operator()(uv_tcp_t *socket) const { + if(socket == nullptr) return; + uv_close((uv_handle_t*) socket, [](uv_handle_t *h){ + delete (uv_tcp_t*) h; + }); + } + }; + } + + typedef std::unique_ptr SocketHandle; + + class Server; + class Client { + enum { MAX_HEADER_SIZE = 10 }; + enum : unsigned char { NO_FRAMES = 0 }; + public: + ~Client(); + + // If reasonLen is -1, it'll use strlen + void Close(uint16_t code, const char *reason = nullptr, size_t reasonLen = -1); + void Destroy(); + void Send(const char *data, size_t len, uint8_t opCode = 2); + + inline void SetUserData(void *v){ m_pUserData = v; } + inline void* GetUserData(){ return m_pUserData; } + + inline bool IsSecure(){ return m_pTLS != nullptr; } + inline bool IsUsingAlternativeProtocol(){ return m_bUsingAlternativeProtocol; } + + inline Server* GetServer(){ return m_pServer; } + + inline const char* GetIP() const { return m_IP; } + + inline bool HasClientRequestedClose() const { return m_bClientRequestedClose; } + + private: + + struct DataFrame { + uint8_t opcode; + std::unique_ptr data; + size_t len; + }; + + Client(Server *server, SocketHandle socket); + + Client(const Client &other) = delete; + Client& operator=(Client &other) = delete; + + size_t GetDataFrameHeaderSize(size_t len); + void WriteDataFrameHeader(uint8_t opcode, size_t len, char *out); + void EncryptAndWrite(const char *data, size_t len); + + void OnRawSocketData(char *data, size_t len); + void OnSocketData(char *data, size_t len); + void ProcessDataFrame(uint8_t opcode, char *data, size_t len); + + void InitSecure(); + void FlushTLS(); + + void Write(const char *data); + void Write(const char *data, size_t len); + + template + void Write(uv_buf_t bufs[N]); + + template + void WriteRaw(uv_buf_t bufs[N]); + + void WriteRawQueue(std::unique_ptr data, size_t len); + + void Cork(bool v); + + // Stub, maybe some day + inline bool IsValidUTF8(const char *, size_t){ return true; } + + inline bool IsBuildingFrames(){ return m_iFrameOpcode != NO_FRAMES; } + + Server *m_pServer; + SocketHandle m_Socket; + void *m_pUserData = nullptr; + bool m_bWaitingForFirstPacket = true; + bool m_bHasCompletedHandshake = false; + bool m_bIsClosing = false; + bool m_bUsingAlternativeProtocol = false; + bool m_bClientRequestedClose = false; + char m_IP[46]; + + std::unique_ptr m_pTLS; + + std::vector m_Buffer; + + uint8_t m_iFrameOpcode = NO_FRAMES; + std::vector m_FrameBuffer; + + friend class Server; + friend struct detail::Corker; + friend class std::unique_ptr; + }; + +} + +#endif diff --git a/src/net/ws/Headers.h b/src/net/ws/Headers.h new file mode 100644 index 0000000..d281dbe --- /dev/null +++ b/src/net/ws/Headers.h @@ -0,0 +1,51 @@ +#ifndef H_39B56032251A44728943666BD008D047 +#define H_39B56032251A44728943666BD008D047 + +#include +#include +#include +#include +#include + +namespace ws28 { + class Client; + class Server; + + class RequestHeaders { + public: + void Set(std::string_view key, std::string_view value){ + m_Headers.push_back({ key, value }); + } + + template + void ForEachValueOf(std::string_view key, const F &f) const { + for(auto &p : m_Headers){ + if(p.first == key) f(p.second); + } + } + + std::optional Get(std::string_view key) const { + for(auto &p : m_Headers){ + if(p.first == key) return p.second; + } + + return std::nullopt; + } + + template + void ForEach(const F &f) const { + for(auto &p : m_Headers){ + f(p.first, p.second); + } + } + + private: + std::vector> m_Headers; + + friend class Client; + friend class Server; + }; + +} + +#endif diff --git a/src/net/ws/Server.cpp b/src/net/ws/Server.cpp new file mode 100644 index 0000000..5c9733b --- /dev/null +++ b/src/net/ws/Server.cpp @@ -0,0 +1,118 @@ +#include "Server.h" + +#ifndef _WIN32 +#include +#endif + +namespace ws28{ + +Server::Server(uv_loop_t *loop, SSL_CTX *ctx) : m_pLoop(loop), m_pSSLContext(ctx){ + + m_fnCheckConnection = [](Client*, HTTPRequest &req) -> bool { + auto host = req.headers.Get("host"); + if(!host) return true; // No host header, default to accept + + auto origin = req.headers.Get("origin"); + if(!origin) return true; + + return origin == host; + }; + +} + +bool Server::Listen(int port, bool ipv4Only){ + if(m_Server) return false; + +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif + + auto server = SocketHandle{new uv_tcp_t}; + uv_tcp_init_ex(m_pLoop, server.get(), ipv4Only ? AF_INET : AF_INET6); + server->data = this; + + struct sockaddr_storage addr; + + if(ipv4Only){ + uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*) &addr); + }else{ + uv_ip6_addr("::0", port, (struct sockaddr_in6*) &addr); + } + + uv_tcp_nodelay(server.get(), (int) true); + + // Enable SO_REUSEPORT +#ifndef _WIN32 + uv_os_fd_t fd; + int r = uv_fileno((uv_handle_t*) server.get(), &fd); + (void) r; + assert(r == 0); + int optval = 1; + setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)); +#endif + + if(uv_tcp_bind(server.get(), (struct sockaddr*) &addr, 0) != 0){ + return false; + } + + if(uv_listen((uv_stream_t*) server.get(), 512, [](uv_stream_t* server, int status){ + ((Server*) server->data)->OnConnection(server, status); + }) != 0){ + return false; + } + + m_Server = std::move(server); + return true; +} + +void Server::StopListening(){ + // Just in case we have more logic in the future + if(!m_Server) return; + + m_Server.reset(); +} + +void Server::DestroyClients(){ + // Clients will erase themselves from this vector + while(!m_Clients.empty()){ + m_Clients.back()->Destroy(); + } +} + +Server::~Server(){ + StopListening(); + DestroyClients(); +} + +void Server::OnConnection(uv_stream_t* server, int status){ + if(status < 0) return; + + SocketHandle socket{new uv_tcp_t}; + uv_tcp_init(m_pLoop, socket.get()); + + socket->data = nullptr; + + if(uv_accept(server, (uv_stream_t*) socket.get()) == 0){ + auto client = new Client(this, std::move(socket)); + m_Clients.emplace_back(client); + + // If for whatever reason uv_tcp_getpeername failed (happens... somehow?) + if(client->GetIP()[0] == '\0') client->Destroy(); + } +} + +std::unique_ptr Server::NotifyClientPreDestroyed(Client *client){ + for(auto it = m_Clients.begin(); it != m_Clients.end(); ++it){ + if(it->get() == client){ + std::unique_ptr r = std::move(*it); + *it = std::move(m_Clients.back()); + m_Clients.pop_back(); + return r; + } + } + + assert(false); + return {}; +} + +} diff --git a/src/net/ws/Server.h b/src/net/ws/Server.h new file mode 100644 index 0000000..a671fbb --- /dev/null +++ b/src/net/ws/Server.h @@ -0,0 +1,152 @@ +#ifndef H_2ABA91710E664A51814F459521E1C4D4 +#define H_2ABA91710E664A51814F459521E1C4D4 + +#include +#include +#include +#include +#include + +#include "Client.h" + +namespace ws28 { + class Server; + + struct HTTPRequest { + Server *server; + std::string_view method; + std::string_view path; + std::string_view ip; + + // Header keys are always lower case + const RequestHeaders &headers; + }; + + class HTTPResponse { + public: + + HTTPResponse& status(int v){ statusCode = v; return *this; } + HTTPResponse& send(std::string_view v){ body.append(v); return *this; } + + // Appends a response header. The following headers cannot be changed: + // Connection: close + // Content-Length: body.size() + HTTPResponse& header(std::string_view key, std::string_view value){ headers.emplace(std::string(key), std::string(value)); return *this; } + + private: + int statusCode = 200; + std::string body; + std::multimap headers; + + friend class Client; + }; + + class Server { + typedef bool (*CheckTCPConnectionFn)(std::string_view ip, bool secure); + typedef bool (*CheckConnectionFn)(Client *, HTTPRequest&); + typedef bool (*CheckAlternativeConnectionFn)(Client *); + typedef void (*ClientConnectedFn)(Client *, HTTPRequest&); + typedef void (*ClientDisconnectedFn)(Client *); + typedef void (*ClientDataFn)(Client *, char *data, size_t len, int opcode); + typedef void (*HTTPRequestFn)(HTTPRequest&, HTTPResponse&); + public: + + // Note: By default, this listens on both ipv4 and ipv6 + // Note: if you provide a SSL_CTX, this server will listen to *BOTH* secure and insecure connections at that port, + // sniffing the first byte to figure out whether it's secure or not + Server(uv_loop_t *loop, SSL_CTX *ctx = nullptr); + Server(const Server &other) = delete; + Server& operator=(const Server &other) = delete; + ~Server(); + + bool Listen(int port, bool ipv4Only = false); + void StopListening(); + void DestroyClients(); + + // This callback is called when we know whether a TCP connection wants a secure connection or not, + // once we receive the very first byte from the client + void SetCheckTCPConnectionCallback(CheckTCPConnectionFn v){ m_fnCheckTCPConnection = v; } + + // This callback is called when the client is trying to connect using websockets + // By default, for safety, this checks the Origin and makes sure it matches the Host + // It's likely you wanna change this check if your websocket server is in a different domain. + void SetCheckConnectionCallback(CheckConnectionFn v){ m_fnCheckConnection = v; } + + // This is called instead of CheckConnection for connections using the alternative protocol (if enabled) + void SetCheckAlternativeConnectionCallback(CheckAlternativeConnectionFn v){ m_fnCheckAlternativeConnection = v; } + + // This callback is called when a client establishes a connection (after websocket handshake) + // This is paired with the disconnected callback + void SetClientConnectedCallback(ClientConnectedFn v){ m_fnClientConnected = v; } + + // This callback is called when a client disconnects + // This is paired with the connected callback, and will *always* be called for clients that called the other callback + // Note that clients grab this value when you call Destroy on them, so changing this after clients are connected + // might lead to weird results. In practice, just set it once and forget about it. + void SetClientDisconnectedCallback(ClientDisconnectedFn v){ m_fnClientDisconnected = v; } + + // This callback is called when the client receives a data frame + // Note that both text and binary op codes end up here + void SetClientDataCallback(ClientDataFn v){ m_fnClientData = v; } + + // This callback is called when a normal http request is received + // If you don't send anything in response, the status code is 404 + // If you send anything in response without setting a specific status code, it will be 200 + // Connections that call this callback never lead to a connection + void SetHTTPCallback(HTTPRequestFn v){ m_fnHTTPRequest = v;} + + SSL_CTX* GetSSLContext() const { return m_pSSLContext; } + + inline void SetUserData(void *v){ m_pUserData = v; } + inline void* GetUserData() const { return m_pUserData; } + + // Adjusts how much we're willing to accept from clients + // Note: this can only be set while we don't have clients (preferably before listening) + inline void SetMaxMessageSize(size_t v){ assert(m_Clients.empty()); m_iMaxMessageSize = v;} + + // Alternative protocol means that the client sends a 0x00, and we skip all websocket protocol + // This means clients don't call CheckConnection, and they receive an empty request header in the connection callback + // Opcode is always binary + // In the alternative protocol, clients and servers send the packet length as a Little Endian uint32, then its contents + inline void SetAllowAlternativeProtocol(bool v){ m_bAllowAlternativeProtocol = v; } + inline bool GetAllowAlternativeProtocol(){ return m_bAllowAlternativeProtocol; } + + void Ref(){ if(m_Server) uv_ref((uv_handle_t*) m_Server.get()); } + void Unref(){ if(m_Server) uv_unref((uv_handle_t*) m_Server.get()); } + + private: + void OnConnection(uv_stream_t* server, int status); + + void NotifyClientInit(Client *client, HTTPRequest &req){ + if(m_fnClientConnected) m_fnClientConnected(client, req); + } + + std::unique_ptr NotifyClientPreDestroyed(Client *client); + + void NotifyClientData(Client *client, char *data, size_t len, int opcode){ + if(m_fnClientData) m_fnClientData(client, data, len, opcode); + } + + uv_loop_t *m_pLoop; + SocketHandle m_Server; + SSL_CTX *m_pSSLContext; + void *m_pUserData = nullptr; + std::vector> m_Clients; + bool m_bAllowAlternativeProtocol = false; + + CheckTCPConnectionFn m_fnCheckTCPConnection = nullptr; + CheckConnectionFn m_fnCheckConnection = nullptr; + CheckAlternativeConnectionFn m_fnCheckAlternativeConnection = nullptr; + ClientConnectedFn m_fnClientConnected = nullptr; + ClientDisconnectedFn m_fnClientDisconnected = nullptr; + ClientDataFn m_fnClientData = nullptr; + HTTPRequestFn m_fnHTTPRequest = nullptr; + + size_t m_iMaxMessageSize = 16 * 1024; + + friend class Client; + }; + +} + +#endif diff --git a/src/net/ws/TLS.h b/src/net/ws/TLS.h new file mode 100644 index 0000000..0ea7767 --- /dev/null +++ b/src/net/ws/TLS.h @@ -0,0 +1,220 @@ +#ifndef H_3078F0E9644347BD9F28E4C56F162FC8 +#define H_3078F0E9644347BD9F28E4C56F162FC8 + +#include +#include + +#include +#include +#include +#include + +namespace ws28 { + +// Ported from https://github.com/darrenjs/openssl_examples +// MIT licensed +class TLS { + enum SSLStatus { + SSLSTATUS_OK, SSLSTATUS_WANT_IO, SSLSTATUS_FAIL + }; + +public: + + TLS(SSL_CTX *ctx, bool server = true, const char *hostname = nullptr){ + m_ReadBIO = BIO_new(BIO_s_mem()); + m_WriteBIO = BIO_new(BIO_s_mem()); + m_SSL = SSL_new(ctx); + + if(server){ + SSL_set_accept_state(m_SSL); + }else{ + SSL_set_connect_state(m_SSL); + } + + if(!server && hostname) SSL_set_tlsext_host_name(m_SSL, hostname); + SSL_set_bio(m_SSL, m_ReadBIO, m_WriteBIO); + + if(!server) DoSSLHandhake(); + } + + ~TLS(){ + SSL_free(m_SSL); + } + + TLS(const TLS &other) = delete; + TLS& operator=(const TLS &other) = delete; + + // Helper to setup SSL, you still need to create the context + static void InitSSL(){ + static std::once_flag f; + std::call_once(f, [](){ + SSL_library_init(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); +#if OPENSSL_VERSION_NUMBER < 0x30000000L || defined(LIBRESSL_VERSION_NUMBER) + ERR_load_BIO_strings(); +#endif + ERR_load_crypto_strings(); + }); + } + + // Writes unencrypted bytes to be encrypted and sent out + // If this returns false, the connection must be closed + bool Write(const char *buf, size_t len){ + m_EncryptBuf.insert(m_EncryptBuf.end(), buf, buf + len); + return DoEncrypt(); + } + + // Process raw bytes received from the other side + // If this returns false, the connection must be closed + template + bool ReceivedData(const char *src, size_t len, const F &f){ + int n; + while(len > 0){ + n = BIO_write(m_ReadBIO, src, len); + + // Assume bio write failure is unrecoverable + if(n <= 0) return false; + + src += n; + len -= n; + + if(!SSL_is_init_finished(m_SSL)){ + if(DoSSLHandhake() == SSLSTATUS_FAIL) return false; + if(!SSL_is_init_finished(m_SSL)) return true; + } + + ERR_clear_error(); + do { + char buf[4096]; + n = SSL_read(m_SSL, buf, sizeof buf); + if(n > 0){ + f(buf, (size_t) n); + } + }while(n > 0); + + auto status = GetSSLStatus(n); + if(status == SSLSTATUS_WANT_IO){ + do { + char buf[4096]; + n = BIO_read(m_WriteBIO, buf, sizeof(buf)); + if(n > 0){ + QueueEncrypted(buf, n); + }else if(!BIO_should_retry(m_WriteBIO)){ + return false; + } + }while(n > 0); + }else if(status == SSLSTATUS_FAIL){ + return false; + } + } + + return true; + } + + template + void ForEachPendingWrite(const F &f){ + // If the callback does something crazy like calling Write inside of it + // We need to handle this carefully, thus the swap. + for(;;){ + if(m_WriteBuf.empty()) return; + + std::vector buf; + std::swap(buf, m_WriteBuf); + + f(buf.data(), buf.size()); + } + } + + bool IsHandshakeFinished(){ + return SSL_is_init_finished(m_SSL); + } + +private: + SSLStatus GetSSLStatus(int n){ + switch(SSL_get_error(m_SSL, n)){ + case SSL_ERROR_NONE: + return SSLSTATUS_OK; + + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + return SSLSTATUS_WANT_IO; + + case SSL_ERROR_ZERO_RETURN: + case SSL_ERROR_SYSCALL: + default: + return SSLSTATUS_FAIL; + } + } + + void QueueEncrypted(const char *buf, size_t len){ + m_WriteBuf.insert(m_WriteBuf.end(), buf, buf + len); + } + + bool DoEncrypt(){ + if(!SSL_is_init_finished(m_SSL)) return true; + + int n; + + while(!m_EncryptBuf.empty()){ + ERR_clear_error(); + n = SSL_write(m_SSL, m_EncryptBuf.data(), (int) m_EncryptBuf.size()); + + if(GetSSLStatus(n) == SSLSTATUS_FAIL) return false; + + if(n > 0){ + // Consume bytes + m_EncryptBuf.erase(m_EncryptBuf.begin(), m_EncryptBuf.begin() + n); + + // Write them out + do { + char buf[4096]; + n = BIO_read(m_WriteBIO, buf, sizeof buf); + if(n > 0){ + QueueEncrypted(buf, n); + }else if(!BIO_should_retry(m_WriteBIO)){ + return false; + } + }while(n > 0); + } + } + + return true; + } + + SSLStatus DoSSLHandhake(){ + ERR_clear_error(); + SSLStatus status = GetSSLStatus(SSL_do_handshake(m_SSL)); + + // Did SSL request to write bytes? + if(status == SSLSTATUS_WANT_IO){ + int n; + do { + char buf[4096]; + n = BIO_read(m_WriteBIO, buf, sizeof buf); + + if(n > 0){ + QueueEncrypted(buf, n); + }else if(!BIO_should_retry(m_WriteBIO)){ + return SSLSTATUS_FAIL; + } + + } while(n > 0); + } + + return status; + } + + + + std::vector m_EncryptBuf; // Bytes waiting to be encrypted + std::vector m_WriteBuf; // Bytes waiting to be written to the socket + + SSL *m_SSL; + BIO *m_ReadBIO; + BIO *m_WriteBIO; +}; + +} + +#endif diff --git a/src/net/ws/base64.cpp b/src/net/ws/base64.cpp new file mode 100644 index 0000000..415798d --- /dev/null +++ b/src/net/ws/base64.cpp @@ -0,0 +1,125 @@ +/* + base64.cpp and base64.h + + Copyright (C) 2004-2008 Ren� Nyffenegger + + This source code is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this source code must not be misrepresented; you must not + claim that you wrote the original source code. If you use this source code + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original source code. + + 3. This notice may not be removed or altered from any source distribution. + + Ren� Nyffenegger rene.nyffenegger@adp-gmbh.ch + +*/ +#include "base64.h" +#include + +namespace ws28 { + +static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; + + +static inline bool is_base64(unsigned char c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) { + std::string ret; + int i = 0; + int j = 0; + unsigned char char_array_3[3]; + unsigned char char_array_4[4]; + + while (in_len--) { + char_array_3[i++] = *(bytes_to_encode++); + if (i == 3) { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for(i = 0; (i <4) ; i++) + ret += base64_chars[char_array_4[i]]; + i = 0; + } + } + + if (i) + { + for(j = i; j < 3; j++) + char_array_3[j] = '\0'; + + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for (j = 0; (j < i + 1); j++) + ret += base64_chars[char_array_4[j]]; + + while((i++ < 3)) + ret += '='; + + } + + return ret; + +} + +std::string base64_decode(std::string const& encoded_string) { + int in_len = int(encoded_string.size()); + int i = 0; + int j = 0; + int in_ = 0; + unsigned char char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; in_++; + if (i ==4) { + for (i = 0; i <4; i++) + char_array_4[i] = (unsigned char) base64_chars.find(char_array_4[i]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; (i < 3); i++) + ret += char_array_3[i]; + i = 0; + } + } + + if (i) { + for (j = i; j <4; j++) + char_array_4[j] = 0; + + for (j = 0; j <4; j++) + char_array_4[j] = (unsigned char) base64_chars.find(char_array_4[j]); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + } + + return ret; +} +} diff --git a/src/net/ws/base64.h b/src/net/ws/base64.h new file mode 100644 index 0000000..b5ebb29 --- /dev/null +++ b/src/net/ws/base64.h @@ -0,0 +1,43 @@ +/* + Copyright (c) 2011, Micael Hildenborg + All rights reserved. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Micael Hildenborg nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY Micael Hildenborg ''AS IS'' AND ANY + EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Micael Hildenborg BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + Contributors: + Gustav + Several members in the gamedev.se forum. + Gregory Petrosyan + */ + +#ifndef H_F322B9F4CB8E431398CA8441845F438C +#define H_F322B9F4CB8E431398CA8441845F438C + +#include + +namespace ws28 { + std::string base64_encode(unsigned char const* , unsigned int len); + std::string base64_decode(std::string const& s); +} + +#endif \ No newline at end of file diff --git a/src/stateserver/database_state_server.cpp b/src/stateserver/database_state_server.cpp index 8a718ad..f80d16f 100644 --- a/src/stateserver/database_state_server.cpp +++ b/src/stateserver/database_state_server.cpp @@ -5,6 +5,7 @@ #include "../util/config.h" #include "../util/logger.h" #include "../util/metrics.h" +#include "../web/web_panel.h" #include "loading_object.h" namespace Ardos { @@ -26,11 +27,11 @@ DatabaseStateServer::DatabaseStateServer() : ChannelSubscriber() { _dbChannel = config["database"].as(); auto rangeParam = config["ranges"]; - auto min = rangeParam["min"].as(); - auto max = rangeParam["max"].as(); + _minDoId = rangeParam["min"].as(); + _maxDoId = rangeParam["max"].as(); // Start listening to DoId's in our listening range. - SubscribeRange(min, max); + SubscribeRange(_minDoId, _maxDoId); // Initialize metrics. InitMetrics(); @@ -441,4 +442,74 @@ bool UnpackDBFields(DatagramIterator &dgi, DCClass *dclass, FieldMap &required, return true; } +void DatabaseStateServer::HandleWeb(ws28::Client *client, + nlohmann::json &data) { + if (data["msg"] == "init") { + // Build up an array of distributed objects. + nlohmann::json distObjInfo = nlohmann::json::array(); + for (const auto &distObj : _distObjs) { + distObjInfo.push_back({ + {"doId", distObj.first}, + {"clsName", distObj.second->GetDClass()->get_name()}, + {"parentId", distObj.second->GetParentId()}, + {"zoneId", distObj.second->GetZoneId()}, + }); + } + + WebPanel::Send(client, { + {"type", "dbss:init"}, + {"success", true}, + {"dbChannel", _dbChannel}, + {"minDoId", _minDoId}, + {"maxDoId", _maxDoId}, + {"distObjs", distObjInfo}, + }); + } else if (data["msg"] == "distobj") { + auto doId = data["doId"].template get(); + + // Try to find a matching Distributed Object for the provided DoId. + if (!_distObjs.contains(doId)) { + WebPanel::Send(client, { + {"type", "dbss:distobj"}, + {"success", false}, + }); + return; + } + + auto distObj = _distObjs[doId]; + + // Build an array of explicitly set RAM fields. + nlohmann::json ramFields = nlohmann::json::array(); + for (const auto &field : distObj->GetRamFields()) { + ramFields.push_back({{"fieldName", field.first->get_name()}}); + } + + // Build a dictionary of zone objects under this Distributed Object. + nlohmann::json zoneObjs = nlohmann::json::object(); + for (const auto &zoneData : distObj->GetZoneObjects()) { + for (const auto &zoneDoId : zoneData.second) { + // Try to get the DClass name for the zone object. + auto clsName = _distObjs.contains(zoneDoId) + ? _distObjs[zoneDoId]->GetDClass()->get_name() + : "Unknown"; + + zoneObjs[std::to_string(zoneData.first)].push_back( + {{"doId", zoneDoId}, {"clsName", clsName}}); + } + } + + WebPanel::Send(client, { + {"type", "dbss:distobj"}, + {"success", true}, + {"clsName", distObj->GetDClass()->get_name()}, + {"parentId", distObj->GetParentId()}, + {"zoneId", distObj->GetZoneId()}, + {"owner", distObj->GetOwner()}, + {"size", distObj->Size()}, + {"ram", ramFields}, + {"zones", zoneObjs}, + }); + } +} + } // namespace Ardos diff --git a/src/stateserver/database_state_server.h b/src/stateserver/database_state_server.h index 94cdefd..f171975 100644 --- a/src/stateserver/database_state_server.h +++ b/src/stateserver/database_state_server.h @@ -16,7 +16,7 @@ bool UnpackDBFields(DatagramIterator &dgi, DCClass *dclass, FieldMap &required, class LoadingObject; class DatabaseStateServer final : public StateServerImplementation, - public ChannelSubscriber { + public ChannelSubscriber { public: friend class LoadingObject; @@ -24,6 +24,8 @@ class DatabaseStateServer final : public StateServerImplementation, void RemoveDistributedObject(const uint32_t &doId) override; + void HandleWeb(ws28::Client *client, nlohmann::json &data); + private: void HandleDatagram(const std::shared_ptr &dg) override; @@ -48,6 +50,8 @@ class DatabaseStateServer final : public StateServerImplementation, void ReportActivateTime(const uvw::timer_handle::time &startTime); uint64_t _dbChannel; + uint64_t _minDoId; + uint64_t _maxDoId; std::unordered_map _distObjs; std::unordered_map _loadObjs; diff --git a/src/stateserver/distributed_object.cpp b/src/stateserver/distributed_object.cpp index 004f317..ccea20e 100644 --- a/src/stateserver/distributed_object.cpp +++ b/src/stateserver/distributed_object.cpp @@ -93,18 +93,10 @@ size_t DistributedObject::Size() const { return objectSize; } -uint64_t DistributedObject::GetAI() const { return _aiChannel; } - -bool DistributedObject::IsAIExplicitlySet() const { return _aiExplicitlySet; } - -uint32_t DistributedObject::GetDoId() const { return _doId; } - uint64_t DistributedObject::GetLocation() const { return LocationAsChannel(_parentId, _zoneId); } -uint64_t DistributedObject::GetOwner() const { return _ownerChannel; } - void DistributedObject::Annihilate(const uint64_t &sender, const bool ¬ifyParent) { std::unordered_set targets; @@ -659,7 +651,7 @@ void DistributedObject::HandleLocationChange(const uint32_t &newParent, } // Send changing location message. - auto dg = std::make_shared(targets, sender, + auto dg = std::make_shared(targets, _doId, STATESERVER_OBJECT_CHANGING_LOCATION); dg->AddUint32(_doId); dg->AddLocation(newParent, newZone); diff --git a/src/stateserver/distributed_object.h b/src/stateserver/distributed_object.h index 04007a6..bd7c6cd 100644 --- a/src/stateserver/distributed_object.h +++ b/src/stateserver/distributed_object.h @@ -24,13 +24,23 @@ class DistributedObject final : public ChannelSubscriber { [[nodiscard]] size_t Size() const; - [[nodiscard]] uint64_t GetAI() const; - [[nodiscard]] bool IsAIExplicitlySet() const; + [[nodiscard]] uint64_t GetAI() const { return _aiChannel; } + [[nodiscard]] bool IsAIExplicitlySet() const { return _aiExplicitlySet; } - [[nodiscard]] uint32_t GetDoId() const; + [[nodiscard]] uint32_t GetDoId() const { return _doId; } + [[nodiscard]] DCClass *GetDClass() const { return _dclass; } [[nodiscard]] uint64_t GetLocation() const; - [[nodiscard]] uint64_t GetOwner() const; + [[nodiscard]] uint64_t GetOwner() const { return _ownerChannel; } + [[nodiscard]] uint32_t GetParentId() const { return _parentId; } + [[nodiscard]] uint32_t GetZoneId() const { return _zoneId; } + + [[nodiscard]] std::unordered_map> + GetZoneObjects() const { + return _zoneObjects; + } + + [[nodiscard]] FieldMap GetRamFields() const { return _ramFields; } private: void Annihilate(const uint64_t &sender, const bool ¬ifyParent = true); diff --git a/src/stateserver/state_server.cpp b/src/stateserver/state_server.cpp index e0fed9e..5609359 100644 --- a/src/stateserver/state_server.cpp +++ b/src/stateserver/state_server.cpp @@ -7,6 +7,7 @@ #include "../util/globals.h" #include "../util/logger.h" #include "../util/metrics.h" +#include "../web/web_panel.h" #include "distributed_object.h" namespace Ardos { @@ -148,4 +149,71 @@ void StateServer::InitMetrics() { 16384, 65536}); } +void StateServer::HandleWeb(ws28::Client *client, nlohmann::json &data) { + if (data["msg"] == "init") { + // Build up an array of distributed objects. + nlohmann::json distObjInfo = nlohmann::json::array(); + for (const auto &distObj : _distObjs) { + distObjInfo.push_back({ + {"doId", distObj.first}, + {"clsName", distObj.second->GetDClass()->get_name()}, + {"parentId", distObj.second->GetParentId()}, + {"zoneId", distObj.second->GetZoneId()}, + }); + } + + WebPanel::Send(client, { + {"type", "ss:init"}, + {"success", true}, + {"channel", _channel}, + {"distObjs", distObjInfo}, + }); + } else if (data["msg"] == "distobj") { + auto doId = data["doId"].template get(); + + // Try to find a matching Distributed Object for the provided DoId. + if (!_distObjs.contains(doId)) { + WebPanel::Send(client, { + {"type", "ss:distobj"}, + {"success", false}, + }); + return; + } + + auto distObj = _distObjs[doId]; + + // Build an array of explicitly set RAM fields. + nlohmann::json ramFields = nlohmann::json::array(); + for (const auto &field : distObj->GetRamFields()) { + ramFields.push_back({{"fieldName", field.first->get_name()}}); + } + + // Build a dictionary of zone objects under this Distributed Object. + nlohmann::json zoneObjs = nlohmann::json::object(); + for (const auto &zoneData : distObj->GetZoneObjects()) { + for (const auto &zoneDoId : zoneData.second) { + // Try to get the DClass name for the zone object. + auto clsName = _distObjs.contains(zoneDoId) + ? _distObjs[zoneDoId]->GetDClass()->get_name() + : "Unknown"; + + zoneObjs[std::to_string(zoneData.first)].push_back( + {{"doId", zoneDoId}, {"clsName", clsName}}); + } + } + + WebPanel::Send(client, { + {"type", "ss:distobj"}, + {"success", true}, + {"clsName", distObj->GetDClass()->get_name()}, + {"parentId", distObj->GetParentId()}, + {"zoneId", distObj->GetZoneId()}, + {"owner", distObj->GetOwner()}, + {"size", distObj->Size()}, + {"ram", ramFields}, + {"zones", zoneObjs}, + }); + } +} + } // namespace Ardos diff --git a/src/stateserver/state_server.h b/src/stateserver/state_server.h index 5ececea..a0e5482 100644 --- a/src/stateserver/state_server.h +++ b/src/stateserver/state_server.h @@ -3,24 +3,29 @@ #include +#include #include #include #include "../messagedirector/channel_subscriber.h" #include "../net/datagram.h" #include "../net/datagram_iterator.h" +#include "../net/ws/Client.h" #include "state_server_implementation.h" namespace Ardos { class DistributedObject; -class StateServer final : public StateServerImplementation, public ChannelSubscriber { +class StateServer final : public StateServerImplementation, + public ChannelSubscriber { public: StateServer(); void RemoveDistributedObject(const uint32_t &doId) override; + void HandleWeb(ws28::Client *client, nlohmann::json &data); + private: void HandleDatagram(const std::shared_ptr &dg) override; void HandleGenerate(DatagramIterator &dgi, const bool &other); diff --git a/src/util/config.h b/src/util/config.h index 0109b1f..6b83510 100644 --- a/src/util/config.h +++ b/src/util/config.h @@ -10,6 +10,7 @@ class Config { static Config *Instance(); void LoadConfig(const std::string &name); + YAML::Node GetConfig() { return _config; } std::string GetString(const std::string &key, const std::string &defVal = ""); YAML::Node GetNode(const std::string &key); diff --git a/src/web/web_panel.cpp b/src/web/web_panel.cpp new file mode 100644 index 0000000..68f545e --- /dev/null +++ b/src/web/web_panel.cpp @@ -0,0 +1,232 @@ +#include "web_panel.h" + +#include "../clientagent/client_agent.h" +#include "../database/database_server.h" +#include "../messagedirector/message_director.h" +#include "../net/datagram.h" +#include "../stateserver/database_state_server.h" +#include "../stateserver/state_server.h" +#include "../util/config.h" +#include "../util/globals.h" +#include "../util/logger.h" + +namespace Ardos { + +WebPanel *WebPanel::Instance = nullptr; + +WebPanel::WebPanel() { + Logger::Info("Starting Web Panel component..."); + + Instance = this; + + // Web Panel configuration. + auto config = Config::Instance()->GetNode("web-panel"); + + // Cluster name configuration. + if (auto nameParam = config["name"]) { + _name = nameParam.as(); + } + // Port configuration. + if (auto portParam = config["port"]) { + _port = portParam.as(); + } + + // Login configuration. + if (auto userParam = config["username"]) { + _username = userParam.as(); + } + if (auto passParam = config["password"]) { + _password = passParam.as(); + } + + // SSL configuration. + if (auto certParam = config["certificate"]) { + _cert = certParam.as(); + } + if (auto keyParam = config["private-key"]) { + _key = keyParam.as(); + } + + if (!_cert.empty() && !_key.empty()) { + // Configure SSL (if keys were supplied.) + ws28::TLS::InitSSL(); + _secure = true; + + const SSL_METHOD *method = TLS_server_method(); + + SSL_CTX *ctx = SSL_CTX_new(method); + if (!ctx) { + Logger::Error("Unable to create SSL context"); + } + + if (SSL_CTX_use_certificate_file(ctx, _cert.c_str(), SSL_FILETYPE_PEM) <= + 0) { + Logger::Error(std::format("[WEB] Failed to load cert file: {}", _cert)); + exit(1); + } + + if (SSL_CTX_use_PrivateKey_file(ctx, _key.c_str(), SSL_FILETYPE_PEM) <= 0) { + Logger::Error( + std::format("[WEB] Failed to load private key file: {}", _cert)); + exit(1); + } + + _server = std::make_unique(g_loop->raw(), ctx); + } else { + // Otherwise, create an unsecure server. + _server = std::make_unique(g_loop->raw()); + } + + // Set a max message size that reflects the + // max length of a Datagram (+2 for length header.) + _server->SetMaxMessageSize(kMaxDgSize + 2); + + // Disable Origin checks. + _server->SetCheckConnectionCallback( + [](ws28::Client *client, ws28::HTTPRequest &) { return true; }); + + _server->SetClientConnectedCallback( + [](ws28::Client *client, ws28::HTTPRequest &) { + Logger::Verbose( + std::format("[WEB] Client connected from {}", client->GetIP())); + + auto *data = (ClientData *)malloc(sizeof(ClientData)); + data->authed = false; + + client->SetUserData(data); + }); + + _server->SetClientDisconnectedCallback([](ws28::Client *client) { + Logger::Verbose( + std::format("[WEB] Client '{}' disconnected", client->GetIP())); + + // Free alloc'd user data. + if (client->GetUserData() != nullptr) { + free(client->GetUserData()); + client->SetUserData(nullptr); + } + }); + + _server->SetClientDataCallback( + [](ws28::Client *client, char *data, size_t len, int opcode) { + Instance->HandleData(client, {data, len}); + }); + + // Start listening! + _server->Listen(_port); + + Logger::Info(std::format("[WEB] Listening on {} [{}]", _port, + _secure ? "SECURE" : "UNSECURE")); +} + +void WebPanel::Send(ws28::Client *client, const nlohmann::json &data) { + auto res = data.dump(); + client->Send(res.c_str(), res.length(), 1); +} + +void WebPanel::HandleData(ws28::Client *client, const std::string &data) { + // Make sure we have a valid JSON request. + if (!nlohmann::json::accept(data)) { + client->Close(400, "Improperly formatted request"); + return; + } + + // Parse the request data and client data. + nlohmann::json message = nlohmann::json::parse(data); + auto clientData = (ClientData *)client->GetUserData(); + + // Make sure the request is valid. + if (!message.contains("type") || !message["type"].is_string()) { + client->Close(400, "Improperly formatted request"); + return; + } + + // Make sure the first message is authentication. + auto messageType = message["type"].template get(); + if (!clientData->authed && messageType != "auth") { + client->Close(403, "First message was not auth"); + return; + } + + if (messageType == "auth") { + // Validate the auth message. + if (!message.contains("username") || !message["username"].is_string() || + !message.contains("password") || !message["password"].is_string()) { + client->Close(400, "Improperly formatted request"); + return; + } + + // Validate the auth credentials. + if (message["username"].template get() != _username || + message["password"].template get() != _password) { + // Send the auth response. + Send(client, {{"type", "auth"}, {"success", false}}); + client->Close(401, "Invalid auth credentials"); + return; + } + + clientData->authed = true; + + // Send the auth response. + Send(client, { + {"type", "auth"}, + {"success", true}, + {"name", _name}, + }); + } else if (messageType == "md") { + // Handle the request on the Message Director. + MessageDirector::Instance()->HandleWeb(client, message); + } else if (messageType == "ss") { + // Handle the request on the State Server. + auto ss = MessageDirector::Instance()->GetStateServer(); + if (ss) { + ss->HandleWeb(client, message); + } else { + Send(client, { + {"type", "ss"}, + {"success", false}, + }); + } + } else if (messageType == "ca") { + // Handle the request on the Client Agent. + auto ca = MessageDirector::Instance()->GetClientAgent(); + if (ca) { + ca->HandleWeb(client, message); + } else { + Send(client, { + {"type", "ca"}, + {"success", false}, + }); + } + } else if (messageType == "db") { + // Handle the request on the Database Server. + auto db = MessageDirector::Instance()->GetDbServer(); + if (db) { + db->HandleWeb(client, message); + } else { + Send(client, { + {"type", "db"}, + {"success", false}, + }); + } + } else if (messageType == "dbss") { + // Handle the request on the Database State Server. + auto dbss = MessageDirector::Instance()->GetDbStateServer(); + if (dbss) { + dbss->HandleWeb(client, message); + } else { + Send(client, { + {"type", "dbss"}, + {"success", false}, + }); + } + } else if (messageType == "config") { + // Return the full config file this deployment has been loaded with. + Send(client, { + {"type", "config"}, + {"config", YAML::Dump(Config::Instance()->GetConfig())}, + }); + } +} + +} // namespace Ardos \ No newline at end of file diff --git a/src/web/web_panel.h b/src/web/web_panel.h new file mode 100644 index 0000000..f228be6 --- /dev/null +++ b/src/web/web_panel.h @@ -0,0 +1,43 @@ +#ifndef ARDOS_WEB_PANEL_H +#define ARDOS_WEB_PANEL_H + +#include + +#include + +#include "../net/ws/Server.h" + +namespace Ardos { + +class WebPanel { +public: + WebPanel(); + + static void Send(ws28::Client *client, const nlohmann::json &data); + + static WebPanel *Instance; + + typedef struct { + bool authed; + } ClientData; + +private: + void HandleData(ws28::Client *client, const std::string &data); + + std::string _name = "Ardos"; + int _port = 7781; + + std::string _username = "ardos"; + std::string _password = "ardos"; + + std::string _cert; + std::string _key; + + bool _secure = false; + + std::unique_ptr _server; +}; + +} // namespace Ardos + +#endif // ARDOS_WEB_PANEL_H