From 0929f4e23dffeea52b13bb5af2896bb9e8287d70 Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Tue, 4 Jun 2024 15:57:32 +0200 Subject: [PATCH] fix: windows wsa startup. (#88) --- src/apps/dllama-api/dllama-api.cpp | 3 +++ src/apps/dllama/dllama.cpp | 14 ++++++++++---- src/socket.cpp | 28 +++++++++++++++------------- src/socket.hpp | 4 ++++ 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp index d02b8fe..002206f 100644 --- a/src/apps/dllama-api/dllama-api.cpp +++ b/src/apps/dllama-api/dllama-api.cpp @@ -433,8 +433,11 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, int main(int argc, char *argv[]) { initQuants(); + initSockets(); AppArgs args = AppArgs::parse(argc, argv, false); App::run(&args, server); + + cleanupSockets(); return EXIT_SUCCESS; } diff --git a/src/apps/dllama/dllama.cpp b/src/apps/dllama/dllama.cpp index e37a29e..69d7814 100644 --- a/src/apps/dllama/dllama.cpp +++ b/src/apps/dllama/dllama.cpp @@ -219,27 +219,33 @@ void worker(AppArgs* args) { int main(int argc, char *argv[]) { initQuants(); + initSockets(); AppArgs args = AppArgs::parse(argc, argv, true); + bool success = false; if (args.mode != NULL) { if (strcmp(args.mode, "inference") == 0) { args.benchmark = true; App::run(&args, generate); - return EXIT_SUCCESS; + success = true; } else if (strcmp(args.mode, "generate") == 0) { args.benchmark = false; App::run(&args, generate); - return EXIT_SUCCESS; + success = true; } else if (strcmp(args.mode, "chat") == 0) { App::run(&args, chat); - return EXIT_SUCCESS; + success = true; } else if (strcmp(args.mode, "worker") == 0) { worker(&args); - return EXIT_SUCCESS; + success = true; } } + cleanupSockets(); + + if (success) + return EXIT_SUCCESS; fprintf(stderr, "Invalid usage\n"); return EXIT_FAILURE; } diff --git a/src/socket.cpp b/src/socket.cpp index 1409ea5..25ccddf 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -75,7 +75,6 @@ static inline void setReuseAddr(int socket) { int iresult = setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt)); if (iresult == SOCKET_ERROR) { closesocket(socket); - WSACleanup(); throw std::runtime_error("setsockopt failed: " + std::to_string(WSAGetLastError())); } #else @@ -133,6 +132,21 @@ static inline void readSocket(int socket, void* data, size_t size) { } } +void initSockets() { +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError())); + } +#endif +} + +void cleanupSockets() { +#ifdef _WIN32 + WSACleanup(); +#endif +} + ReadSocketException::ReadSocketException(int code, const char* message) { this->code = code; this->message = message; @@ -342,13 +356,6 @@ SocketServer::SocketServer(int port) { const char* host = "0.0.0.0"; struct sockaddr_in serverAddr; - #ifdef _WIN32 - WSADATA wsaData; - if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { - throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError())); - } - #endif - socket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (socket < 0) throw std::runtime_error("Cannot create socket"); @@ -365,7 +372,6 @@ SocketServer::SocketServer(int port) { if (bindResult == SOCKET_ERROR) { int error = WSAGetLastError(); closesocket(socket); - WSACleanup(); throw std::runtime_error("Cannot bind port: " + std::to_string(error)); } #else @@ -380,7 +386,6 @@ SocketServer::SocketServer(int port) { if (listenResult != 0) { #ifdef _WIN32 closesocket(socket); - WSACleanup(); throw std::runtime_error("Cannot listen on port: " + std::to_string(WSAGetLastError())); #else close(socket); @@ -393,8 +398,5 @@ SocketServer::SocketServer(int port) { SocketServer::~SocketServer() { shutdown(socket, 2); - #ifdef _WIN32 - WSACleanup(); - #endif close(socket); } diff --git a/src/socket.hpp b/src/socket.hpp index 82f4fd0..5849793 100644 --- a/src/socket.hpp +++ b/src/socket.hpp @@ -6,12 +6,16 @@ #include #include +void initSockets(); +void cleanupSockets(); + class ReadSocketException : public std::exception { public: int code; const char* message; ReadSocketException(int code, const char* message); }; + class WriteSocketException : public std::exception { public: int code;