Skip to content

Commit

Permalink
feat: reduction of writeMany/readMany calls. (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Aug 10, 2024
1 parent 668ea98 commit 3353d56
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 52 deletions.
76 changes: 25 additions & 51 deletions src/tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, Transformer
// root

unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
ios[i].socketIndex = threadIndex + i * nThreads;
ios[i].data = buffer;
ios[i].size = bufferBytes;
if (nSockets > 0) {
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
ios[i].socketIndex = threadIndex + i * nThreads;
ios[i].data = buffer;
ios[i].size = bufferBytes;
}
ctx->socketPool->writeMany(nSockets, ios);
}
ctx->socketPool->writeMany(nSockets, ios);
} else if (ctx->socket != NULL) {
if (threadIndex != 0) return;

Expand All @@ -70,54 +72,24 @@ void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, Tr
// root

unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
int socketIndex = threadIndex + i * nThreads;
uint8_t workerSliceIndex = socketIndex + 1;
ios[i].socketIndex = socketIndex;
ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex);
ios[i].size = bufferBytes;
}

ctx->socketPool->readMany(nSockets, ios);
} else if (ctx->socket != NULL) {
if (threadIndex != 0) return;

// worker
void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex);
ctx->socket->write(buffer, bufferBytes);
}
}

void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex) {
size_t sliceBytes = ctx->transformer->buffer->getSlicedBytes(bufferIndex);
if (ctx->socketPool != NULL) {
// root

unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
SocketIo ios[nSockets];

for (uint8_t si = 0; si < ctx->transformer->spec->nSlices - 1; si++) {
for (unsigned int i = 0; i < nSockets; i++) {
if (nSockets > 0) {
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
int socketIndex = threadIndex + i * nThreads;
uint8_t workerSliceIndex = socketIndex + 1;
slice_index_t sliceIndex = si < workerSliceIndex ? si : si + 1;
ios[i].socketIndex = socketIndex;
ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex);
ios[i].size = sliceBytes;
ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex);
ios[i].size = bufferBytes;
}
ctx->socketPool->writeMany(nSockets, ios);

ctx->socketPool->readMany(nSockets, ios);
}
} else if (ctx->socket != NULL) {
if (threadIndex != 0) return;

// worker
for (slice_index_t sliceIndex = 0; sliceIndex < ctx->transformer->spec->nSlices; sliceIndex++) {
if (sliceIndex != ctx->transformer->sliceIndex) {
void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex);
ctx->socket->read(buffer, sliceBytes);
}
}
void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex);
ctx->socket->write(buffer, bufferBytes);
}
}

Expand Down Expand Up @@ -167,13 +139,15 @@ void sendPos(TASK_ARGS) {

if (ctx->socketPool != NULL) {
unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
ios[i].socketIndex = threadIndex + i * nThreads;
ios[i].data = &transformer->pos;
ios[i].size = sizeof(pos_t);
if (nSockets > 0) {
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
ios[i].socketIndex = threadIndex + i * nThreads;
ios[i].data = &transformer->pos;
ios[i].size = sizeof(pos_t);
}
ctx->socketPool->writeMany(nSockets, ios);
}
ctx->socketPool->writeMany(nSockets, ios);
}
}

Expand Down
1 change: 0 additions & 1 deletion src/tasks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class TransformerArch {

void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
Expand Down

0 comments on commit 3353d56

Please sign in to comment.