Skip to content

Commit

Permalink
reduce transfer size per token.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Nov 21, 2024
1 parent dffc23d commit c2c2fa0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 23 deletions.
1 change: 0 additions & 1 deletion src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {

TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
if (spec->archType == LLAMA) return buildLlamaArch(spec);
if (spec->archType == MIXTRAL) return buildMixtralArch(spec);
printf("Unsupported arch type: %d\n", spec->archType);
exit(EXIT_FAILURE);
}
Expand Down
6 changes: 3 additions & 3 deletions src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void llamaQuantizeAtt(TASK_ARGS) {

void llamaSyncAtt(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XBV_QUANTIZED);
syncSliceOfSlicedBuffer(nThreads, threadIndex, false, ctx, TB_SLICED_XBV_QUANTIZED);
}

void llamaDequantizeAtt(TASK_ARGS) {
Expand Down Expand Up @@ -195,7 +195,7 @@ void llamaQuantizeFfn2(TASK_ARGS) {

void llamaSyncFfn2(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XBV_QUANTIZED);
syncSliceOfSlicedBuffer(nThreads, threadIndex, false, ctx, TB_SLICED_XBV_QUANTIZED);
}

void llamaDequantizeFfn2(TASK_ARGS) {
Expand Down Expand Up @@ -248,7 +248,7 @@ void llamaFinalize(TASK_ARGS) {

void llamaSyncLogits(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_LOGITS);
syncSliceOfSlicedBuffer(nThreads, threadIndex, true, ctx, TB_SLICED_LOGITS);
}

TransformerArch buildLlamaArch(TransformerSpec* spec) {
Expand Down
4 changes: 2 additions & 2 deletions src/mixtral-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void grokQuantizeMoeMul(TASK_ARGS) {

void grokSyncMoeMulA(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_HB_QUANTIZED);
syncSliceOfSlicedBuffer(nThreads, threadIndex, false, ctx, TB_SLICED_HB_QUANTIZED);
}

void grokSyncMoeMulRearrange(TASK_ARGS) {
Expand Down Expand Up @@ -155,7 +155,7 @@ void grokSyncMoeMulB(TASK_ARGS) {

void grokSyncMoeOutput(TASK_ARGS) {
TASK_VARIABLES;
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_XB2_QUANTIZED);
syncSliceOfSlicedBuffer(nThreads, threadIndex, false, ctx, TB_SLICED_XB2_QUANTIZED);
}

void grokDequantizeMoeOutput(TASK_ARGS) {
Expand Down
39 changes: 23 additions & 16 deletions src/tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,36 @@ void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, Transformer
}
}

void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex) {
void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, bool onlyFromWorkerToRoot, TransformerContext* ctx, uint8_t bufferIndex) {
unsigned int nSocketsPerThread = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
if (nSocketsPerThread == 0) return;
size_t sliceBytes = ctx->transformer->buffer->getSlicedBytes(bufferIndex);

void* mySliceData = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex);
size_t sliceBytes = ctx->transformer->buffer->getSlicedBytes(bufferIndex);

SocketIo writeIos[nSocketsPerThread];
SocketIo readIos[nSocketsPerThread];
for (unsigned int i = 0; i < nSocketsPerThread; i++) {
unsigned int socketIndex = threadIndex + i * nThreads;
writeIos[i].socketIndex = socketIndex;
writeIos[i].data = mySliceData;
writeIos[i].size = sliceBytes;
if (!onlyFromWorkerToRoot || ctx->transformer->sliceIndex != 0) {
void* mySliceData = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex);

int sliceIndex = socketIndex >= ctx->transformer->sliceIndex ? socketIndex + 1 : socketIndex;
readIos[i].socketIndex = socketIndex;
readIos[i].data = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex);
readIos[i].size = sliceBytes;
SocketIo writeIos[nSocketsPerThread];
for (unsigned int i = 0; i < nSocketsPerThread; i++) {
unsigned int socketIndex = threadIndex + i * nThreads;
writeIos[i].socketIndex = socketIndex;
writeIos[i].data = mySliceData;
writeIos[i].size = sliceBytes;
}
ctx->socketPool->writeManyWithAlignment(nSocketsPerThread, writeIos);
}

ctx->socketPool->writeManyWithAlignment(nSocketsPerThread, writeIos);
ctx->socketPool->readManyWithAlignment(nSocketsPerThread, readIos);
if (!onlyFromWorkerToRoot || ctx->transformer->sliceIndex == 0) {
SocketIo readIos[nSocketsPerThread];
for (unsigned int i = 0; i < nSocketsPerThread; i++) {
unsigned int socketIndex = threadIndex + i * nThreads;
int sliceIndex = socketIndex >= ctx->transformer->sliceIndex ? socketIndex + 1 : socketIndex;
readIos[i].socketIndex = socketIndex;
readIos[i].data = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex);
readIos[i].size = sliceBytes;
}
ctx->socketPool->readManyWithAlignment(nSocketsPerThread, readIos);
}
}

void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex) {
Expand Down
2 changes: 1 addition & 1 deletion src/tasks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TransformerArch {
TransformerSpec* spec = transformer->spec; // printf("%s:%d\n", __FUNCTION__, ctx->currentBlockIndex); fflush(stdout);

void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, bool onlyFromWorkerToRoot, TransformerContext* ctx, uint8_t bufferIndex);
void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool skipMySlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
Expand Down

0 comments on commit c2c2fa0

Please sign in to comment.