Skip to content

Commit

Permalink
Merge pull request #4 from IABTechLab/ant-UID2-2320-refactoring
Browse files Browse the repository at this point in the history
UID2-2320 AWS vsock proxy refactoring and cleanup
  • Loading branch information
atarassov-ttd authored Dec 4, 2023
2 parents b4a5605 + 1c79a70 commit 070840f
Show file tree
Hide file tree
Showing 22 changed files with 779 additions and 638 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ cmake_minimum_required (VERSION 3.8)

project ("vsock-bridge")

set(CMAKE_CXX_FLAGS_DEBUG "-ggdb")

add_subdirectory ("vsock-bridge")
138 changes: 82 additions & 56 deletions vsock-bridge/include/buffer.h
Original file line number Diff line number Diff line change
@@ -1,57 +1,53 @@
#pragma once

#include <unistd.h>
#include <stdint.h>
#include "logger.h"

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iostream>
#include <vector>
#include <list>
#include <logger.h>
#include <memory>
#include <vector>

#include <unistd.h>

namespace vsockio
{
struct MemoryBlock
{
MemoryBlock(uint8_t* startPtr, class MemoryArena* region)
: _startPtr(startPtr), _region(region) {}

MemoryBlock(MemoryBlock&& other) : _startPtr(other._startPtr), _region(other._region) {}
MemoryBlock(int size, class MemoryArena* region)
: _startPtr(std::make_unique<uint8_t[]>(size)), _region(region) {}

MemoryBlock(const MemoryBlock& other) = delete;

uint8_t* offset(int x)
uint8_t* offset(int x) const
{
return _startPtr + x;
return _startPtr.get() + x;
}

uint8_t* _startPtr;
std::unique_ptr<uint8_t[]> _startPtr;
class MemoryArena* _region;
};

struct MemoryArena
{
std::vector<MemoryBlock> _blocks;
std::list<MemoryBlock*> _handles;
uint8_t* _memoryStart;
uint32_t _blockSizeInBytes;
int _numBlocks;
bool _initialized;
uint32_t _blockSizeInBytes = 0;
bool _initialized = false;

MemoryArena()
: _initialized(false), _numBlocks(0), _memoryStart(nullptr), _blocks{} {}
MemoryArena() = default;

void init(uint32_t blockSize, int numBlocks)
void init(int blockSize, int numBlocks)
{
if (_initialized) throw;

Logger::instance->Log(Logger::INFO, "Thread-local memory arena init: blockSize=", blockSize, ", numBlocks=", numBlocks);

_numBlocks = numBlocks;
_blockSizeInBytes = blockSize;
_memoryStart = static_cast<uint8_t*>(malloc(blockSize * numBlocks * sizeof(uint8_t)));

for (int i = 0; i < numBlocks; i++)
{
_blocks.emplace_back(MemoryBlock( startPtrOf(i), this ));
_blocks.emplace_back(blockSize, this);
}

for (int i = 0; i < numBlocks; i++)
Expand All @@ -62,11 +58,6 @@ namespace vsockio
_initialized = true;
}

inline uint8_t* startPtrOf(int blockIndex) const
{
return _memoryStart + (blockIndex * _blockSizeInBytes);
}

MemoryBlock* get()
{
if (!_handles.empty())
Expand All @@ -77,19 +68,18 @@ namespace vsockio
}
else
{
return new MemoryBlock(new uint8_t[_blockSizeInBytes], nullptr);
return new MemoryBlock(_blockSizeInBytes, nullptr);
}
}

void put(MemoryBlock* mb)
{
if (mb->_region == this)
{
_handles.push_back(mb);
_handles.push_front(mb);
}
else if (mb->_region == nullptr)
{
delete[] mb->_startPtr;
delete mb;
}
else
Expand All @@ -111,7 +101,7 @@ namespace vsockio
MemoryBlock* _pages[MAX_PAGES];
MemoryArena* _arena;

Buffer(MemoryArena* arena) : _arena(arena), _pageCount{ 0 }, _cursor{ 0 }, _size{ 0 }, _pageSize(arena->blockSize()) {}
explicit Buffer(MemoryArena* arena) : _arena(arena), _pageCount{ 0 }, _cursor{ 0 }, _size{ 0 }, _pageSize(arena->blockSize()) {}

Buffer(Buffer&& b) : _arena(b._arena), _pageCount(b._pageCount), _cursor(b._cursor), _size(b._size), _pageSize(b._arena->blockSize())
{
Expand All @@ -122,7 +112,51 @@ namespace vsockio
b._pageCount = 0; // prevent _pages being destructed by old object
}

Buffer(const Buffer& _) = delete;
Buffer(const Buffer&) = delete;
Buffer& operator=(const Buffer&) = delete;

~Buffer()
{
for (int i = 0; i < _pageCount; i++)
{
_arena->put(_pages[i]);
}
}

uint8_t* tail() const
{
return offset(_size);
}

int remainingCapacity() const
{
return capacity() - _size;
}

void produce(int size)
{
_size += size;
}

bool ensureCapacity()
{
return remainingCapacity() > 0 || tryNewPage();
}

uint8_t* head() const
{
return offset(_cursor);
}

int headLimit() const
{
return std::min(pageLimit(_cursor), _size - _cursor);
}

void consume(int size)
{
_cursor += size;
}

bool tryNewPage()
{
Expand All @@ -131,64 +165,56 @@ namespace vsockio
return true;
}

uint8_t* offset(ssize_t x)
uint8_t* offset(int x) const
{
return _pages[x / _pageSize]->offset(x % _pageSize);
}

size_t capacity() const
int capacity() const
{
return _pageCount * _pageSize;
}

void setCursor(size_t cursor)
int pageLimit(int x) const
{
_cursor = cursor;
return _pageSize - (x % _pageSize);
}

size_t cursor() const
int cursor() const
{
return _cursor;
}

size_t pageLimit(ssize_t x)
int size() const
{
return _pageSize - (x % _pageSize);
return _size;
}

void setSize(size_t size)
bool empty() const
{
_size = size;
return _size <= 0;
}

size_t size() const
bool consumed() const
{
return _size;
}

~Buffer()
{
for (int i = 0; i < _pageCount; i++)
{
_arena->put(_pages[i]);
}
return _cursor >= _size;
}
};

struct BufferManager
{
thread_local static MemoryArena* arena;

static Buffer* getBuffer()
static std::unique_ptr<Buffer> getBuffer()
{
Buffer* b = new Buffer(arena);
auto b = std::make_unique<Buffer>(arena);
b->tryNewPage();
return b;
}

static Buffer* getEmptyBuffer()
static std::unique_ptr<Buffer> getEmptyBuffer()
{
return new Buffer(arena);
return std::make_unique<Buffer>(arena);
}
};

Expand Down
68 changes: 34 additions & 34 deletions vsock-bridge/include/channel.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#pragma once

#include <eventdef.h>
#include <socket.h>
#include <logger.h>
#include <threading.h>
#include "eventdef.h"
#include "logger.h"
#include "socket.h"
#include "threading.h"

#include <memory>

namespace vsockio
{
Expand All @@ -23,63 +25,61 @@ namespace vsockio
int _id;
BlockingQueue<TAction>* _taskQueue;

Socket* _a;
Socket* _b;
std::unique_ptr<Socket> _a;
std::unique_ptr<Socket> _b;
ChannelHandle _ha;
ChannelHandle _hb;

DirectChannel(int id, Socket* a, Socket* b, BlockingQueue<TAction>* taskQueue)
DirectChannel(int id, std::unique_ptr<Socket> a, std::unique_ptr<Socket> b, BlockingQueue<TAction>* taskQueue)
: _id(id)
, _a(a)
, _b(b)
, _ha(id, a->fd())
, _hb(id, b->fd())
, _a(std::move(a))
, _b(std::move(b))
, _ha(id, _a->fd())
, _hb(id, _b->fd())
, _taskQueue(taskQueue)

{
_a->setPeer(b);
_b->setPeer(a);
_a->setPeer(_b.get());
_b->setPeer(_a.get());
}

void handle(int fd, int evt)
{

Socket* s = _a->fd() == fd ? _a : (_b->fd() == fd ? _b : nullptr);
Socket* s = _a->fd() == fd ? _a.get() : (_b->fd() == fd ? _b.get() : nullptr);
if (s == nullptr)
{
Logger::instance->Log(Logger::WARNING, "error in channel.handle: `id=", _id,"`, `fd=", fd, "` does not belong to this channel");
return;
}

if (evt & IOEvent::Error)
{
s->incrementEventCount();
_taskQueue->enqueue(std::bind(&Socket::onError, s));
Logger::instance->Log(Logger::DEBUG, "poll error for fd=", fd);
evt |= IOEvent::InputReady;
evt |= IOEvent::OutputReady;
}
else

if (evt & IOEvent::InputReady)
{
if (evt & IOEvent::InputReady)
{
s->incrementEventCount();
_taskQueue->enqueue(std::bind(&Socket::onInputReady, s));
}
s->onIoEvent();
_taskQueue->enqueue([=] { s->onInputReady(); });
}

if (evt & IOEvent::OutputReady)
{
s->incrementEventCount();
_taskQueue->enqueue(std::bind(&Socket::onOutputReady, s));
}
if (evt & IOEvent::OutputReady)
{
s->onIoEvent();
_taskQueue->enqueue([=] { s->onOutputReady(); });
}
}

bool canBeTerminated() const
void terminate()
{
return _a->closed() && _b->closed() && _a->ioEventCount() == 0 && _b->ioEventCount() == 0;
_taskQueue->enqueue([this] { delete this; });
}
virtual ~DirectChannel()

bool canBeTerminated() const
{
delete _a;
delete _b;
return _a->closed() && _b->closed() && _a->ioEventCount() == 0 && _b->ioEventCount() == 0;
}
};
}
4 changes: 2 additions & 2 deletions vsock-bridge/include/config.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <stdint.h>
#include <vector>
#include <cstdint>
#include <string>
#include <vector>

namespace vsockproxy
{
Expand Down
Loading

0 comments on commit 070840f

Please sign in to comment.