From aa0db5c0533f3c85049cdd2e1803cba24eca0307 Mon Sep 17 00:00:00 2001 From: xbasel <103044017+xbasel@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:30:11 +0200 Subject: [PATCH] Preserve original fd blocking state in TLS I/O operations This change prevents unintended side effects on connection state and improves consistency with non-TLS sync operations. Signed-off-by: xbasel <103044017+xbasel@users.noreply.github.com> --- src/anet.c | 30 ++++++++++++++++++++++++++---- src/anet.h | 1 + src/tls.c | 21 ++++++++++++++++----- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/anet.c b/src/anet.c index d4ac698982..8dc06ca62e 100644 --- a/src/anet.c +++ b/src/anet.c @@ -70,17 +70,24 @@ int anetGetError(int fd) { return sockerr; } -int anetSetBlock(char *err, int fd, int non_block) { +static int anetGetSocketFlags(char *err, int fd) { int flags; - /* Set the socket blocking (if non_block is zero) or non-blocking. - * Note that fcntl(2) for F_GETFL and F_SETFL can't be - * interrupted by a signal. */ if ((flags = fcntl(fd, F_GETFL)) == -1) { anetSetError(err, "fcntl(F_GETFL): %s", strerror(errno)); return ANET_ERR; } + return flags; +} + +int anetSetBlock(char *err, int fd, int non_block) { + int flags = anetGetSocketFlags(err, fd); + + if (flags == ANET_ERR) { + return ANET_ERR; + } + /* Check if this flag has been set or unset, if so, * then there is no need to call fcntl to set/unset it again. */ if (!!(flags & O_NONBLOCK) == !!non_block) return ANET_OK; @@ -105,6 +112,21 @@ int anetBlock(char *err, int fd) { return anetSetBlock(err, fd, 0); } +int anetIsBlock(char *err, int fd) { + int flags = anetGetSocketFlags(err, fd); + + if (flags == ANET_ERR) { + return ANET_ERR; + } + + /* Check if the O_NONBLOCK flag is set */ + if (flags & O_NONBLOCK) { + return 0; /* Socket is non-blocking */ + } else { + return 1; /* Socket is blocking */ + } +} + /* Enable the FD_CLOEXEC on the given fd to avoid fd leaks. * This function should be invoked for fd's on specific places * where fork + execve system calls are called. */ diff --git a/src/anet.h b/src/anet.h index ab32f72e4b..b14b4bdaad 100644 --- a/src/anet.h +++ b/src/anet.h @@ -61,6 +61,7 @@ int anetTcpAccept(char *err, int serversock, char *ip, size_t ip_len, int *port) int anetUnixAccept(char *err, int serversock); int anetNonBlock(char *err, int fd); int anetBlock(char *err, int fd); +int anetIsBlock(char *err, int fd); int anetCloexec(int fd); int anetEnableTcpNoDelay(char *err, int fd); int anetDisableTcpNoDelay(char *err, int fd); diff --git a/src/tls.c b/src/tls.c index f1c82d35e4..216f7f247b 100644 --- a/src/tls.c +++ b/src/tls.c @@ -967,6 +967,10 @@ static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) return C_OK; } +static int isBlocking(tls_connection *conn) { + return anetIsBlock(NULL, conn->c.fd); +} + static void setBlockingTimeout(tls_connection *conn, long long timeout) { anetBlock(NULL, conn->c.fd); anetSendTimeout(NULL, conn->c.fd, timeout); @@ -1005,26 +1009,30 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port, static ssize_t connTLSSyncWrite(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *)conn_; - + int blocking = isBlocking(conn); setBlockingTimeout(conn, timeout); SSL_clear_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); ERR_clear_error(); int ret = SSL_write(conn->ssl, ptr, size); ret = updateStateAfterSSLIO(conn, ret, 0); SSL_set_mode(conn->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE); - unsetBlockingTimeout(conn); + if (!blocking) { + unsetBlockingTimeout(conn); + } return ret; } static ssize_t connTLSSyncRead(connection *conn_, char *ptr, ssize_t size, long long timeout) { tls_connection *conn = (tls_connection *)conn_; - + int blocking = isBlocking(conn); setBlockingTimeout(conn, timeout); ERR_clear_error(); int ret = SSL_read(conn->ssl, ptr, size); ret = updateStateAfterSSLIO(conn, ret, 0); - unsetBlockingTimeout(conn); + if (!blocking) { + unsetBlockingTimeout(conn); + } return ret; } @@ -1033,6 +1041,7 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l tls_connection *conn = (tls_connection *)conn_; ssize_t nread = 0; + int blocking = isBlocking(conn); setBlockingTimeout(conn, timeout); size--; @@ -1058,7 +1067,9 @@ static ssize_t connTLSSyncReadLine(connection *conn_, char *ptr, ssize_t size, l size--; } exit: - unsetBlockingTimeout(conn); + if (!blocking) { + unsetBlockingTimeout(conn); + } return nread; }