diff --git a/src/tasks.cpp b/src/tasks.cpp index 17c2322..2d7a629 100644 --- a/src/tasks.cpp +++ b/src/tasks.cpp @@ -49,13 +49,15 @@ void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, Transformer // root unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); - SocketIo ios[nSockets]; - for (int i = 0; i < nSockets; i++) { - ios[i].socketIndex = threadIndex + i * nThreads; - ios[i].data = buffer; - ios[i].size = bufferBytes; + if (nSockets > 0) { + SocketIo ios[nSockets]; + for (int i = 0; i < nSockets; i++) { + ios[i].socketIndex = threadIndex + i * nThreads; + ios[i].data = buffer; + ios[i].size = bufferBytes; + } + ctx->socketPool->writeMany(nSockets, ios); } - ctx->socketPool->writeMany(nSockets, ios); } else if (ctx->socket != NULL) { if (threadIndex != 0) return; @@ -70,54 +72,24 @@ void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, Tr // root unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); - SocketIo ios[nSockets]; - for (int i = 0; i < nSockets; i++) { - int socketIndex = threadIndex + i * nThreads; - uint8_t workerSliceIndex = socketIndex + 1; - ios[i].socketIndex = socketIndex; - ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex); - ios[i].size = bufferBytes; - } - - ctx->socketPool->readMany(nSockets, ios); - } else if (ctx->socket != NULL) { - if (threadIndex != 0) return; - - // worker - void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex); - ctx->socket->write(buffer, bufferBytes); - } -} - -void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex) { - size_t sliceBytes = ctx->transformer->buffer->getSlicedBytes(bufferIndex); - if (ctx->socketPool != NULL) { - // root - - unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); - SocketIo ios[nSockets]; - - for (uint8_t si = 0; si < ctx->transformer->spec->nSlices - 1; si++) { - for (unsigned int i = 0; i < nSockets; i++) { + if (nSockets > 0) { + SocketIo ios[nSockets]; + for (int i = 0; i < nSockets; i++) { int socketIndex = threadIndex + i * nThreads; uint8_t workerSliceIndex = socketIndex + 1; - slice_index_t sliceIndex = si < workerSliceIndex ? si : si + 1; ios[i].socketIndex = socketIndex; - ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex); - ios[i].size = sliceBytes; + ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex); + ios[i].size = bufferBytes; } - ctx->socketPool->writeMany(nSockets, ios); + + ctx->socketPool->readMany(nSockets, ios); } } else if (ctx->socket != NULL) { if (threadIndex != 0) return; // worker - for (slice_index_t sliceIndex = 0; sliceIndex < ctx->transformer->spec->nSlices; sliceIndex++) { - if (sliceIndex != ctx->transformer->sliceIndex) { - void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex); - ctx->socket->read(buffer, sliceBytes); - } - } + void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex); + ctx->socket->write(buffer, bufferBytes); } } @@ -167,13 +139,15 @@ void sendPos(TASK_ARGS) { if (ctx->socketPool != NULL) { unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0); - SocketIo ios[nSockets]; - for (int i = 0; i < nSockets; i++) { - ios[i].socketIndex = threadIndex + i * nThreads; - ios[i].data = &transformer->pos; - ios[i].size = sizeof(pos_t); + if (nSockets > 0) { + SocketIo ios[nSockets]; + for (int i = 0; i < nSockets; i++) { + ios[i].socketIndex = threadIndex + i * nThreads; + ios[i].data = &transformer->pos; + ios[i].size = sizeof(pos_t); + } + ctx->socketPool->writeMany(nSockets, ios); } - ctx->socketPool->writeMany(nSockets, ios); } } diff --git a/src/tasks.hpp b/src/tasks.hpp index 5fa679b..a1c409d 100644 --- a/src/tasks.hpp +++ b/src/tasks.hpp @@ -44,7 +44,6 @@ class TransformerArch { void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex); void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex); -void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex); void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex); void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex); void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);