diff --git a/CMakeLists.txt b/CMakeLists.txt index 060f33d..9f53b4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,4 +2,6 @@ cmake_minimum_required (VERSION 3.8) project ("vsock-bridge") +set(CMAKE_CXX_FLAGS_DEBUG "-ggdb") + add_subdirectory ("vsock-bridge") diff --git a/vsock-bridge/include/buffer.h b/vsock-bridge/include/buffer.h index 7be5a76..02d1aa0 100644 --- a/vsock-bridge/include/buffer.h +++ b/vsock-bridge/include/buffer.h @@ -1,29 +1,30 @@ #pragma once -#include -#include +#include "logger.h" + +#include +#include +#include #include -#include #include -#include +#include +#include + +#include namespace vsockio { struct MemoryBlock { - MemoryBlock(uint8_t* startPtr, class MemoryArena* region) - : _startPtr(startPtr), _region(region) {} - - MemoryBlock(MemoryBlock&& other) : _startPtr(other._startPtr), _region(other._region) {} + MemoryBlock(int size, class MemoryArena* region) + : _startPtr(std::make_unique(size)), _region(region) {} - MemoryBlock(const MemoryBlock& other) = delete; - - uint8_t* offset(int x) + uint8_t* offset(int x) const { - return _startPtr + x; + return _startPtr.get() + x; } - uint8_t* _startPtr; + std::unique_ptr _startPtr; class MemoryArena* _region; }; @@ -31,27 +32,22 @@ namespace vsockio { std::vector _blocks; std::list _handles; - uint8_t* _memoryStart; - uint32_t _blockSizeInBytes; - int _numBlocks; - bool _initialized; + uint32_t _blockSizeInBytes = 0; + bool _initialized = false; - MemoryArena() - : _initialized(false), _numBlocks(0), _memoryStart(nullptr), _blocks{} {} + MemoryArena() = default; - void init(uint32_t blockSize, int numBlocks) + void init(int blockSize, int numBlocks) { if (_initialized) throw; Logger::instance->Log(Logger::INFO, "Thread-local memory arena init: blockSize=", blockSize, ", numBlocks=", numBlocks); - _numBlocks = numBlocks; _blockSizeInBytes = blockSize; - _memoryStart = static_cast(malloc(blockSize * numBlocks * sizeof(uint8_t))); for (int i = 0; i < numBlocks; i++) { - _blocks.emplace_back(MemoryBlock( startPtrOf(i), this )); + _blocks.emplace_back(blockSize, this); } for (int i = 0; i < numBlocks; i++) @@ -62,11 +58,6 @@ namespace vsockio _initialized = true; } - inline uint8_t* startPtrOf(int blockIndex) const - { - return _memoryStart + (blockIndex * _blockSizeInBytes); - } - MemoryBlock* get() { if (!_handles.empty()) @@ -77,7 +68,7 @@ namespace vsockio } else { - return new MemoryBlock(new uint8_t[_blockSizeInBytes], nullptr); + return new MemoryBlock(_blockSizeInBytes, nullptr); } } @@ -85,11 +76,10 @@ namespace vsockio { if (mb->_region == this) { - _handles.push_back(mb); + _handles.push_front(mb); } else if (mb->_region == nullptr) { - delete[] mb->_startPtr; delete mb; } else @@ -111,7 +101,7 @@ namespace vsockio MemoryBlock* _pages[MAX_PAGES]; MemoryArena* _arena; - Buffer(MemoryArena* arena) : _arena(arena), _pageCount{ 0 }, _cursor{ 0 }, _size{ 0 }, _pageSize(arena->blockSize()) {} + explicit Buffer(MemoryArena* arena) : _arena(arena), _pageCount{ 0 }, _cursor{ 0 }, _size{ 0 }, _pageSize(arena->blockSize()) {} Buffer(Buffer&& b) : _arena(b._arena), _pageCount(b._pageCount), _cursor(b._cursor), _size(b._size), _pageSize(b._arena->blockSize()) { @@ -122,7 +112,51 @@ namespace vsockio b._pageCount = 0; // prevent _pages being destructed by old object } - Buffer(const Buffer& _) = delete; + Buffer(const Buffer&) = delete; + Buffer& operator=(const Buffer&) = delete; + + ~Buffer() + { + for (int i = 0; i < _pageCount; i++) + { + _arena->put(_pages[i]); + } + } + + uint8_t* tail() const + { + return offset(_size); + } + + int remainingCapacity() const + { + return capacity() - _size; + } + + void produce(int size) + { + _size += size; + } + + bool ensureCapacity() + { + return remainingCapacity() > 0 || tryNewPage(); + } + + uint8_t* head() const + { + return offset(_cursor); + } + + int headLimit() const + { + return std::min(pageLimit(_cursor), _size - _cursor); + } + + void consume(int size) + { + _cursor += size; + } bool tryNewPage() { @@ -131,47 +165,39 @@ namespace vsockio return true; } - uint8_t* offset(ssize_t x) + uint8_t* offset(int x) const { return _pages[x / _pageSize]->offset(x % _pageSize); } - size_t capacity() const + int capacity() const { return _pageCount * _pageSize; } - void setCursor(size_t cursor) + int pageLimit(int x) const { - _cursor = cursor; + return _pageSize - (x % _pageSize); } - size_t cursor() const + int cursor() const { return _cursor; } - size_t pageLimit(ssize_t x) + int size() const { - return _pageSize - (x % _pageSize); + return _size; } - void setSize(size_t size) + bool empty() const { - _size = size; + return _size <= 0; } - size_t size() const + bool consumed() const { - return _size; - } - - ~Buffer() - { - for (int i = 0; i < _pageCount; i++) - { - _arena->put(_pages[i]); - } + return _cursor >= _size; } }; @@ -179,16 +205,16 @@ namespace vsockio { thread_local static MemoryArena* arena; - static Buffer* getBuffer() + static std::unique_ptr getBuffer() { - Buffer* b = new Buffer(arena); + auto b = std::make_unique(arena); b->tryNewPage(); return b; } - static Buffer* getEmptyBuffer() + static std::unique_ptr getEmptyBuffer() { - return new Buffer(arena); + return std::make_unique(arena); } }; diff --git a/vsock-bridge/include/channel.h b/vsock-bridge/include/channel.h index fda2326..62a17ca 100644 --- a/vsock-bridge/include/channel.h +++ b/vsock-bridge/include/channel.h @@ -1,9 +1,11 @@ #pragma once -#include -#include -#include -#include +#include "eventdef.h" +#include "logger.h" +#include "socket.h" +#include "threading.h" + +#include namespace vsockio { @@ -23,63 +25,61 @@ namespace vsockio int _id; BlockingQueue* _taskQueue; - Socket* _a; - Socket* _b; + std::unique_ptr _a; + std::unique_ptr _b; ChannelHandle _ha; ChannelHandle _hb; - DirectChannel(int id, Socket* a, Socket* b, BlockingQueue* taskQueue) + DirectChannel(int id, std::unique_ptr a, std::unique_ptr b, BlockingQueue* taskQueue) : _id(id) - , _a(a) - , _b(b) - , _ha(id, a->fd()) - , _hb(id, b->fd()) + , _a(std::move(a)) + , _b(std::move(b)) + , _ha(id, _a->fd()) + , _hb(id, _b->fd()) , _taskQueue(taskQueue) { - _a->setPeer(b); - _b->setPeer(a); + _a->setPeer(_b.get()); + _b->setPeer(_a.get()); } void handle(int fd, int evt) { - - Socket* s = _a->fd() == fd ? _a : (_b->fd() == fd ? _b : nullptr); + Socket* s = _a->fd() == fd ? _a.get() : (_b->fd() == fd ? _b.get() : nullptr); if (s == nullptr) { Logger::instance->Log(Logger::WARNING, "error in channel.handle: `id=", _id,"`, `fd=", fd, "` does not belong to this channel"); + return; } if (evt & IOEvent::Error) { - s->incrementEventCount(); - _taskQueue->enqueue(std::bind(&Socket::onError, s)); + Logger::instance->Log(Logger::DEBUG, "poll error for fd=", fd); + evt |= IOEvent::InputReady; + evt |= IOEvent::OutputReady; } - else + + if (evt & IOEvent::InputReady) { - if (evt & IOEvent::InputReady) - { - s->incrementEventCount(); - _taskQueue->enqueue(std::bind(&Socket::onInputReady, s)); - } + s->onIoEvent(); + _taskQueue->enqueue([=] { s->onInputReady(); }); + } - if (evt & IOEvent::OutputReady) - { - s->incrementEventCount(); - _taskQueue->enqueue(std::bind(&Socket::onOutputReady, s)); - } + if (evt & IOEvent::OutputReady) + { + s->onIoEvent(); + _taskQueue->enqueue([=] { s->onOutputReady(); }); } } - bool canBeTerminated() const + void terminate() { - return _a->closed() && _b->closed() && _a->ioEventCount() == 0 && _b->ioEventCount() == 0; + _taskQueue->enqueue([this] { delete this; }); } - - virtual ~DirectChannel() + + bool canBeTerminated() const { - delete _a; - delete _b; + return _a->closed() && _b->closed() && _a->ioEventCount() == 0 && _b->ioEventCount() == 0; } }; } \ No newline at end of file diff --git a/vsock-bridge/include/config.h b/vsock-bridge/include/config.h index 208b086..337cf50 100644 --- a/vsock-bridge/include/config.h +++ b/vsock-bridge/include/config.h @@ -1,8 +1,8 @@ #pragma once -#include -#include +#include #include +#include namespace vsockproxy { diff --git a/vsock-bridge/include/dispatcher.h b/vsock-bridge/include/dispatcher.h index 26d9377..7ef0868 100644 --- a/vsock-bridge/include/dispatcher.h +++ b/vsock-bridge/include/dispatcher.h @@ -1,141 +1,149 @@ #pragma once -#include +#include "channel.h" +#include "logger.h" +#include "poller.h" + +#include #include -#include #include -#include -#include +#include namespace vsockio { - struct ChannelIdListNode + struct ChannelNode { - int id; - bool inUse; - DirectChannel* channel; - ChannelIdListNode* next; - ChannelIdListNode* prev; - - ChannelIdListNode(int id) - : id(id), inUse(false), channel(nullptr), next(nullptr), prev(nullptr) {} + int _id; + std::unique_ptr _channel; + + explicit ChannelNode(int id) : _id(id) {} + + void reset() + { + _channel.reset(); + } + + bool inUse() const + { + return !!_channel; + } }; - struct ChannelIdList + class ChannelNodePool { - // [head][free(p)][free]...[free][tail] + public: + ChannelNodePool() = default; - ChannelIdListNode* head; - ChannelIdListNode* p; - ChannelIdListNode* tail; - uint32_t _nextId; + ChannelNodePool(const ChannelNodePool&) = delete; + ChannelNodePool& operator=(const ChannelNodePool&) = delete; - ChannelIdList() : head(nullptr), tail(nullptr), _nextId(0) + ~ChannelNodePool() { - head = new ChannelIdListNode(0); - tail = new ChannelIdListNode(0); - head->prev = nullptr; - head->next = tail; - tail->prev = head; - tail->next = nullptr; - p = tail; + for (auto* node : _freeList) + { + delete node; + } } - ChannelIdListNode* getNode() + struct ChannelNodeDeleter { - if (p == head || p == tail) + ChannelNodePool* _pool; + + void operator()(ChannelNode* node) { - // no free node because we do not have any or we used all up - putNode(new ChannelIdListNode(_nextId++)); + _pool->releaseNode(node); } + }; - auto* prev = p->prev; - auto* next = p->next; - next->prev = prev; - prev->next = next; - - auto* ret = p; - p = next; - ret->inUse = true; - return ret; - } + using ChannelNodePtr = std::unique_ptr; - void putNode(ChannelIdListNode* node) - { - node->inUse = false; - node->channel = nullptr; - node->next = tail; - node->prev = tail->prev; - tail->prev->next = node; - tail->prev = node; - - if (p == tail) + ChannelNodePtr getFreeNode() { + const ChannelNodeDeleter deleter{this}; + + if (_freeList.empty()) { - p = node; + return ChannelNodePtr(new ChannelNode(_nextNodeId++), deleter); } + + auto* node = _freeList.front(); + _freeList.pop_front(); + return ChannelNodePtr(node, deleter); } - ~ChannelIdList() + void releaseNode(ChannelNode* node) { - auto* ptr = head; - while (ptr != nullptr) - { - auto* x = ptr; - ptr = ptr->next; - delete x; - } + if (node == nullptr) return; + node->reset(); + _freeList.push_front(node); } + + private: + int _nextNodeId = 0; + std::forward_list _freeList; }; struct Dispatcher { Poller* _poller; - VsbEvent* _events; - std::unordered_map _channels; - ChannelIdList _idman; + std::vector _events; + ChannelNodePool _idman; + std::unordered_map _channels; BlockingQueue> _tasksToRun; int maxNewConnectionPerLoop = 20; int scanAndCleanInterval = 20; - uint64_t _lastScanAndCleanGen; - uint64_t _currentGen; + int _currentGen = 0; int _name; Dispatcher(Poller* poller) : Dispatcher(0, poller) {} - Dispatcher(int name, Poller* poller) : _name(name), _poller(poller), _events(new VsbEvent[poller->maxEventsPerPoll()]) {} - - ChannelIdListNode* prepareChannel() + Dispatcher(int name, Poller* poller) : _name(name), _poller(poller), _events(poller->maxEventsPerPoll()) {} + + int name() const { - return _idman.getNode(); + return _name; } - void makeChannel(ChannelIdListNode* node, Socket* a, Socket* b, BlockingQueue* _taskQueue) + void postAddChannel(std::unique_ptr&& ap, std::unique_ptr(bp)) { - Logger::instance->Log(Logger::DEBUG, "creating channel id=", node->id, ", a.fd=", a->fd(), ", b.fd=", b->fd()); - auto* c = new DirectChannel(node->id, a, b, _taskQueue); - _channels[node->id] = node; - node->channel = c; - _poller->add(a->fd(), (void*)&c->_ha, IOEvent::InputReady | IOEvent::OutputReady); - _poller->add(b->fd(), (void*)&c->_hb, IOEvent::InputReady | IOEvent::OutputReady); + // Dispatcher::taskloop manages the channel map attached to the dispatcher + // connectToPeer modifies the map so we request taskloop thread to run it + runOnTaskLoop([this, ap = std::move(ap), bp = std::move(bp)]() mutable { addChannel(std::move(ap), std::move(bp)); }); } - - void destroyChannel(ChannelIdListNode* node) + + ChannelNode* addChannel(std::unique_ptr ap, std::unique_ptr bp) { - Logger::instance->Log(Logger::DEBUG, "destroying channel id=", node->channel->_id); - _idman.putNode(node); - delete node->channel; + ChannelNodePool::ChannelNodePtr node = _idman.getFreeNode(); + BlockingQueue* taskQueue = ThreadPool::getTaskQueue(node->_id); + + Logger::instance->Log(Logger::DEBUG, "creating channel id=", node->_id, ", a.fd=", ap->fd(), ", b.fd=", bp->fd()); + node->_channel = std::make_unique(node->_id, std::move(ap), std::move(bp), taskQueue); + + const auto& c = *node->_channel; + c._a->setPoller(_poller); + c._b->setPoller(_poller); + if (!_poller->add(c._a->fd(), (void*)&c._ha, IOEvent::InputReady | IOEvent::OutputReady) || + !_poller->add(c._b->fd(), (void*)&c._hb, IOEvent::InputReady | IOEvent::OutputReady)) + { + return nullptr; + } + + auto* const n = node.get(); + _channels[n->_id] = std::move(node); + return n; } - void runOnTaskLoop(std::function action) + template + void runOnTaskLoop(T&& action) { - _tasksToRun.enqueue(action); + auto wrapper = std::make_shared(std::forward(action)); + _tasksToRun.enqueue([wrapper] { (*wrapper)(); }); } void run() { - Logger::instance->Log(Logger::DEBUG, "dispatcher started"); + Logger::instance->Log(Logger::DEBUG, "dispatcher ", name(), " started"); for (;;) { taskloop(); @@ -144,70 +152,83 @@ namespace vsockio void taskloop() { - // Phase 1. poll IO events - int eventCount = _poller->poll(_events, getTimeout()); - if (eventCount == -1) - { + // handle events on existing channels + poll(); + + // complete new channels + processQueuedTasks(); + + // collect terminated channels + cleanup(); + } + + void poll() + { + const int eventCount = _poller->poll(_events.data(), getTimeout()); + if (eventCount == -1) { Logger::instance->Log(Logger::CRITICAL, "Poller returns error."); return; } - // Phase 2. find corresponding handler and process events - for (int i = 0; i < eventCount; i++) - { - auto* handle = static_cast(_events[i].data); + for (int i = 0; i < eventCount; i++) { + auto *handle = static_cast(_events[i].data); auto it = _channels.find(handle->channelId); - if (it == _channels.end() || !it->second->inUse || it->second->channel == nullptr) - { + if (it == _channels.end() || !it->second->inUse()) { Logger::instance->Log(Logger::WARNING, "Channel ID ", handle->channelId, " does not exist."); continue; } - auto* channel = it->second->channel; - channel->handle(handle->fd, _events[i].ioFlags); + auto &channel = *it->second->_channel; + channel.handle(handle->fd, _events[i].ioFlags); } + } - // Phase 3. complete newcoming connections - for (int i = 0; i < maxNewConnectionPerLoop; i++) - { + void processQueuedTasks() + { + for (int i = 0; i < maxNewConnectionPerLoop; i++) { // must check task count first, since we don't wanna block here - if (_tasksToRun.count() > 0) - { + if (!_tasksToRun.empty()) { auto action = _tasksToRun.dequeue(); action(); - } - else - { + } else { break; } } + } - // Phase 4. clean & remove terminated channels - if (_currentGen - _lastScanAndCleanGen == scanAndCleanInterval) + void cleanup() + { + if (_currentGen >= scanAndCleanInterval) { - std::vector keysToRemove; - for (auto it = _channels.begin(); it != _channels.end(); it++) + for (auto it = _channels.begin(); it != _channels.end(); ) { - auto* ch = it->second->channel; - if (ch != nullptr && ch->canBeTerminated()) + auto* node = it->second.get(); + if (!node->inUse() || node->_channel->canBeTerminated()) { - keysToRemove.push_back(it->first); - destroyChannel(it->second); + Logger::instance->Log(Logger::DEBUG, "destroying channel id=", it->first); + + // any resources allocated on channel thread must be freed there + if (node->inUse()) + { + node->_channel.release()->terminate(); + } + + it = _channels.erase(it); + } + else + { + ++it; } } - for (int id : keysToRemove) - { - _channels.erase(id); - } - _lastScanAndCleanGen = _currentGen; + _currentGen = 0; } _currentGen++; } int getTimeout() const { - bool hasPendingTask = - (_currentGen - _lastScanAndCleanGen == scanAndCleanInterval) || - (_tasksToRun.count() > 0); + const bool hasPendingTask = + (_currentGen >= scanAndCleanInterval) || + (!_tasksToRun.empty()); return hasPendingTask ? 0 : 16; } diff --git a/vsock-bridge/include/endpoint.h b/vsock-bridge/include/endpoint.h index 392c2dc..d4ea831 100644 --- a/vsock-bridge/include/endpoint.h +++ b/vsock-bridge/include/endpoint.h @@ -1,9 +1,13 @@ #pragma once -#include -#include #include +#include +#include +#include +#include +#include + namespace vsockio { struct Endpoint diff --git a/vsock-bridge/include/epoll_poller.h b/vsock-bridge/include/epoll_poller.h index a1f021b..e51ec5c 100644 --- a/vsock-bridge/include/epoll_poller.h +++ b/vsock-bridge/include/epoll_poller.h @@ -1,9 +1,12 @@ #pragma once -#include -#include +#include "logger.h" +#include "poller.h" + #include -#include +#include + +#include namespace vsockio { @@ -11,7 +14,6 @@ namespace vsockio { int _epollFd; - uint64_t _pollCounter; std::unique_ptr _epollEvents; EpollPoller(int maxEvents) : _epollEvents(new epoll_event[maxEvents]) @@ -25,28 +27,46 @@ namespace vsockio } } - int add(int fd, void* handler, uint32_t events) override + bool add(int fd, void* handler, uint32_t events) override { epoll_event ev; memset(&ev, 0, sizeof(epoll_event)); ev.data.ptr = handler; ev.events = vsb2epoll(events); - return epoll_ctl(_epollFd, EPOLL_CTL_ADD, fd, &ev); + if (epoll_ctl(_epollFd, EPOLL_CTL_ADD, fd, &ev) != 0) + { + const int err = errno; + Logger::instance->Log(Logger::ERROR, "epoll_ctl failed to add fd=", fd, ": ", strerror(err)); + return false; + } + + return true; } - int update(int fd, void* handler, uint32_t events) override + bool update(int fd, void* handler, uint32_t events) override { epoll_event ev; memset(&ev, 0, sizeof(epoll_event)); ev.data.ptr = handler; ev.events = vsb2epoll(events); - return epoll_ctl(_epollFd, EPOLL_CTL_MOD, fd, &ev); + if (epoll_ctl(_epollFd, EPOLL_CTL_MOD, fd, &ev) != 0) + { + const int err = errno; + Logger::instance->Log(Logger::ERROR, "epoll_ctl failed to update fd=", fd, ": ", strerror(err)); + return false; + } + + return true; } void remove(int fd) override { epoll_event ev; - epoll_ctl(_epollFd, EPOLL_CTL_DEL, fd, &ev); + if (epoll_ctl(_epollFd, EPOLL_CTL_DEL, fd, &ev) != 0) + { + const int err = errno; + Logger::instance->Log(Logger::ERROR, "epoll_ctl failed to delete fd=", fd, ": ", strerror(err)); + } } int poll(VsbEvent* outEvents, int timeout) override diff --git a/vsock-bridge/include/listener.h b/vsock-bridge/include/listener.h index de18177..4080481 100644 --- a/vsock-bridge/include/listener.h +++ b/vsock-bridge/include/listener.h @@ -1,68 +1,62 @@ #pragma once -#include +#include "channel.h" +#include "dispatcher.h" +#include "endpoint.h" +#include "epoll_poller.h" +#include "logger.h" + +#include + #include #include +#include #include -#include -#include -#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include +#include +#include namespace vsockio { - struct IOControl - { - static int setNonBlocking(int fd) - { - int flags = fcntl(fd, F_GETFL, 0); - if (flags == -1) - { - int err = errno; - Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); - return -1; - } - if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) - { - int err = errno; - Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); - return -1; - } - return 0; - } - - static int setBlocking(int fd) - { - int flags = fcntl(fd, F_GETFL, 0); - if (flags == -1) - { - int err = errno; - Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); - return -1; - } - if (fcntl(fd, F_SETFL, flags & ~O_NONBLOCK) == -1) - { - int err = errno; - Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); - return -1; - } - return 0; - } - }; + struct IOControl { + static bool setNonBlocking(int fd) { + const int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) { + const int err = errno; + Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); + return false; + } + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) { + const int err = errno; + Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); + return false; + } + return true; + } + + static int setBlocking(int fd) { + const int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) { + int err = errno; + Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); + return false; + } + if (fcntl(fd, F_SETFL, flags & ~O_NONBLOCK) == -1) { + int err = errno; + Logger::instance->Log(Logger::ERROR, "fcntl error: ", strerror(err)); + return false; + } + return true; + } + }; struct Listener { const int MAX_POLLER_EVENTS = 256; - const int SO_BACKLOG = 5; + const int SO_BACKLOG = 64; Listener(std::unique_ptr&& listenEndpoint, std::unique_ptr&& connectEndpoint, std::vector& dispatchers) : _fd(-1) @@ -73,42 +67,59 @@ namespace vsockio , _dispatchers(dispatchers) , _dispatcherIdRr(0) { - int fd = _listenEp->getSocket(); + const int fd = _listenEp->getSocket(); + if (fd < 0) + { + throw std::runtime_error("failed to get listener socket"); + } int enable = 1; if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)) < 0) { - Logger::instance->Log(Logger::ERROR, "error setting SO_REUSEADDR"); - close(fd); + close(fd); + throw std::runtime_error("error setting SO_REUSEADDR"); } std::pair addressAndLen = _listenEp->getAddress(); if (bind(fd, addressAndLen.first, addressAndLen.second) < 0) { - int err = errno; - Logger::instance->Log(Logger::ERROR, "Failed to bind new Listener on ", _listenEp->describe(), ": ", strerror(err)); - close(fd); - return; + const int err = errno; + close(fd); + Logger::instance->Log(Logger::ERROR, "failed to bind on ", _listenEp->describe(), ": ", strerror(err)); + throw std::runtime_error("failed to bind"); } /* listener fd is blocking intentially */ - IOControl::setBlocking(fd); + if (!IOControl::setBlocking(fd)) + { + throw std::runtime_error("failed to set blocking"); + } _fd = fd; } + Listener(const Listener&) = delete; + Listener& operator=(const Listener&) = delete; + + ~Listener() + { + if (_fd >= 0) + { + close(_fd); + } + } + void run() { if (listen(_fd, SO_BACKLOG) < 0) { - int err = errno; - Logger::instance->Log(Logger::ERROR, "Failed to listen on ", _listenEp->describe(), ": ", strerror(err)); - close(_fd); - return; + const int err = errno; + Logger::instance->Log(Logger::ERROR, "failed to listen on ", _listenEp->describe(), ": ", strerror(err)); + throw std::runtime_error("failed to listen"); } - Logger::instance->Log(Logger::DEBUG, "listening on ", _listenEp->describe(), ", fd=", _fd); + Logger::instance->Log(Logger::INFO, "listening on ", _listenEp->describe(), ", fd=", _fd); // accept loop for (;;) @@ -121,11 +132,11 @@ namespace vsockio { // accepted connection should have the same protocol with listen endpoint auto addrAndLen = _listenEpClone->getWritableAddress(); - int clientFd = accept(_fd, addrAndLen.first, &addrAndLen.second); + const int clientFd = accept(_fd, addrAndLen.first, &addrAndLen.second); if (clientFd == -1) { - int err = errno; + const int err = errno; if (err == EAGAIN || err == EWOULDBLOCK) { // nothing to accept @@ -134,56 +145,64 @@ namespace vsockio else { // accept failed - Logger::instance->Log(Logger::ERROR, "error during accept: ", strerror(err)); + Logger::instance->Log(Logger::ERROR, "error during accept (fd=", _fd, "): ", strerror(err)); return; } } - IOControl::setNonBlocking(clientFd); - Socket* socket = new Socket(clientFd, SocketImpl::singleton); + auto inPeer = std::make_unique(clientFd, *SocketImpl::singleton); + if (!IOControl::setNonBlocking(clientFd)) + { + Logger::instance->Log(Logger::ERROR, "failed to set non-blocking mode (fd=", clientFd, ")"); + return; + } - int dpId = (_dispatcherIdRr++) % _dispatchers.size(); - auto* dp = _dispatchers[dpId]; + auto outPeer = connectToPeer(); + if (!outPeer) + { + return; + } - ChannelIdListNode* idHandle = dp->prepareChannel(); - Logger::instance->Log(Logger::DEBUG, "Dispatcher ", dpId, " will handle channel ", idHandle->id); - // Dispatcher::taskloop manages the channel map attached to the dispatcher - // connectToPeer modifies the map so we request taskloop thread to run it - dp->runOnTaskLoop([this, socket, idHandle, dp]() { connectToPeer(socket, idHandle, dp); }); - } + const int dpId = (_dispatcherIdRr++) % _dispatchers.size(); + auto* const dp = _dispatchers[dpId]; - void connectToPeer(Socket* inPeer, ChannelIdListNode* idHandle, Dispatcher* dispatcher) - { - int fd = _connectEp->getSocket(); + Logger::instance->Log(Logger::DEBUG, "Dispatcher ", dpId, " will handle channel for accepted connection fd=", inPeer->fd(), ", peer fd=", outPeer->fd()); + dp->postAddChannel(std::move(inPeer), std::move(outPeer)); + } + + std::unique_ptr connectToPeer() + { + const int fd = _connectEp->getSocket(); if (fd == -1) { - Logger::instance->Log(Logger::ERROR, "creating new socket failed."); - inPeer->shutdown(); - delete inPeer; - return; + Logger::instance->Log(Logger::ERROR, "creating remote socket failed"); + return nullptr; } - IOControl::setNonBlocking(fd); + auto peer = std::make_unique(fd, *SocketImpl::singleton); + + if (!IOControl::setNonBlocking(fd)) + { + Logger::instance->Log(Logger::ERROR, "failed to set non-blocking mode (fd=", fd, ")"); + return nullptr; + } auto addrAndLen = _connectEp->getAddress(); int status = connect(fd, addrAndLen.first, addrAndLen.second); if (status == 0 || (status = errno) == EINPROGRESS) { - Logger::instance->Log(Logger::DEBUG, "connected to remote endpoint with status=", status); - Socket* outPeer = new Socket(fd, SocketImpl::singleton); - dispatcher->makeChannel(idHandle, inPeer, outPeer, ThreadPool::getTaskQueue(idHandle->id)); + Logger::instance->Log(Logger::DEBUG, "connected to remote endpoint (fd=", fd, ") with status=", status); + return peer; } else { - Logger::instance->Log(Logger::WARNING, "failed to connect to remote endpoint"); - close(fd); - inPeer->shutdown(); - delete inPeer; + Logger::instance->Log(Logger::WARNING, "failed to connect to remote endpoint (fd=", fd, "): ", strerror(status)); + return nullptr; } } - inline bool listening() const { return _fd > 0; } + inline bool listening() const { return _fd >= 0; } int _fd; std::unique_ptr _listenEp; diff --git a/vsock-bridge/include/logger.h b/vsock-bridge/include/logger.h index c2862ce..7eb44fe 100644 --- a/vsock-bridge/include/logger.h +++ b/vsock-bridge/include/logger.h @@ -1,20 +1,22 @@ #pragma once -#include +#include +#include +#include #include +#include #include -#include #include -#include +#include + #include struct LoggingStream { - virtual std::ostream& getStream(int level) = 0; + virtual std::ostream& startLog(int level) = 0; }; -struct Logger -{ +struct Logger { enum { DEBUG = 0, INFO = 1, @@ -30,62 +32,74 @@ struct Logger Logger() : _streamProvider(nullptr), _minLevel(DEBUG) {} - void setMinLevel(int minLevel) - { + void setMinLevel(int minLevel) { _minLevel = minLevel; } + static const char *getLogLevelStr(int level) + { + switch (level) + { + case DEBUG: return "DEBG"; + case INFO: return "INFO"; + case WARNING: return "WARN"; + case ERROR: return "ERRR"; + case CRITICAL: return "CRIT"; + default: return "UNKN"; + } + } + void setStreamProvider(LoggingStream* streamProvider) { _streamProvider = streamProvider; } template - void Log(int level, T0&& m0) + void Log(int level, const T0& m0) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->getStream(level) << m0 << std::endl; + _streamProvider->startLog(level) << m0 << std::endl; } template - void Log(int level, T0&& m0, T1&& m1) + void Log(int level, const T0& m0, const T1& m1) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->getStream(level) << m0 << m1 << std::endl; + _streamProvider->startLog(level) << m0 << m1 << std::endl; } template - void Log(int level, T0&& m0, T1&& m1, T2&& m2) + void Log(int level, const T0& m0, const T1& m1, const T2& m2) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->getStream(level) << m0 << m1 << m2 << std::endl; + _streamProvider->startLog(level) << m0 << m1 << m2 << std::endl; } template - void Log(int level, T0&& m0, T1&& m1, T2&& m2, T3&& m3) + void Log(int level, const T0& m0, const T1& m1, const T2& m2, const T3& m3) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->getStream(level) << m0 << m1 << m2 << m3 << std::endl; + _streamProvider->startLog(level) << m0 << m1 << m2 << m3 << std::endl; } template - void Log(int level, T0&& m0, T1&& m1, T2&& m2, T3&& m3, T4&& m4) + void Log(int level, const T0& m0, const T1& m1, const T2& m2, const T3& m3, const T4& m4) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->getStream(level) << m0 << m1 << m2 << m3 << m4 << std::endl; + _streamProvider->startLog(level) << m0 << m1 << m2 << m3 << m4 << std::endl; } template - void Log(int level, T0&& m0, T1&& m1, T2&& m2, T3&& m3, T4&& m4, T5&& m5) + void Log(int level, const T0& m0, const T1& m1, const T2& m2, const T3& m3, const T4& m4, const T5& m5) { if (level < _minLevel || _streamProvider == nullptr) return; std::lock_guard lk(_lock); - _streamProvider->getStream(level) << m0 << m1 << m2 << m3 << m4 << m5 << std::endl; + _streamProvider->startLog(level) << m0 << m1 << m2 << m3 << m4 << m5 << std::endl; } }; @@ -95,16 +109,16 @@ class RSyslogBuf : public std::stringbuf public: virtual int sync() override { - std::lock_guard lock(mut); - syslog(log_level, this->str().c_str()); + std::lock_guard lock(_mut); + syslog(_logLevel, this->str().c_str()); this->str(std::string()); return 0; } - RSyslogBuf(int level) : log_level(level) {} + explicit RSyslogBuf(int level) : _logLevel(level) {} private: - std::mutex mut; - int log_level; + std::mutex _mut; + int _logLevel; }; struct NullStream : public std::ostream { @@ -116,27 +130,35 @@ struct NullStream : public std::ostream { struct StdoutLogger : public LoggingStream { - std::ostream& getStream(int level) override + std::ostream& startLog(int level) override { - return std::cout; + const std::time_t t = std::time(0); + const std::tm* const now = std::localtime(&t); + auto& s = std::cout; + const auto prevfill = s.fill('0'); + s << std::setw(4) << (now->tm_year + 1900) << '-' << std::setw(2) << (now->tm_mon + 1) << '-' << std::setw(2) << now->tm_mday + << ' ' << std::setw(2) << now->tm_hour << ':' << std::setw(2) << now->tm_min << ':' << std::setw(2) << now->tm_sec + << " [" << Logger::getLogLevelStr(level) << "] "; + s.fill(prevfill); + return s; } }; struct RSyslogLogger : public LoggingStream { - std::ostream& getStream(int level) override + std::ostream& startLog(int level) override { - if (level == Logger::DEBUG) return debug; - if (level == Logger::INFO) return info; - if (level == Logger::WARNING) return warn; - if (level == Logger::ERROR) return error; - if (level == Logger::CRITICAL) return critical; - return null_stream; + if (level == Logger::DEBUG) return _debug; + if (level == Logger::INFO) return _info; + if (level == Logger::WARNING) return _warn; + if (level == Logger::ERROR) return _error; + if (level == Logger::CRITICAL) return _critical; + return _nullStream; } - RSyslogLogger(const char* name) - : debugb(LOG_DEBUG), infob(LOG_INFO), warnb(LOG_WARNING), errorb(LOG_ERR), criticalb(LOG_CRIT) - , debug(&debugb), info(&infob), warn(&warnb), error(&errorb), critical(&criticalb) + explicit RSyslogLogger(const char* name) + : _debugb(LOG_DEBUG), _infob(LOG_INFO), _warnb(LOG_WARNING), _errorb(LOG_ERR), _criticalb(LOG_CRIT) + , _debug(&_debugb), _info(&_infob), _warn(&_warnb), _error(&_errorb), _critical(&_criticalb) { openlog(name, LOG_CONS | LOG_PID | LOG_NDELAY, LOG_USER); } @@ -146,15 +168,15 @@ struct RSyslogLogger : public LoggingStream closelog(); } - std::ostream debug; - std::ostream info; - std::ostream warn; - std::ostream error; - std::ostream critical; - RSyslogBuf debugb; - RSyslogBuf infob; - RSyslogBuf warnb; - RSyslogBuf errorb; - RSyslogBuf criticalb; - NullStream null_stream; + RSyslogBuf _debugb; + RSyslogBuf _infob; + RSyslogBuf _warnb; + RSyslogBuf _errorb; + RSyslogBuf _criticalb; + std::ostream _debug; + std::ostream _info; + std::ostream _warn; + std::ostream _error; + std::ostream _critical; + NullStream _nullStream; }; \ No newline at end of file diff --git a/vsock-bridge/include/peer.h b/vsock-bridge/include/peer.h index b1b5159..6ea5ce6 100644 --- a/vsock-bridge/include/peer.h +++ b/vsock-bridge/include/peer.h @@ -1,12 +1,14 @@ #pragma once +#include "buffer.h" + #include +#include #include #include -#include #include -#include +#include namespace vsockio { @@ -17,7 +19,7 @@ namespace vsockio std::list _list; - int _count; + ssize_t _count; UniquePtrQueue() : _count(0) {} @@ -31,7 +33,7 @@ namespace vsockio return _list.front(); } - void enqueue(TPtr& value) + void enqueue(TPtr&& value) { _count++; _list.push_back(std::move(value)); @@ -52,84 +54,74 @@ namespace vsockio }; template - struct Peer + class Peer { - Peer(int inputState, int outputState) - : _inputReady(inputState), _outputReady(outputState) - , _inputClosed(false), _outputClosed(false) - , _queueFull(false), _peer(nullptr), _ioEventCount(0) {} + public: + Peer() = default; + + Peer(const Peer&) = delete; + Peer& operator=(const Peer&) = delete; + + virtual ~Peer() {} void onInputReady() { + assert(_peer != nullptr); + _inputReady = true; - bool continuation = false; - do - { - readFromInput(continuation); - if (continuation) - { - _peer->writeToOutput(continuation); - } - } while (continuation); - _ioEventCount--; + while (readFromInput() && _peer->writeToOutput()) + ; + --_ioEventCount; } void onOutputReady() { + assert(_peer != nullptr); + _outputReady = true; - bool continuation = false; - do - { - writeToOutput(continuation); - if (continuation) - { - _peer->readFromInput(continuation); - } - } while (continuation); - _ioEventCount--; + while (writeToOutput() && _peer->readFromInput()) + ; + --_ioEventCount; } - void onError() + inline void setPeer(Peer* p) { - shutdown(); - _ioEventCount--; + _peer = p; } - virtual void shutdown() = 0; - - virtual void onPeerShutdown() = 0; + inline void onIoEvent() { ++_ioEventCount; } - virtual void readFromInput(bool& continuation) = 0; + inline int ioEventCount() const { return _ioEventCount.load(); } - virtual void writeToOutput(bool& continuation) = 0; + virtual void close() = 0; - virtual void queue(TBuf& buffer) = 0; + bool closed() const { return _inputClosed && _outputClosed; } bool inputClosed() const { return _inputClosed; } bool outputClosed() const { return _outputClosed; } - bool closed() const { return _inputClosed && _outputClosed; } + virtual void onPeerClosed() = 0; - bool queueFull() const { return _queueFull; } + virtual void queue(TBuf&& buffer) = 0; - virtual ~Peer() {} - - inline void setPeer(Peer* p) { _peer = p; } + bool queueFull() const { return _queueFull; } + virtual bool queueEmpty() const = 0; - inline void incrementEventCount() { _ioEventCount++; } + protected: + virtual bool readFromInput() = 0; - inline int ioEventCount() const { return _ioEventCount.load(); } + virtual bool writeToOutput() = 0; protected: - bool _inputReady; - bool _outputReady; - int _inputClosed; - int _outputClosed; - bool _queueFull; - Peer* _peer; + bool _inputReady = false; + bool _outputReady = false; + bool _inputClosed = false; + bool _outputClosed = false; + bool _queueFull = false; + Peer* _peer = nullptr; private: - std::atomic_int _ioEventCount; + std::atomic_int _ioEventCount{0}; }; } \ No newline at end of file diff --git a/vsock-bridge/include/poller.h b/vsock-bridge/include/poller.h index 1382091..e8af5a4 100644 --- a/vsock-bridge/include/poller.h +++ b/vsock-bridge/include/poller.h @@ -1,14 +1,14 @@ #pragma once -#include +#include "eventdef.h" namespace vsockio { struct Poller { - virtual int add(int fd, void* handler, uint32_t events) = 0; + virtual bool add(int fd, void* handler, uint32_t events) = 0; - virtual int update(int fd, void* handler, uint32_t events) = 0; + virtual bool update(int fd, void* handler, uint32_t events) = 0; virtual void remove(int fd) = 0; diff --git a/vsock-bridge/include/proxy.h b/vsock-bridge/include/proxy.h deleted file mode 100644 index 591e0a8..0000000 --- a/vsock-bridge/include/proxy.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include - -namespace vsockio -{ - struct Proxy - { - int _numProcThreads; - }; -} \ No newline at end of file diff --git a/vsock-bridge/include/socket.h b/vsock-bridge/include/socket.h index 147d327..fd0a03b 100644 --- a/vsock-bridge/include/socket.h +++ b/vsock-bridge/include/socket.h @@ -1,7 +1,10 @@ #pragma once +#include "peer.h" +#include "poller.h" + #include -#include +#include namespace vsockio { @@ -25,37 +28,47 @@ namespace vsockio close(closeImpl) {} }; - struct Socket : public Peer> + class Socket : public Peer> { - void readFromInput(bool& continuation) override; + public: + Socket(int fd, SocketImpl& impl); - void writeToOutput(bool& continuation) override; + Socket(const Socket&) = delete; + Socket& operator=(const Socket&) = delete; - void shutdown() override; + ~Socket(); - void onPeerShutdown() override; + inline int fd() const { return _fd; } - void queue(std::unique_ptr& buffer) override; + void close() override; - std::unique_ptr read(); + bool queueEmpty() const override { return _sendQueue.empty(); } - void send(std::unique_ptr& buffer); + void setPoller(Poller* poller) + { + _poller = poller; + } - void closeInput(); + protected: + bool readFromInput() override; - void closeOutput(); + bool writeToOutput() override; - inline int fd() const { return _fd; } + void onPeerClosed() override; - Socket(int fd, SocketImpl* impl); + void queue(std::unique_ptr&& buffer) override; - Socket(const Socket& _) = delete; + private: + std::unique_ptr read(); - ~Socket(); + void send(Buffer& buffer); + + void closeInput(); private: - SocketImpl* _impl; + SocketImpl& _impl; UniquePtrQueue _sendQueue; int _fd; + Poller* _poller = nullptr; }; } \ No newline at end of file diff --git a/vsock-bridge/include/threading.h b/vsock-bridge/include/threading.h index 689b79a..ffe45d5 100644 --- a/vsock-bridge/include/threading.h +++ b/vsock-bridge/include/threading.h @@ -1,12 +1,13 @@ #pragma once +#include +#include #include +#include +#include #include #include -#include -#include #include -#include namespace vsockio { @@ -24,8 +25,8 @@ namespace vsockio { { std::lock_guard lk(_queueLock); - _list.push_back(value); - _count++; + _list.push_back(std::move(value)); + ++_count; } _signal.notify_one(); } @@ -40,7 +41,7 @@ namespace vsockio T p = _list.front(); _list.pop_front(); - _count--; + --_count; lk.unlock(); return p; @@ -53,25 +54,31 @@ namespace vsockio bool empty() const { - return _list.empty(); + return _count <= 0; } }; struct WorkerThread { std::function _initCallback; - BlockingQueue>* _taskQueue; - bool _retired; - std::thread* t; + BlockingQueue> _taskQueue; + bool _retired = false; + uint64_t _eventsProcessed = 0; + std::thread t; - uint64_t _eventsProcessed; uint64_t eventsProcessed() const { return _eventsProcessed; } WorkerThread(std::function initCallback) - : _initCallback(initCallback), _taskQueue(new BlockingQueue>), _retired(false) + : _initCallback(initCallback), t([this] { run(); }) { - _eventsProcessed = 0; - t = new std::thread(&WorkerThread::run, this); + } + + ~WorkerThread() + { + if (t.joinable()) + { + t.join(); + } } void run() @@ -80,7 +87,7 @@ namespace vsockio while (!_retired) { - auto action = _taskQueue->dequeue(); + auto action = _taskQueue.dequeue(); action(); _eventsProcessed++; } @@ -89,18 +96,18 @@ namespace vsockio void stop() { _retired = true; - _taskQueue->enqueue([](){}); + _taskQueue.enqueue([](){}); } - BlockingQueue>* getQueue() const + BlockingQueue>* getQueue() { - return _taskQueue; + return &_taskQueue; } }; struct ThreadPool { - static std::vector threads; + static std::vector> threads; static BlockingQueue>* getTaskQueue(int taskId) { return threads[taskId % ThreadPool::threads.size()]->getQueue(); diff --git a/vsock-bridge/include/vsock-bridge.h b/vsock-bridge/include/vsock-bridge.h index e89e31e..a6ecbb7 100644 --- a/vsock-bridge/include/vsock-bridge.h +++ b/vsock-bridge/include/vsock-bridge.h @@ -1,9 +1,10 @@ #pragma once -#include -#include -#include -#include +#include "config.h" +#include "peer.h" +#include "listener.h" +#include "logger.h" +#include "socket.h" + #include -#include #include \ No newline at end of file diff --git a/vsock-bridge/src/config.cpp b/vsock-bridge/src/config.cpp index 964941b..ab83780 100644 --- a/vsock-bridge/src/config.cpp +++ b/vsock-bridge/src/config.cpp @@ -1,9 +1,10 @@ -#include -#include -#include +#include "config.h" +#include "logger.h" + #include -#include +#include #include +#include namespace vsockproxy { diff --git a/vsock-bridge/src/epoll_poller.cpp b/vsock-bridge/src/epoll_poller.cpp index e69de29..7341257 100644 --- a/vsock-bridge/src/epoll_poller.cpp +++ b/vsock-bridge/src/epoll_poller.cpp @@ -0,0 +1 @@ +#include "epoll_poller.h" diff --git a/vsock-bridge/src/global.cpp b/vsock-bridge/src/global.cpp index 7445678..f5ef2d5 100644 --- a/vsock-bridge/src/global.cpp +++ b/vsock-bridge/src/global.cpp @@ -7,7 +7,7 @@ using namespace vsockio; thread_local MemoryArena* BufferManager::arena = new MemoryArena(); -std::vector ThreadPool::threads; +std::vector> ThreadPool::threads; SocketImpl* SocketImpl::singleton = new SocketImpl( /*read: */ [](int fd, void* buf, int len) { return ::read(fd, buf, len); }, diff --git a/vsock-bridge/src/socket.cpp b/vsock-bridge/src/socket.cpp index 11e5b47..cb79d65 100644 --- a/vsock-bridge/src/socket.cpp +++ b/vsock-bridge/src/socket.cpp @@ -1,69 +1,71 @@ -#include -#include +#include "logger.h" +#include "socket.h" + +#include #include +#include + namespace vsockio { - Socket::Socket(int fd, SocketImpl* impl) + Socket::Socket(int fd, SocketImpl& impl) : _fd(fd) , _impl(impl) - , Peer(false, false) - {} - - void Socket::readFromInput(bool& continuation) { - continuation = false; + assert(_fd >= 0); + } - if (_peer->closed()) + bool Socket::readFromInput() + { + if (_peer->outputClosed() && !inputClosed()) { - Logger::instance->Log(Logger::DEBUG, "shutdown 1"); - shutdown(); + Logger::instance->Log(Logger::DEBUG, "[socket] readToInput detected output peer closed, closing input (fd=", _fd, ")"); + closeInput(); + return false; } - if (_inputClosed) return; + if (_inputClosed) return false; + + bool hasInput = false; while (!_inputClosed && _inputReady && !_peer->queueFull()) { std::unique_ptr buffer{ read() }; - if (buffer != nullptr && buffer->size() > 0) + if (buffer && !buffer->empty()) { - _peer->queue(buffer); - continuation = true; + _peer->queue(std::move(buffer)); + hasInput = true; } } - if (_inputClosed && !_peer->outputClosed()) + if (_inputClosed) { - Logger::instance->Log(Logger::DEBUG, "[socket] sending termination from (fd=", _fd, ")"); - std::unique_ptr termination{ BufferManager::getEmptyBuffer() }; - _peer->queue(termination); - - Logger::instance->Log(Logger::DEBUG, "shutdown 2"); - shutdown(); - continuation = true; + Logger::instance->Log(Logger::DEBUG, "[socket] readToInput detected input closed, closing (fd=", _fd, ")"); + close(); } + + return hasInput; } - void Socket::writeToOutput(bool& continuation) + bool Socket::writeToOutput() { - continuation = false; + if (_outputClosed) return false; - if (_outputClosed) return; while (!_outputClosed && _outputReady && !_sendQueue.empty()) { std::unique_ptr& buffer = _sendQueue.front(); // received termination signal from peer - if (buffer->size() == 0) + if (buffer->empty()) { Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput dequeued a termination buffer (fd=", _fd, ")"); - Logger::instance->Log(Logger::DEBUG, "shutdown 3"); - shutdown(); + _sendQueue.dequeue(); + close(); break; } else { - send(buffer); - if (buffer->cursor() == buffer->size()) + send(*buffer); + if (buffer->consumed()) { _sendQueue.dequeue(); _queueFull = false; @@ -71,18 +73,28 @@ namespace vsockio } } - if (_peer->closed() && _sendQueue.empty()) + if (_peer->closed()) { - Logger::instance->Log(Logger::DEBUG, "shutdown 4"); - shutdown(); + if (_sendQueue.empty()) + { + Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput detected input peer is closed, closing (fd=", _fd, ")"); + close(); + } + else if (!_peer->queueEmpty()) + { + // Peer has some queued data they never received + // Assuming this data is critical for the protocol, it should be ok to abort the connection straight away + Logger::instance->Log(Logger::DEBUG, "[socket] writeToOutput detected input peer is closed while having data remaining, closing (fd=", _fd, ")"); + close(); + } } - continuation = _sendQueue.empty(); + return _sendQueue.empty(); } - void Socket::queue(std::unique_ptr& buffer) + void Socket::queue(std::unique_ptr&& buffer) { - _sendQueue.enqueue(buffer); + _sendQueue.enqueue(std::move(buffer)); // to simplify logic we allow only 1 buffer for socket sinks _queueFull = true; @@ -91,20 +103,19 @@ namespace vsockio std::unique_ptr Socket::read() { std::unique_ptr buffer{ BufferManager::getBuffer() }; - ssize_t bytesRead; - ssize_t totalBytes = 0; while (true) { - bytesRead = _impl->read(_fd, buffer->offset(totalBytes), (ssize_t)buffer->capacity() - totalBytes); + const int bytesRead = _impl.read(_fd, buffer->tail(), buffer->remainingCapacity()); int err = 0; if (bytesRead > 0) { // New content read // update byte count and enlarge buffer if needed - totalBytes += bytesRead; - if (totalBytes == buffer->capacity() && !buffer->tryNewPage()) + //Logger::instance->Log(Logger::DEBUG, "[socket] read returns ", bytesRead, " (fd=", _fd, ")"); + buffer->produce(bytesRead); + if (!buffer->ensureCapacity()) { break; } @@ -112,9 +123,8 @@ namespace vsockio else if (bytesRead == 0) { // Source closed - // Shutdown ourself and queue close message to peer - Logger::instance->Log(Logger::DEBUG, "[socket] read returns 0 (fd=", _fd, ")"); + Logger::instance->Log(Logger::DEBUG, "[socket] read returns 0, closing input (fd=", _fd, ")"); closeInput(); break; } @@ -129,27 +139,20 @@ namespace vsockio { // Error - Logger::instance->Log(Logger::WARNING, "error on read: ", strerror(err)); + Logger::instance->Log(Logger::WARNING, "[socket] error on read, closing input (fd=", _fd, "): ", strerror(err)); closeInput(); break; } } - buffer->setSize(totalBytes); - buffer->setCursor(0); return buffer; } - void Socket::send(std::unique_ptr& buffer) + void Socket::send(Buffer& buffer) { - ssize_t bytesWritten; - ssize_t totalBytes = buffer->cursor(); - while (true) + while (!buffer.consumed()) { - ssize_t pageLimit = buffer->pageLimit(totalBytes); - ssize_t dataRemaining = buffer->size() - totalBytes; - ssize_t lengthToWrite = pageLimit < dataRemaining ? pageLimit : dataRemaining; - bytesWritten = _impl->write(_fd, buffer->offset(totalBytes), lengthToWrite); + const int bytesWritten = _impl.write(_fd, buffer.head(), buffer.headLimit()); int err = 0; if (bytesWritten > 0) @@ -157,17 +160,12 @@ namespace vsockio // Some data written to downstream // log bytes written and move cursor forward - totalBytes += bytesWritten; - buffer->setCursor(totalBytes); - if (totalBytes == buffer->size()) - { - break; - } + //Logger::instance->Log(Logger::DEBUG, "[socket] write returns ", bytesWritten, " (fd=", _fd, ")"); + buffer.consume(bytesWritten); } else if((err = errno) == EAGAIN || err == EWOULDBLOCK) { // Write blocked - buffer->setCursor(totalBytes); _outputReady = false; break; } @@ -175,10 +173,8 @@ namespace vsockio { // Error - Logger::instance->Log(Logger::WARNING, "error on send: ", strerror(err)); - buffer->setCursor(totalBytes); - Logger::instance->Log(Logger::DEBUG, "shutdown 5"); - shutdown(); + Logger::instance->Log(Logger::WARNING, "[socket] error on send, closing (fd=", _fd, "): ", strerror(err)); + close(); break; } } @@ -190,16 +186,7 @@ namespace vsockio _inputClosed = true; } - void Socket::closeOutput() - { - if (_inputClosed) - { - Logger::instance->Log(Logger::DEBUG, "shutdown 6"); - shutdown(); - } - } - - void Socket::shutdown() + void Socket::close() { _inputReady = false; _outputReady = false; @@ -208,27 +195,49 @@ namespace vsockio { _inputClosed = true; _outputClosed = true; - Logger::instance->Log(Logger::DEBUG, "socket shutdown, fd=", _fd); - _impl->close(_fd); + if (_poller) + { + // epoll is meant to automatically deregister sockets on close, but apparently some systems + // have bugs around this, so do it explicitly + Logger::instance->Log(Logger::DEBUG, "[socket] remove from poller (fd=", _fd, ")"); + _poller->remove(_fd); + } + + Logger::instance->Log(Logger::DEBUG, "[socket] close, fd=", _fd); + _impl.close(_fd); if (_peer != nullptr) { - _peer->onPeerShutdown(); + _peer->onPeerClosed(); } } } - void Socket::onPeerShutdown() + void Socket::onPeerClosed() { if (!closed()) { - Logger::instance->Log(Logger::DEBUG, "[socket] sending termination from (fd=", _fd, ")"); + Logger::instance->Log(Logger::DEBUG, "[socket] sending termination for (fd=", _fd, ")"); std::unique_ptr termination{ BufferManager::getEmptyBuffer() }; - queue(termination); - bool _; - writeToOutput(_); + queue(std::move(termination)); + + // force process the queue + _outputReady = true; + writeToOutput(); } } - Socket::~Socket() {} + Socket::~Socket() + { + if (!closed()) + { + Logger::instance->Log(Logger::WARNING, "[socket] closing on destruction (fd=", _fd, ")"); + close(); + } + + if (_peer != nullptr) + { + _peer->setPeer(nullptr); + } + } } \ No newline at end of file diff --git a/vsock-bridge/src/vsock-bridge.cpp b/vsock-bridge/src/vsock-bridge.cpp index 1258d9d..f5bdda2 100644 --- a/vsock-bridge/src/vsock-bridge.cpp +++ b/vsock-bridge/src/vsock-bridge.cpp @@ -56,12 +56,12 @@ void start_services(std::vector& services, int numIOThreads for (int i = 0; i < numWorkers; i++) { - WorkerThread* t = new WorkerThread( + auto t = std::make_unique( /*init:*/ []() { BufferManager::arena->init(512, 2000); } ); - ThreadPool::threads.push_back(t); + ThreadPool::threads.push_back(std::move(t)); } for (auto& sd : services) @@ -69,7 +69,7 @@ void start_services(std::vector& services, int numIOThreads std::vector* dispatchers = new std::vector(); for (int i = 0; i < 1; i++) { - Dispatcher* d = new Dispatcher(i + 1, new EpollPoller(VSB_MAX_POLL_EVENTS)); + Dispatcher* d = new Dispatcher(i, new EpollPoller(VSB_MAX_POLL_EVENTS)); dispatchers->push_back(d); } @@ -190,7 +190,7 @@ int main(int argc, char* argv[]) { quit_bad_args("invalid log level, must be 0, 1, 2, 3 or 4", false); } - if (min_log_level < 0 && min_log_level > 4) + if (min_log_level < 0 || min_log_level > 4) { quit_bad_args("invalid log level, must be 0, 1, 2, 3 or 4", false); } @@ -210,7 +210,7 @@ int main(int argc, char* argv[]) { quit_bad_args("invalid io thread count, must be number > 0", false); } - if (min_log_level < 0 && min_log_level > 4) + if (num_iothreads <= 0) { quit_bad_args("invalid io thread count, must be number > 0", false); } diff --git a/vsock-bridge/test/mocks.h b/vsock-bridge/test/mocks.h index 6a9e6d2..b45725e 100644 --- a/vsock-bridge/test/mocks.h +++ b/vsock-bridge/test/mocks.h @@ -26,22 +26,22 @@ struct MockPoller : public Poller _maxEvents = maxEvents; } - int add(int fd, void* handler, uint32_t events) override + bool add(int fd, void* handler, uint32_t events) override { Logger::instance->Log(Logger::INFO, "add: ", fd, ",", (uint64_t)handler, ",", events); _fdMap[fd].fd = fd; _fdMap[fd].handler = handler; _fdMap[fd].listeningEvents = events; - return 0; + return true; } - int update(int fd, void* handler, uint32_t events) override + bool update(int fd, void* handler, uint32_t events) override { Logger::instance->Log(Logger::INFO, "update: ", fd, ",", (uint64_t)handler, ",", events); _fdMap[fd].fd = fd; _fdMap[fd].handler = handler; _fdMap[fd].listeningEvents = events; - return 0; + return true; } void remove(int fd) override diff --git a/vsock-bridge/test/testmain.cpp b/vsock-bridge/test/testmain.cpp index 7070e4d..da8a9e7 100644 --- a/vsock-bridge/test/testmain.cpp +++ b/vsock-bridge/test/testmain.cpp @@ -3,10 +3,11 @@ #include #include -#include -#include #include +#include #include +#include +#include #include #include @@ -17,7 +18,7 @@ using namespace vsockio; -std::vector ThreadPool::threads; +std::vector> ThreadPool::threads; thread_local MemoryArena* BufferManager::arena = new MemoryArena(); TEST_CASE("Queue works", "[queue]") @@ -34,7 +35,7 @@ TEST_CASE("Queue works", "[queue]") for (int i = 0; i < 5; i++) { - q.enqueue(pNumbers[i]); + q.enqueue(std::move(pNumbers[i])); } for (int i = 0; i < 5; i++) @@ -63,12 +64,12 @@ TEST_CASE("Buffer works", "[buffer]") } ssize_t c = b->_pageSize / 2; - b->setCursor(c); - REQUIRE(b->cursor() == c); - - b->setSize(c); + b->produce(c); REQUIRE(b->size() == c); + b->consume(c); + REQUIRE(b->cursor() == c); + REQUIRE(b->tryNewPage()); REQUIRE(b->capacity() == 2 * b->_pageSize); REQUIRE(b->size() == c); @@ -146,7 +147,7 @@ TEST_CASE("Slow write", "[peer]") return 16; }; - impl.close = [](int _) {return 0; }; + impl.close = [](int _) { return 0; }; testContext.reset(); @@ -155,8 +156,8 @@ TEST_CASE("Slow write", "[peer]") int fd_a = mock_sock_default(AF_INET, SOCK_STREAM, 0); int fd_b = mock_sock_default(AF_INET, SOCK_STREAM, 0); - Socket a(fd_a, &impl); - Socket b(fd_b, &impl); + Socket a(fd_a, impl); + Socket b(fd_b, impl); a.setPeer(&b); b.setPeer(&a); @@ -193,8 +194,8 @@ TEST_CASE("Fast close", "[peer]") int fd_a = mock_sock_default(AF_INET, SOCK_STREAM, 0); int fd_b = mock_sock_default(AF_INET, SOCK_STREAM, 0); - Socket a(fd_a, &impl); - Socket b(fd_b, &impl); + Socket a(fd_a, impl); + Socket b(fd_b, impl); a.setPeer(&b); b.setPeer(&a); @@ -235,8 +236,8 @@ TEST_CASE("Correct content", "[peer]") int fd_a = mock_sock_default(AF_INET, SOCK_STREAM, 0); int fd_b = mock_sock_default(AF_INET, SOCK_STREAM, 0); - Socket a(fd_a, &impl); - Socket b(fd_b, &impl); + Socket a(fd_a, impl); + Socket b(fd_b, impl); a.setPeer(&b); b.setPeer(&a); @@ -287,8 +288,8 @@ TEST_CASE("No early close", "[peer]") int fd_a = mock_sock_default(AF_INET, SOCK_STREAM, 0); int fd_b = mock_sock_default(AF_INET, SOCK_STREAM, 0); - Socket a(fd_a, &impl); - Socket b(fd_b, &impl); + Socket a(fd_a, impl); + Socket b(fd_b, impl); a.setPeer(&b); b.setPeer(&a); @@ -305,25 +306,24 @@ TEST_CASE("No early close", "[peer]") TEST_CASE("Threaded IO", "[threading]") { - WorkerThread* wt = new WorkerThread([](){}); - std::thread t(&WorkerThread::run, wt); + WorkerThread wt([](){}); - auto* q = wt->_taskQueue; - q->enqueue([]() { - std::cout << "threading test, block worker thread for 1 second" << std::endl; + auto& q = wt._taskQueue; + q.enqueue([]() { + std::cout << "threading test, block main thread for 1 second" << std::endl; }); std::this_thread::sleep_for(std::chrono::seconds(1)); - q->enqueue([wt]() { + q.enqueue([&wt]() { std::cout << "retiring in 1 second..." << std::endl; std::this_thread::sleep_for(std::chrono::seconds(1)); - wt->stop(); + wt.stop(); std::cout << "worker thread stopped." << std::endl; }); std::cout << "main thread joined." << std::endl; - t.join(); + wt.t.join(); std::cout << "main thread exiting." << std::endl; } @@ -373,9 +373,11 @@ TEST_CASE("Queue tasks in channel", "[channel]") int fd_a = mock_sock_default(AF_INET, SOCK_STREAM, 0); int fd_b = mock_sock_default(AF_INET, SOCK_STREAM, 0); - Socket* a = new Socket(fd_a, &impl); - Socket* b = new Socket(fd_b, &impl); - DirectChannel c(1, a, b, &tQueue); + auto ap = std::make_unique(fd_a, impl); + auto bp = std::make_unique(fd_b, impl); + auto* a = ap.get(); + auto* b = bp.get(); + DirectChannel c(1, std::move(ap), std::move(bp), &tQueue); c.handle(b->fd(), IOEvent::InputReady); c.handle(a->fd(), IOEvent::OutputReady); @@ -400,44 +402,52 @@ TEST_CASE("Queue tasks in channel", "[channel]") REQUIRE(dest == source); } - -TEST_CASE("Id LinkedList behavior", "[processor]") +TEST_CASE("ChannelNodePool behavior", "[processor]") { - ChannelIdList ls; - - // [head][0][1][2] - - auto* id0 = ls.getNode(); - auto* id1 = ls.getNode(); - auto* id2 = ls.getNode(); - - REQUIRE(id0->id == 0); - REQUIRE(id1->id == 1); - REQUIRE(id2->id == 2); - REQUIRE(id0->inUse); - REQUIRE(id1->inUse); - REQUIRE(id2->inUse); - - ls.putNode(id1); - ls.putNode(id0); - auto* id1_ = ls.getNode(); - REQUIRE(id1_->id == 1); - REQUIRE(id1_->inUse); - - auto* id0_ = ls.getNode(); - REQUIRE(id0_->id == 0); - REQUIRE(id0_->inUse); - - auto* id3_ = ls.getNode(); - REQUIRE(id3_->id == 3); - REQUIRE(id3_->inUse); + SocketImpl impl( + [](int fd, void* buf, int len) { return 0; }, + [](int fd, void* buf, int len) { return 0; }, + [](int fd) { return 0; } ); + + ChannelNodePool pool; + + auto node0 = pool.getFreeNode(); + auto node1 = pool.getFreeNode(); + auto node2 = pool.getFreeNode(); + + REQUIRE(node0->_id == 0); + REQUIRE(node1->_id == 1); + REQUIRE(node2->_id == 2); + REQUIRE(!node0->inUse()); + REQUIRE(!node1->inUse()); + REQUIRE(!node2->inUse()); + + auto client = std::make_unique(1, impl); + auto server = std::make_unique(2, impl); + BlockingQueue> taskQueue; + node1->_channel = std::make_unique(123, std::move(client), std::move(server), &taskQueue); + REQUIRE(node1->inUse()); + + node0.reset(); + node1.reset(); + auto node1_ = pool.getFreeNode(); + REQUIRE(node1_->_id == 1); + REQUIRE(!node1_->inUse()); + + auto node0_ = pool.getFreeNode(); + REQUIRE(node0_->_id == 0); + REQUIRE(!node0_->inUse()); + + auto node3_ = pool.getFreeNode(); + REQUIRE(node3_->_id == 3); + REQUIRE(!node3_->inUse()); } void createWorkerThreads(int numThreads) { for (int i = 0; i < numThreads; i++) { - ThreadPool::threads.push_back(new WorkerThread([]() { BufferManager::arena->init(1024, 20); })); + ThreadPool::threads.push_back(std::make_unique([]() { BufferManager::arena->init(1024, 20); })); } } @@ -446,7 +456,6 @@ void terminateWorkerThreads() for (int i = 0; i < ThreadPool::threads.size(); i++) { ThreadPool::threads[i]->stop(); - ThreadPool::threads[i]->t->join(); } ThreadPool::threads.clear(); } @@ -480,19 +489,18 @@ TEST_CASE("Dispatcher", "[dispatcher]") MockPoller poller(20); Dispatcher ex(&poller); - auto* channel = ex.prepareChannel(); - Socket* client = new Socket(1, &impl); - Socket* server = new Socket(2, &impl); - ex.makeChannel(channel, client, server, ThreadPool::getTaskQueue(channel->id)); + auto client = std::make_unique(1, impl); + auto server = std::make_unique(2, impl); + auto* channel = ex.addChannel(std::move(client), std::move(server)); REQUIRE(poller._fdMap.find(1) != poller._fdMap.end()); REQUIRE(poller._fdMap.find(2) != poller._fdMap.end()); REQUIRE(poller._fdMap[1].listeningEvents == (uint32_t)(IOEvent::InputReady | IOEvent::OutputReady)); REQUIRE(poller._fdMap[2].listeningEvents == (uint32_t)(IOEvent::InputReady | IOEvent::OutputReady)); - REQUIRE(((ChannelHandle*)poller._fdMap[1].handler)->channelId == channel->id); - REQUIRE(((ChannelHandle*)poller._fdMap[2].handler)->channelId == channel->id); + REQUIRE(((ChannelHandle*)poller._fdMap[1].handler)->channelId == channel->_id); + REQUIRE(((ChannelHandle*)poller._fdMap[2].handler)->channelId == channel->_id); - auto* queue = ThreadPool::getTaskQueue(channel->id); + auto* queue = ThreadPool::getTaskQueue(channel->_id); REQUIRE(queue->empty()); poller.setInputReady(1, true); poller.setOutputReady(2, true); @@ -505,9 +513,16 @@ TEST_CASE("Dispatcher", "[dispatcher]") std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - poller.setOutputReady(2, true); + REQUIRE(channel->inUse() == true); + REQUIRE(channel->_channel->canBeTerminated()); + ex.taskloop(); - REQUIRE(channel->inUse == false); + for (; retry < 100 && channel->inUse(); retry++) + { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + REQUIRE(channel->inUse() == false); for (int i = 0; i < ThreadPool::threads.size(); i++) {