diff --git a/src/network/connection.cpp b/src/network/connection.cpp index 1f1ae69c8..65fd839ad 100644 --- a/src/network/connection.cpp +++ b/src/network/connection.cpp @@ -67,7 +67,11 @@ namespace Network { bool Connection::beginPacket() { if (m_IsBundle) { - m_BundlePacketPosition = 0; + m_BundleInnerStart = m_BundleInnerPosition; + // We have to advance two bytes, since the first two bytes are reserved for the + // packet length + m_BundleInnerPosition += 2; + return true; } @@ -84,21 +88,14 @@ bool Connection::beginPacket() { bool Connection::endPacket() { if (m_IsBundle) { - uint32_t innerPacketSize = m_BundlePacketPosition; - - MUST_TRANSFER_BOOL((innerPacketSize > 0)); + auto innerPacketSize = getBundleInnerSize(); - m_IsBundle = false; + // We have to go back to the start of the packet and write the size + convert_to_chars((uint16_t)innerPacketSize, m_Buf); + memcpy(m_BundleInnerStart, m_Buf, 2); - if (m_BundlePacketInnerCount == 0) { - sendPacketType(PACKET_BUNDLE); - sendPacketNumber(); - } - sendShort(innerPacketSize); - sendBytes(m_Packet, innerPacketSize); + m_BundlePacketCount++; - m_BundlePacketInnerCount++; - m_IsBundle = true; return true; } @@ -114,40 +111,68 @@ bool Connection::endPacket() { } bool Connection::beginBundle() { - MUST_TRANSFER_BOOL(m_ServerFeatures.has(ServerFeatures::PROTOCOL_BUNDLE_SUPPORT)); + MUST_TRANSFER_BOOL(m_ServerFeatures.has(EServerFeatureFlags::PROTOCOL_BUNDLE_SUPPORT + )); + + memset(m_Bundle, 0, sizeof(m_Bundle)); + m_BundleInnerPosition = m_Bundle; + m_BundlePacketCount = 0; + MUST_TRANSFER_BOOL(m_Connected); MUST_TRANSFER_BOOL(!m_IsBundle); - MUST_TRANSFER_BOOL(beginPacket()); m_IsBundle = true; - m_BundlePacketInnerCount = 0; + return true; } -bool Connection::endBundle() { - MUST_TRANSFER_BOOL(m_IsBundle); +bool Connection::sendBundle() { + MUST_TRANSFER_BOOL(m_ServerFeatures.has(EServerFeatureFlags::PROTOCOL_BUNDLE_SUPPORT + )); + MUST_TRANSFER_BOOL(!m_IsBundle); + MUST_TRANSFER_BOOL((m_BundlePacketCount > 0)); + MUST_TRANSFER_BOOL(beginPacket()); + + MUST_TRANSFER_BOOL(sendPacketType(PACKET_BUNDLE)); + MUST_TRANSFER_BOOL(sendPacketNumber()); + MUST_TRANSFER_BOOL(write(m_Bundle, getBundleSize())); + + MUST_TRANSFER_BOOL(endPacket()); + + return true; +} + +bool Connection::endBundle() { m_IsBundle = false; - MUST_TRANSFER_BOOL((m_BundlePacketInnerCount > 0)); + auto ret = sendBundle(); - return endPacket(); + memset(m_Buf, 0, sizeof(m_Buf)); + m_BundleInnerStart = m_Buf; + m_BundleInnerPosition = m_Buf; + m_BundlePacketCount = 0; + + return ret; } -size_t Connection::write(const uint8_t *buffer, size_t size) { +size_t Connection::write(const uint8_t* buffer, size_t size) { if (m_IsBundle) { - if (m_BundlePacketPosition + size > sizeof(m_Packet)) { + if (getBundleSize() + size > MAX_BUNDLE_SIZE) { + m_Logger.error("Bundled packet too large"); + + // TODO: Drop the currently forming packet + return 0; } - memcpy(m_Packet + m_BundlePacketPosition, buffer, size); - m_BundlePacketPosition += size; + + memcpy(m_BundleInnerPosition, buffer, size); + m_BundleInnerPosition += size; + return size; } - return m_UDP.write(buffer, size); -} -size_t Connection::write(uint8_t byte) { - return write(&byte, 1); + return m_UDP.write(buffer, size); } bool Connection::sendFloat(float f) { @@ -158,19 +183,19 @@ bool Connection::sendFloat(float f) { bool Connection::sendByte(uint8_t c) { return write(&c, 1) != 0; } -bool Connection::sendShort(uint16_t i) { +bool Connection::sendU16(uint16_t i) { convert_to_chars(i, m_Buf); return write(m_Buf, sizeof(i)) != 0; } -bool Connection::sendInt(uint32_t i) { +bool Connection::sendI32(int32_t i) { convert_to_chars(i, m_Buf); return write(m_Buf, sizeof(i)) != 0; } -bool Connection::sendLong(uint64_t l) { +bool Connection::sendU64(uint64_t l) { convert_to_chars(l, m_Buf); return write(m_Buf, sizeof(l)) != 0; @@ -185,9 +210,9 @@ bool Connection::sendPacketNumber() { return true; } - uint64_t pn = m_PacketNumber++; + auto pn = m_PacketNumber++; - return sendLong(pn); + return sendU64(pn); } bool Connection::sendShortString(const char* str) { @@ -199,18 +224,12 @@ bool Connection::sendShortString(const char* str) { return true; } -bool Connection::sendPacketType(uint8_t type) { - MUST_TRANSFER_BOOL(sendByte(0)); - MUST_TRANSFER_BOOL(sendByte(0)); - MUST_TRANSFER_BOOL(sendByte(0)); - - return sendByte(type); -} +bool Connection::sendPacketType(int32_t type) { return sendI32(type); } bool Connection::sendLongString(const char* str) { int size = strlen(str); - MUST_TRANSFER_BOOL(sendInt(size)); + MUST_TRANSFER_BOOL(sendI32(size)); return sendBytes((const uint8_t*)str, size); } @@ -376,7 +395,8 @@ void Connection::sendFeatureFlags() { MUST(sendPacketType(PACKET_FEATURE_FLAGS)); MUST(sendPacketNumber()); - MUST(write(FirmwareFeatures::flags.data(), FirmwareFeatures::flags.size())); + auto packedFeatures = m_FirmwareFeatures.pack(); + MUST(write(packedFeatures.data(), packedFeatures.size())); MUST(endPacket()); } @@ -391,17 +411,17 @@ void Connection::sendTrackerDiscovery() { MUST(sendPacketType(PACKET_HANDSHAKE)); // Packet number is always 0 for handshake - MUST(sendLong(0)); - MUST(sendInt(BOARD)); + MUST(sendU64(0)); + MUST(sendI32(BOARD)); // This is kept for backwards compatibility, // but the latest SlimeVR server will not initialize trackers // with firmware build > 8 until it recieves a sensor info packet - MUST(sendInt(static_cast(sensorManager.getSensorType(0)))); - MUST(sendInt(HARDWARE_MCU)); - MUST(sendInt(0)); - MUST(sendInt(0)); - MUST(sendInt(0)); - MUST(sendInt(FIRMWARE_BUILD_NUMBER)); + MUST(sendI32(static_cast(sensorManager.getSensorType(0)))); + MUST(sendI32(HARDWARE_MCU)); + MUST(sendI32(0)); + MUST(sendI32(0)); + MUST(sendI32(0)); + MUST(sendI32(FIRMWARE_BUILD_NUMBER)); MUST(sendShortString(FIRMWARE_VERSION)); // MAC address string MUST(sendBytes(mac, 6)); @@ -437,19 +457,19 @@ void Connection::sendInspectionRawIMUData( MUST(sendByte(sensorId)); MUST(sendByte(PACKET_INSPECTION_DATATYPE_INT)); - MUST(sendInt(rX)); - MUST(sendInt(rY)); - MUST(sendInt(rZ)); + MUST(sendI32(rX)); + MUST(sendI32(rY)); + MUST(sendI32(rZ)); MUST(sendByte(rA)); - MUST(sendInt(aX)); - MUST(sendInt(aY)); - MUST(sendInt(aZ)); + MUST(sendI32(aX)); + MUST(sendI32(aY)); + MUST(sendI32(aZ)); MUST(sendByte(aA)); - MUST(sendInt(mX)); - MUST(sendInt(mY)); - MUST(sendInt(mZ)); + MUST(sendI32(mX)); + MUST(sendI32(mY)); + MUST(sendI32(mZ)); MUST(sendByte(mA)); MUST(endPacket()); @@ -573,7 +593,7 @@ void Connection::searchForServer() { m_Connected = true; m_FeatureFlagsRequestAttempts = 0; - m_ServerFeatures = ServerFeatures { }; + m_ServerFeatures = FeatureFlags(); statusManager.setStatus(SlimeVR::Status::SERVER_CONNECTING, false); ledManager.off(); @@ -603,7 +623,11 @@ void Connection::searchForServer() { void Connection::reset() { m_Connected = false; - std::fill(m_AckedSensorState, m_AckedSensorState+MAX_IMU_COUNT, SensorStatus::SENSOR_OFFLINE); + std::fill( + m_AckedSensorState, + m_AckedSensorState + MAX_IMU_COUNT, + SensorStatus::SENSOR_OFFLINE + ); m_UDP.begin(m_ServerPort); @@ -614,18 +638,23 @@ void Connection::update() { auto & sensors = sensorManager.getSensors(); updateSensorState(sensors); - maybeRequestFeatureFlags(); if (!m_Connected) { searchForServer(); return; } + maybeRequestFeatureFlags(); + if (m_LastPacketTimestamp + TIMEOUT < millis()) { statusManager.setStatus(SlimeVR::Status::SERVER_CONNECTING, true); m_Connected = false; - std::fill(m_AckedSensorState, m_AckedSensorState+MAX_IMU_COUNT, SensorStatus::SENSOR_OFFLINE); + std::fill( + m_AckedSensorState, + m_AckedSensorState + MAX_IMU_COUNT, + SensorStatus::SENSOR_OFFLINE + ); m_Logger.warn("Connection to server timed out"); return; @@ -639,17 +668,7 @@ void Connection::update() { m_LastPacketTimestamp = millis(); int len = m_UDP.read(m_Packet, sizeof(m_Packet)); -#ifdef DEBUG_NETWORK - m_Logger.trace( - "Received %d bytes from %s, port %d", - packetSize, - m_UDP.remoteIP().toString().c_str(), - m_UDP.remotePort() - ); - m_Logger.traceArray("UDP packet contents: ", m_Packet, len); -#else (void)packetSize; -#endif switch (convert_chars(m_Packet)) { case PACKET_RECEIVE_HEARTBEAT: @@ -697,16 +716,17 @@ void Connection::update() { } bool hadFlags = m_ServerFeatures.isAvailable(); - uint32_t flagsLength = len - 12; - m_ServerFeatures = ServerFeatures::from(&m_Packet[12], flagsLength); + m_ServerFeatures + = FeatureFlags(&m_Packet[12], flagsLength); if (!hadFlags) { - #if PACKET_BUNDLING != PACKET_BUNDLING_DISABLED - if (m_ServerFeatures.has(ServerFeatures::PROTOCOL_BUNDLE_SUPPORT)) { - m_Logger.debug("Server supports packet bundling"); - } - #endif +#if PACKET_BUNDLING != PACKET_BUNDLING_DISABLED + if (m_ServerFeatures.has(EServerFeatureFlags::PROTOCOL_BUNDLE_SUPPORT + )) { + m_Logger.debug("Server supports packet bundling"); + } +#endif } break; diff --git a/src/network/connection.h b/src/network/connection.h index b0aefc3e5..33c910d43 100644 --- a/src/network/connection.h +++ b/src/network/connection.h @@ -26,18 +26,20 @@ #include #include +#include "featureflags.h" #include "globals.h" #include "quat.h" #include "sensors/sensor.h" #include "wifihandler.h" -#include "featureflags.h" namespace SlimeVR { namespace Network { class Connection { public: - Connection() { + Connection() + : m_FirmwareFeatures(m_EnabledFirmwareFeatures) { + #ifdef SERVER_IP m_ServerHost.fromString(SERVER_IP); #endif @@ -117,9 +119,7 @@ class Connection { ); #endif - const ServerFeatures& getServerFeatureFlags() { - return m_ServerFeatures; - } + const auto getServerFeatureFlags() { return m_ServerFeatures; } bool beginBundle(); bool endBundle(); @@ -131,16 +131,15 @@ class Connection { bool beginPacket(); bool endPacket(); - size_t write(const uint8_t *buffer, size_t size); - size_t write(uint8_t byte); + size_t write(const uint8_t* buffer, size_t size); - bool sendPacketType(uint8_t type); + bool sendPacketType(int32_t type); bool sendPacketNumber(); bool sendFloat(float f); bool sendByte(uint8_t c); - bool sendShort(uint16_t i); - bool sendInt(uint32_t i); - bool sendLong(uint64_t l); + bool sendU16(uint16_t c); + bool sendI32(int32_t i); + bool sendU64(uint64_t l); bool sendBytes(const uint8_t* c, size_t length); bool sendShortString(const char* str); bool sendLongString(const char* str); @@ -162,7 +161,11 @@ class Connection { SlimeVR::Logging::Logger m_Logger = SlimeVR::Logging::Logger("UDPConnection"); WiFiUDP m_UDP; - unsigned char m_Packet[128]; // buffer for incoming packets + /* + The current incoming packet that is being handled + TODO: remove this from the class and make it a local variable + */ + uint8_t m_Packet[128]; uint64_t m_PacketNumber = 0; int m_ServerPort = 6969; @@ -175,13 +178,31 @@ class Connection { uint8_t m_FeatureFlagsRequestAttempts = 0; unsigned long m_FeatureFlagsRequestTimestamp = millis(); - ServerFeatures m_ServerFeatures{}; + FeatureFlags m_ServerFeatures; + FeatureFlags m_FirmwareFeatures; bool m_IsBundle = false; - uint16_t m_BundlePacketPosition = 0; - uint16_t m_BundlePacketInnerCount = 0; + /* `53` is the maximum size of any packet that could be bundled, which is the + * inspection packet. The additional `2` bytes are from the field which describes + * how long the next bundled packet is. If you're having bundle size issues, then + * you forgot to increase MAX_IMU_COUNT in `defines.h`. */ + constexpr static size_t MAX_BUNDLE_SIZE = MAX_IMU_COUNT * (53 + 2); + /* The bundle that is currently being written. */ + uint8_t m_Bundle[MAX_BUNDLE_SIZE]; + /* Points into `m_Bundle` to indicate where we currently are. */ + uint8_t* m_BundleInnerPosition = m_Bundle; + /* Points to where we started writing the current packet into `m_Bundle`. */ + uint8_t* m_BundleInnerStart = m_Bundle; + /* Count of packets that are in the currently forming bundle. */ + uint16_t m_BundlePacketCount = 0; + + const size_t getBundleSize() const { return m_BundleInnerPosition - m_Bundle; } + const size_t getBundleInnerSize() const { + return m_BundleInnerPosition - m_BundleInnerStart - 2; + } + bool sendBundle(); - unsigned char m_Buf[8]; + uint8_t m_Buf[8]; }; } // namespace Network diff --git a/src/network/featureflags.h b/src/network/featureflags.h index b0820be40..64148437b 100644 --- a/src/network/featureflags.h +++ b/src/network/featureflags.h @@ -1,103 +1,102 @@ + /* - SlimeVR Code is placed under the MIT license - Copyright (c) 2023 SlimeVR Contributors - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - THE SOFTWARE. + SlimeVR Code is placed under the MIT license + Copyright (c) 2023 SlimeVR Contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. */ #ifndef SLIMEVR_FEATURE_FLAGS_H_ #define SLIMEVR_FEATURE_FLAGS_H_ -#include #include - -/** - * Bit packed flags, enum values start with 0 and indicate which bit it is. - * - * Change the enums and `flagsEnabled` inside to extend. -*/ -struct ServerFeatures { -public: - enum EServerFeatureFlags: uint32_t { - // Server can parse bundle packets: `PACKET_BUNDLE` = 100 (0x64). - PROTOCOL_BUNDLE_SUPPORT, - - // Add new flags here - - BITS_TOTAL, - }; - - bool has(EServerFeatureFlags flag) { - uint32_t bit = static_cast(flag); - return m_Available && (m_Flags[bit / 8] & (1 << (bit % 8))); - } - - /** - * Whether the server supports the "feature flags" feature, - * set to true when we've received flags packet from the server. - */ - bool isAvailable() { - return m_Available; - } - - static ServerFeatures from(uint8_t* received, uint32_t length) { - ServerFeatures res; - res.m_Available = true; - memcpy(res.m_Flags, received, std::min(static_cast(sizeof(res.m_Flags)), length)); - return res; - } - -private: - bool m_Available = false; - - uint8_t m_Flags[static_cast(EServerFeatureFlags::BITS_TOTAL) / 8 + 1]; +#include +#include +#include + +// I hate C++11 - they fixed this in C++14, but our compilers are old as the iceage +struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } }; -class FirmwareFeatures { -public: - enum EFirmwareFeatureFlags: uint32_t { - // EXAMPLE_FEATURE, - B64_WIFI_SCANNING = 1, - - // Add new flags here +enum EServerFeatureFlags : uint32_t { + // Server can parse bundle packets: `PACKET_BUNDLE` = 100 (0x64). + PROTOCOL_BUNDLE_SUPPORT, - BITS_TOTAL, - }; + BITS_TOTAL, +}; - // Flags to send - static constexpr const std::initializer_list flagsEnabled = { - // EXAMPLE_FEATURE, - B64_WIFI_SCANNING, +enum class EFirmwareFeatureFlags : uint32_t { + // EXAMPLE_FEATURE, + B64_WIFI_SCANNING = 1, - // Add enabled flags here - }; + BITS_TOTAL, +}; - static constexpr auto flags = []{ - constexpr uint32_t flagsLength = EFirmwareFeatureFlags::BITS_TOTAL / 8 + 1; - std::array packed{}; +static const std::unordered_map + m_EnabledFirmwareFeatures = {{EFirmwareFeatureFlags::B64_WIFI_SCANNING, true}}; - for (uint32_t bit : flagsEnabled) { - packed[bit / 8] |= 1 << (bit % 8); - } +template +class FeatureFlags { +public: + static constexpr auto FLAG_BYTES + = ((static_cast(Flags::BITS_TOTAL)) + 7) / 8; + + FeatureFlags() + : m_Available(false) {} + FeatureFlags(uint8_t* packed, uint32_t length) + : m_Available(true) { + for (uint32_t bit = 0; bit < length * 8; bit++) { + auto posInPacked = bit / 8; + auto posInByte = bit % 8; + + m_Flags[static_cast(bit)] = packed[posInPacked] & (1 << posInByte); + } + } + FeatureFlags(std::unordered_map flags) + : m_Available(true) + , m_Flags(flags) {} + + std::array pack() { + std::array packed{}; + + for (auto& [flag, value] : m_Flags) { + auto posInPacked = static_cast(flag) / 8; + auto posInByte = static_cast(flag) % 8; + + if (value) { + packed[posInPacked] |= 1 << posInByte; + } + } + + return packed; + }; + + bool has(Flags flag) { return m_Flags[flag]; } + bool isAvailable() { return m_Available; } - return packed; - }(); +private: + bool m_Available = false; + std::unordered_map m_Flags{}; }; #endif