Skip to content

Commit

Permalink
fix: windows wsa startup. (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Jun 4, 2024
1 parent 5f78ddd commit 0929f4e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 17 deletions.
3 changes: 3 additions & 0 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,11 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,

int main(int argc, char *argv[]) {
initQuants();
initSockets();

AppArgs args = AppArgs::parse(argc, argv, false);
App::run(&args, server);

cleanupSockets();
return EXIT_SUCCESS;
}
14 changes: 10 additions & 4 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,27 +219,33 @@ void worker(AppArgs* args) {

int main(int argc, char *argv[]) {
initQuants();
initSockets();

AppArgs args = AppArgs::parse(argc, argv, true);
bool success = false;

if (args.mode != NULL) {
if (strcmp(args.mode, "inference") == 0) {
args.benchmark = true;
App::run(&args, generate);
return EXIT_SUCCESS;
success = true;
} else if (strcmp(args.mode, "generate") == 0) {
args.benchmark = false;
App::run(&args, generate);
return EXIT_SUCCESS;
success = true;
} else if (strcmp(args.mode, "chat") == 0) {
App::run(&args, chat);
return EXIT_SUCCESS;
success = true;
} else if (strcmp(args.mode, "worker") == 0) {
worker(&args);
return EXIT_SUCCESS;
success = true;
}
}

cleanupSockets();

if (success)
return EXIT_SUCCESS;
fprintf(stderr, "Invalid usage\n");
return EXIT_FAILURE;
}
28 changes: 15 additions & 13 deletions src/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ static inline void setReuseAddr(int socket) {
int iresult = setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt));
if (iresult == SOCKET_ERROR) {
closesocket(socket);
WSACleanup();
throw std::runtime_error("setsockopt failed: " + std::to_string(WSAGetLastError()));
}
#else
Expand Down Expand Up @@ -133,6 +132,21 @@ static inline void readSocket(int socket, void* data, size_t size) {
}
}

void initSockets() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError()));
}
#endif
}

void cleanupSockets() {
#ifdef _WIN32
WSACleanup();
#endif
}

ReadSocketException::ReadSocketException(int code, const char* message) {
this->code = code;
this->message = message;
Expand Down Expand Up @@ -342,13 +356,6 @@ SocketServer::SocketServer(int port) {
const char* host = "0.0.0.0";
struct sockaddr_in serverAddr;

#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError()));
}
#endif

socket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (socket < 0)
throw std::runtime_error("Cannot create socket");
Expand All @@ -365,7 +372,6 @@ SocketServer::SocketServer(int port) {
if (bindResult == SOCKET_ERROR) {
int error = WSAGetLastError();
closesocket(socket);
WSACleanup();
throw std::runtime_error("Cannot bind port: " + std::to_string(error));
}
#else
Expand All @@ -380,7 +386,6 @@ SocketServer::SocketServer(int port) {
if (listenResult != 0) {
#ifdef _WIN32
closesocket(socket);
WSACleanup();
throw std::runtime_error("Cannot listen on port: " + std::to_string(WSAGetLastError()));
#else
close(socket);
Expand All @@ -393,8 +398,5 @@ SocketServer::SocketServer(int port) {

SocketServer::~SocketServer() {
shutdown(socket, 2);
#ifdef _WIN32
WSACleanup();
#endif
close(socket);
}
4 changes: 4 additions & 0 deletions src/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
#include <exception>
#include <vector>

void initSockets();
void cleanupSockets();

class ReadSocketException : public std::exception {
public:
int code;
const char* message;
ReadSocketException(int code, const char* message);
};

class WriteSocketException : public std::exception {
public:
int code;
Expand Down

0 comments on commit 0929f4e

Please sign in to comment.