From 1245d1388b003c46092937def7041917aecec8de Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Thu, 30 May 2024 13:38:26 +0200 Subject: [PATCH] netbase: extend CreateSock() to support creating arbitrary sockets Allow the callers of `CreateSock()` to pass all 3 arguments to the `socket(2)` syscall. This makes it possible to create sockets of any domain/type/protocol. --- src/net.cpp | 2 +- src/netbase.cpp | 34 ++++++++++++++++++---------------- src/netbase.h | 10 ++++++---- src/test/fuzz/fuzz.cpp | 5 +++-- src/test/fuzz/i2p.cpp | 2 +- src/test/i2p_tests.cpp | 13 ++++--------- 6 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index de974f39cba72..990c58ee3d88f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -3029,7 +3029,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - std::unique_ptr sock = CreateSock(addrBind.GetSAFamily()); + std::unique_ptr sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); if (!sock) { strError = strprintf(Untranslated("Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError())); LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original); diff --git a/src/netbase.cpp b/src/netbase.cpp index ff46061d3d88b..fcbdb43e2a970 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -487,24 +487,23 @@ bool Socks5(const std::string& strDest, uint16_t port, const ProxyCredentials* a } } -std::unique_ptr CreateSockOS(sa_family_t address_family) +std::unique_ptr CreateSockOS(int domain, int type, int protocol) { // Not IPv4, IPv6 or UNIX - if (address_family == AF_UNSPEC) return nullptr; - - int protocol{IPPROTO_TCP}; -#if HAVE_SOCKADDR_UN - if (address_family == AF_UNIX) protocol = 0; -#endif + if (domain == AF_UNSPEC) return nullptr; // Create a socket in the specified address family. - SOCKET hSocket = socket(address_family, SOCK_STREAM, protocol); + SOCKET hSocket = socket(domain, type, protocol); if (hSocket == INVALID_SOCKET) { return nullptr; } auto sock = std::make_unique(hSocket); + if (domain != AF_INET && domain != AF_INET6 && domain != AF_UNIX) { + return sock; + } + // Ensure that waiting for I/O on this socket won't result in undefined // behavior. if (!sock->IsSelectable()) { @@ -529,18 +528,21 @@ std::unique_ptr CreateSockOS(sa_family_t address_family) } #if HAVE_SOCKADDR_UN - if (address_family == AF_UNIX) return sock; + if (domain == AF_UNIX) return sock; #endif - // Set the no-delay option (disable Nagle's algorithm) on the TCP socket. - const int on{1}; - if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { - LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n"); + if (protocol == IPPROTO_TCP) { + // Set the no-delay option (disable Nagle's algorithm) on the TCP socket. + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n"); + } } + return sock; } -std::function(const sa_family_t&)> CreateSock = CreateSockOS; +std::function(int, int, int)> CreateSock = CreateSockOS; template static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) { @@ -609,7 +611,7 @@ static bool ConnectToSocket(const Sock& sock, struct sockaddr* sockaddr, socklen std::unique_ptr ConnectDirectly(const CService& dest, bool manual_connection) { - auto sock = CreateSock(dest.GetSAFamily()); + auto sock = CreateSock(dest.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP); if (!sock) { LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort()); return {}; @@ -637,7 +639,7 @@ std::unique_ptr Proxy::Connect() const if (!m_is_unix_socket) return ConnectDirectly(proxy, /*manual_connection=*/true); #if HAVE_SOCKADDR_UN - auto sock = CreateSock(AF_UNIX); + auto sock = CreateSock(AF_UNIX, SOCK_STREAM, 0); if (!sock) { LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", m_unix_socket_path); return {}; diff --git a/src/netbase.h b/src/netbase.h index 321c288f67bee..8ef6c28996051 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -262,16 +262,18 @@ CService LookupNumeric(const std::string& name, uint16_t portDefault = 0, DNSLoo CSubNet LookupSubNet(const std::string& subnet_str); /** - * Create a TCP or UNIX socket in the given address family. - * @param[in] address_family to use for the socket. + * Create a real socket from the operating system. + * @param[in] domain Communications domain, first argument to the socket(2) syscall. + * @param[in] type Type of the socket, second argument to the socket(2) syscall. + * @param[in] protocol The particular protocol to be used with the socket, third argument to the socket(2) syscall. * @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure */ -std::unique_ptr CreateSockOS(sa_family_t address_family); +std::unique_ptr CreateSockOS(int domain, int type, int protocol); /** * Socket factory. Defaults to `CreateSockOS()`, but can be overridden by unit tests. */ -extern std::function(const sa_family_t&)> CreateSock; +extern std::function(int, int, int)> CreateSock; /** * Create a socket and try to connect to the specified service. diff --git a/src/test/fuzz/fuzz.cpp b/src/test/fuzz/fuzz.cpp index 9a54a44bd3a26..c1c9945a04af8 100644 --- a/src/test/fuzz/fuzz.cpp +++ b/src/test/fuzz/fuzz.cpp @@ -101,8 +101,9 @@ void ResetCoverageCounters() {} void initialize() { - // Terminate immediately if a fuzzing harness ever tries to create a TCP socket. - CreateSock = [](const sa_family_t&) -> std::unique_ptr { std::terminate(); }; + // Terminate immediately if a fuzzing harness ever tries to create a socket. + // Individual tests can override this by pointing CreateSock to a mocked alternative. + CreateSock = [](int, int, int) -> std::unique_ptr { std::terminate(); }; // Terminate immediately if a fuzzing harness ever tries to perform a DNS lookup. g_dns_lookup = [](const std::string& name, bool allow_lookup) { diff --git a/src/test/fuzz/i2p.cpp b/src/test/fuzz/i2p.cpp index 3af5bed30afce..51517187a0ae3 100644 --- a/src/test/fuzz/i2p.cpp +++ b/src/test/fuzz/i2p.cpp @@ -27,7 +27,7 @@ FUZZ_TARGET(i2p, .init = initialize_i2p) // Mock CreateSock() to create FuzzedSock. auto CreateSockOrig = CreateSock; - CreateSock = [&fuzzed_data_provider](const sa_family_t&) { + CreateSock = [&fuzzed_data_provider](int, int, int) { return std::make_unique(fuzzed_data_provider); }; diff --git a/src/test/i2p_tests.cpp b/src/test/i2p_tests.cpp index d7249d88f4199..0512c6134fccf 100644 --- a/src/test/i2p_tests.cpp +++ b/src/test/i2p_tests.cpp @@ -39,15 +39,14 @@ class EnvTestingSetup : public BasicTestingSetup private: const BCLog::Level m_prev_log_level; - const std::function(const sa_family_t&)> m_create_sock_orig; + const decltype(CreateSock) m_create_sock_orig; }; BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup) BOOST_AUTO_TEST_CASE(unlimited_recv) { - // Mock CreateSock() to create MockSock. - CreateSock = [](const sa_family_t&) { + CreateSock = [](int, int, int) { return std::make_unique(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a')); }; @@ -69,7 +68,7 @@ BOOST_AUTO_TEST_CASE(unlimited_recv) BOOST_AUTO_TEST_CASE(listen_ok_accept_fail) { size_t num_sockets{0}; - CreateSock = [&num_sockets](const sa_family_t&) { + CreateSock = [&num_sockets](int, int, int) { // clang-format off ++num_sockets; // First socket is the control socket for creating the session. @@ -133,9 +132,7 @@ BOOST_AUTO_TEST_CASE(listen_ok_accept_fail) BOOST_AUTO_TEST_CASE(damaged_private_key) { - const auto CreateSockOrig = CreateSock; - - CreateSock = [](const sa_family_t&) { + CreateSock = [](int, int, int) { return std::make_unique("HELLO REPLY RESULT=OK VERSION=3.1\n" "SESSION STATUS RESULT=OK DESTINATION=\n"); }; @@ -172,8 +169,6 @@ BOOST_AUTO_TEST_CASE(damaged_private_key) BOOST_CHECK(!session.Connect(CService{}, conn, proxy_error)); } } - - CreateSock = CreateSockOrig; } BOOST_AUTO_TEST_SUITE_END()