Skip to content

Commit

Permalink
Merge pull request #459 from JetBrains/ab-fix-simple-socke
Browse files Browse the repository at this point in the history
Socket processing improvements.
  • Loading branch information
mirasrael authored Jan 3, 2024
2 parents 9e5a37c + 8cdf969 commit 95c5b78
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 160 deletions.
8 changes: 8 additions & 0 deletions rd-cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ if (CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS "7.0.0")
message(FATAL_ERROR "Insufficient clang version")
endif ()
if (CMAKE_BUILD_TYPE MATCHES "Debug")
option(USE_ADDRESS_SANITIZER "Use address sanitizer to troubleshoot invalid allocations" ON)
else ()
option(USE_ADDRESS_SANITIZER "Use address sanitizer to troubleshoot invalid allocations" OFF)
endif ()
if (USE_ADDRESS_SANITIZER)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -g")
endif()
endif ()
if (MINGW)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wa,-mbig-obj")
Expand Down
7 changes: 3 additions & 4 deletions rd-cpp/src/rd_core_cpp/src/test/cases/ViewableSetTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ using namespace rd;

TEST(viewable_set, advise)
{
std::unique_ptr<IViewableSet<int>> set = std::make_unique<ViewableSet<int>>();

std::vector<int> logAdvise;
std::vector<int> logView1;
std::vector<int> logView2;
std::unique_ptr<IViewableSet<int>> set = std::make_unique<ViewableSet<int>>();
LifetimeDefinition::use([&](Lifetime lt) {
set->advise(lt, [&](AddRemove kind, int const& v) { logAdvise.push_back(kind == AddRemove::ADD ? v : -v); });
set->view(lt, [&](Lifetime inner, int const& v) {
Expand Down Expand Up @@ -66,8 +65,8 @@ TEST(viewable_set, view)

std::unique_ptr<IViewableSet<int>> set = std::make_unique<ViewableSet<int>>();
std::vector<std::string> log;
auto x = LifetimeDefinition::use([&](Lifetime lifetime) {
set->view(lifetime, [&](Lifetime lt, int const& value) {
auto x = LifetimeDefinition::use([&](const Lifetime& lifetime) {
set->view(lifetime, [&](const Lifetime& lt, int const& value) {
log.push_back("View " + std::to_string(value));
lt->add_action([&]() { log.push_back("UnView " + std::to_string(value)); });
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ void ByteBufferAsyncProcessor::ThreadProc()
return;
}

while (data.empty() && queue.empty() || interrupt_balance != 0)
while ((data.empty() && queue.empty()) || interrupt_balance != 0)
{
if (state >= StateKind::Stopping)
{
Expand Down
29 changes: 18 additions & 11 deletions rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <SimpleSocket.h>
#include <ActiveSocket.h>
#include <PassiveSocket.h>
#include <SimpleSocketSender.h>

#include <utility>
#include <thread>
Expand Down Expand Up @@ -81,16 +82,16 @@ bool SocketWire::Base::send0(Buffer::ByteArray const& msg, sequence_number_t seq
send_package_header.write_integral(seqn);

RD_ASSERT_THROW_MSG(
socket_provider->Send(send_package_header.data(), send_package_header.get_position()) == PACKAGE_HEADER_LENGTH,
socket_sender->Send(send_package_header.data(), send_package_header.get_position()) == PACKAGE_HEADER_LENGTH,
this->id +
": failed to send header over the network"
", reason: " +
socket_provider->DescribeError())
socket_sender->DescribeError())

RD_ASSERT_THROW_MSG(socket_provider->Send(msg.data(), msglen) == msglen, this->id +
RD_ASSERT_THROW_MSG(socket_sender->Send(msg.data(), msglen) == msglen, this->id +
": failed to send package over the network"
", reason: " +
socket_provider->DescribeError());
socket_sender->DescribeError());
logger->info("{}: were sent {} bytes", this->id, msglen);
// RD_ASSERT_MSG(socketProvider->Flush(), "{}: failed to flush");
return true;
Expand Down Expand Up @@ -126,6 +127,7 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr<CActiveSocket> new_so
{
std::lock_guard<decltype(socket_send_lock)> guard(socket_send_lock);
socket_provider = std::move(new_socket);
socket_sender = std::make_unique<CSimpleSocketSender>(socket_provider);
socket_send_var.notify_all();
}
{
Expand All @@ -136,8 +138,8 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr<CActiveSocket> new_so
}
}

auto heartbeat = LifetimeDefinition::use([this](Lifetime heartbeatLifetime) {
const auto heartbeat = start_heartbeat(heartbeatLifetime).share();
const auto heartbeat = LifetimeDefinition::use([this](Lifetime heartbeatLifetime) {
const auto heartbeat = start_heartbeat(std::move(heartbeatLifetime)).share();

async_send_buffer.resume();

Expand All @@ -159,6 +161,11 @@ void SocketWire::Base::set_socket_provider(std::shared_ptr<CActiveSocket> new_so
{
logger->debug("{}: socket was already shut down", this->id);
}
else if (socket_provider->GetSocketError() == CSimpleSocket::SocketNotconnected)
{
logger->debug("{}: socket not connected (shutdown likely was initiated by client)");
socket_provider->Close();
}
else if (!socket_provider->Shutdown(CSimpleSocket::Both))
{
// double close?
Expand Down Expand Up @@ -393,14 +400,14 @@ void SocketWire::Base::ping() const
ping_pkg_header.write_integral(counterpart_timestamp);
{
std::lock_guard<decltype(socket_send_lock)> guard(socket_send_lock);
int32_t sent = socket_provider->Send(ping_pkg_header.data(), ping_pkg_header.get_position());
if (sent == 0 && !socket_provider->IsSocketValid())
int32_t sent = socket_sender->Send(ping_pkg_header.data(), ping_pkg_header.get_position());
if (sent == 0 && !socket_sender->IsSocketValid())
{
logger->debug("{}: failed to send ping over the network, reason: socket was shut down for sending", this->id);
return;
}
RD_ASSERT_THROW_MSG(sent == PACKAGE_HEADER_LENGTH,
fmt::format("{}: failed to send ping over the network, reason: {}", this->id, socket_provider->DescribeError()))
fmt::format("{}: failed to send ping over the network, reason: {}", this->id, socket_sender->DescribeError()))
}

++current_timestamp;
Expand All @@ -421,11 +428,11 @@ bool SocketWire::Base::send_ack(sequence_number_t seqn) const
ack_buffer.write_integral(seqn);
{
std::lock_guard<decltype(socket_send_lock)> guard(socket_send_lock);
RD_ASSERT_THROW_MSG(socket_provider->Send(ack_buffer.data(), ack_buffer.get_position()) == PACKAGE_HEADER_LENGTH,
RD_ASSERT_THROW_MSG(socket_sender->Send(ack_buffer.data(), ack_buffer.get_position()) == PACKAGE_HEADER_LENGTH,
this->id +
": failed to send ack over the network"
", reason: " +
socket_provider->DescribeError())
socket_sender->DescribeError())
}
return true;
}
Expand Down
3 changes: 3 additions & 0 deletions rd-cpp/src/rd_framework_cpp/src/main/wire/SocketWire.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ RD_PUSH_STL_EXPORTS_WARNINGS
class CSimpleSocket;
class CActiveSocket;
class CPassiveSocket;
class CSimpleSocketSender;

namespace rd
{
Expand All @@ -39,6 +40,8 @@ class RD_FRAMEWORK_API SocketWire
std::string id;
IScheduler* scheduler = nullptr;
std::shared_ptr<CSimpleSocket> socket_provider;
// we do use separate sender for socket_provider to avoid concurrent state modifications during contesting receive and send operations
std::unique_ptr<CSimpleSocketSender> socket_sender;

std::shared_ptr<CActiveSocket> socket;

Expand Down
2 changes: 2 additions & 0 deletions rd-cpp/thirdparty/clsocket/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ SET(CLSOCKET_HEADERS
src/PassiveSocket.h
src/SimpleSocket.h
src/StatTimer.h
src/SimpleSocketSender.h
)

SET(CLSOCKET_SOURCES
src/SimpleSocket.cpp
src/ActiveSocket.cpp
src/PassiveSocket.cpp
src/SimpleSocketSender.cpp
)

# mark headers as headers...
Expand Down
15 changes: 3 additions & 12 deletions rd-cpp/thirdparty/clsocket/src/ActiveSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ bool CActiveSocket::ConnectTCP(const char *pAddr, uint16_t nPort)
// Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only.
//
//------------------------------------------------------------------
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) ==
CSimpleSocket::SocketError)
Expand Down Expand Up @@ -121,8 +120,6 @@ bool CActiveSocket::ConnectTCP(const char *pAddr, uint16_t nPort)
bRetVal = true;
}

m_timer.SetEndTime();

return bRetVal;
}

Expand Down Expand Up @@ -170,8 +167,7 @@ bool CActiveSocket::ConnectUDP(const char *pAddr, uint16_t nPort)
// Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only.
//
//------------------------------------------------------------------
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) != CSimpleSocket::SocketError)
{
Expand All @@ -180,8 +176,6 @@ bool CActiveSocket::ConnectUDP(const char *pAddr, uint16_t nPort)

TranslateSocketError();

m_timer.SetEndTime();

return bRetVal;
}

Expand Down Expand Up @@ -228,8 +222,7 @@ bool CActiveSocket::ConnectRAW(const char *pAddr, uint16_t nPort)
// Connect to address "xxx.xxx.xxx.xxx" (IPv4) address only.
//
//------------------------------------------------------------------
m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

if (connect(m_socket, (struct sockaddr*)&m_stServerSockaddr, sizeof(m_stServerSockaddr)) != CSimpleSocket::SocketError)
{
Expand All @@ -238,8 +231,6 @@ bool CActiveSocket::ConnectRAW(const char *pAddr, uint16_t nPort)

TranslateSocketError();

m_timer.SetEndTime();

return bRetVal;
}

Expand Down
80 changes: 35 additions & 45 deletions rd-cpp/thirdparty/clsocket/src/PassiveSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,23 @@ bool CPassiveSocket::BindMulticast(const char *pInterface, const char *pGroup, u
//--------------------------------------------------------------------------
// Bind to the specified port
//--------------------------------------------------------------------------
if (bind(m_socket, (struct sockaddr *) &m_stMulticastGroup, sizeof(m_stMulticastGroup)) == 0) {
//----------------------------------------------------------------------
// Join the multicast group
//----------------------------------------------------------------------
m_stMulticastRequest.imr_multiaddr.s_addr = inet_addr(pGroup);
m_stMulticastRequest.imr_interface.s_addr = m_stMulticastGroup.sin_addr.s_addr;

if (SETSOCKOPT(m_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP,
(void *) &m_stMulticastRequest,
sizeof(m_stMulticastRequest)) == CSimpleSocket::SocketSuccess) {
bRetVal = true;
{
CStatTimerCookie timer_cookie(timer);
if (bind(m_socket, (struct sockaddr *) &m_stMulticastGroup, sizeof(m_stMulticastGroup)) == 0) {
//----------------------------------------------------------------------
// Join the multicast group
//----------------------------------------------------------------------
m_stMulticastRequest.imr_multiaddr.s_addr = inet_addr(pGroup);
m_stMulticastRequest.imr_interface.s_addr = m_stMulticastGroup.sin_addr.s_addr;

if (SETSOCKOPT(m_socket, IPPROTO_IP, IP_ADD_MEMBERSHIP,
(void *) &m_stMulticastRequest,
sizeof(m_stMulticastRequest)) == CSimpleSocket::SocketSuccess) {
bRetVal = true;
}
}

m_timer.SetEndTime();
}

m_timer.Initialize();
m_timer.SetStartTime();


//--------------------------------------------------------------------------
// If there was a new_socket error then close the new_socket to clean out the
// connection in the backlog.
Expand Down Expand Up @@ -152,30 +149,29 @@ bool CPassiveSocket::Listen(const char *pAddr, uint16_t nPort, int32_t nConnecti
}
}

m_timer.Initialize();
m_timer.SetStartTime();

//--------------------------------------------------------------------------
// Bind to the specified port
//--------------------------------------------------------------------------
if (bind(m_socket, (struct sockaddr *) &m_stServerSockaddr, sizeof(m_stServerSockaddr)) !=
CSimpleSocket::SocketError) {
socklen_t namelen = sizeof(m_stServerSockaddr);
if (getsockname(m_socket, (struct sockaddr *) &m_stServerSockaddr, &namelen) != CSimpleSocket::SocketError) {
if (m_nSocketType == CSimpleSocket::SocketTypeTcp) {
if (listen(m_socket, nConnectionBacklog) != CSimpleSocket::SocketError) {
{
CStatTimerCookie timer_cookie(timer);

//--------------------------------------------------------------------------
// Bind to the specified port
//--------------------------------------------------------------------------
if (bind(m_socket, (struct sockaddr *) &m_stServerSockaddr, sizeof(m_stServerSockaddr)) !=
CSimpleSocket::SocketError) {
socklen_t namelen = sizeof(m_stServerSockaddr);
if (getsockname(m_socket, (struct sockaddr *) &m_stServerSockaddr, &namelen) != CSimpleSocket::SocketError) {
if (m_nSocketType == CSimpleSocket::SocketTypeTcp) {
if (listen(m_socket, nConnectionBacklog) != CSimpleSocket::SocketError) {
bRetVal = true;
}
} else {
bRetVal = true;
}
} else {
bRetVal = true;
bRetVal = false;
}
} else {
bRetVal = false;
}
}

m_timer.SetEndTime();

//--------------------------------------------------------------------------
// If there was a new_socket error then close the new_socket to clean out the
// connection in the backlog.
Expand Down Expand Up @@ -213,10 +209,9 @@ CActiveSocket *CPassiveSocket::Accept() {
// Wait for incoming connection.
//--------------------------------------------------------------------------
if (pClientSocket != NULL) {
CSocketError socketErrno = SocketSuccess;
CSocketError socketErrno;

m_timer.Initialize();
m_timer.SetStartTime();
CStatTimerCookie timer_cookie(timer);

nClientSockLen = sizeof(m_stClientSockaddr);

Expand Down Expand Up @@ -246,8 +241,6 @@ CActiveSocket *CPassiveSocket::Accept() {

} while (socketErrno == CSimpleSocket::SocketInterrupted);

m_timer.SetEndTime();

if (socketErrno != CSimpleSocket::SocketSuccess) {
delete pClientSocket;
pClientSocket = NULL;
Expand All @@ -271,13 +264,10 @@ int32_t CPassiveSocket::Send(const uint8_t *pBuf, size_t bytesToSend) {
case CSimpleSocket::SocketTypeUdp: {
if (IsSocketValid()) {
if ((bytesToSend > 0) && (pBuf != NULL)) {
m_timer.Initialize();
m_timer.SetStartTime();

m_nBytesSent = static_cast<int32_t>(SENDTO(m_socket, pBuf, bytesToSend, 0,
reinterpret_cast<const sockaddr*>(&m_stClientSockaddr), sizeof(m_stClientSockaddr)));
CStatTimerCookie timer_cookie(timer);

m_timer.SetEndTime();
m_nBytesSent = static_cast<int32_t>(SENDTO(m_socket, pBuf, bytesToSend, 0,
reinterpret_cast<const sockaddr*>(&m_stClientSockaddr), sizeof(m_stClientSockaddr)));

if (m_nBytesSent == CSimpleSocket::SocketError) {
TranslateSocketError();
Expand Down
Loading

0 comments on commit 95c5b78

Please sign in to comment.