Skip to content

Commit

Permalink
feat: dev mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Jul 31, 2024
1 parent f18ac63 commit ee2c689
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,23 @@ MatmulCommand::MatmulCommand(const unsigned int n, const unsigned int d, const F
this->inputFloatType = inputFloatType;
this->weightsFloatType = weightsFloatType;
this->cpuSize = getBatchBytes(weightsFloatType, n, d);
#if ALLOC_MEMORY
this->cpuWeights = newBuffer(this->cpuSize);
#endif
};

MatmulCommand::~MatmulCommand() {
#if ALLOC_MEMORY
freeBuffer(cpuWeights);
#endif
}

size_t MatmulCommand::loadWeights(const void* source) {
#if ALLOC_MEMORY
memcpy(cpuWeights, source, cpuSize);
#else
cpuWeights = (void*)source;
#endif
return cpuSize;
}

Expand Down
20 changes: 18 additions & 2 deletions src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,10 @@ Transformer::Transformer(TransformerSpec* spec, TransformerConfig* config, slice
tokenEmbeddingTableBytes = spec->vocabSize * spec->dim * sizeof(float);
rmsFinalBytes = spec->dim * sizeof(float);

#if ALLOC_MEMORY
tokenEmbeddingTable = (float*)newBuffer(tokenEmbeddingTableBytes);
rmsFinal = (float*)newBuffer(rmsFinalBytes);

#endif
wclsMm = new MatmulCommand(spec->dim, spec->vocabSize, F32, spec->weightsFloatType);

x = (float*)newBuffer(spec->dim * sizeof(float));
Expand Down Expand Up @@ -272,8 +273,10 @@ Transformer::~Transformer() {
delete[] blocks;

if (IS_ROOT_SLICE(sliceIndex)) {
#if ALLOC_MEMORY
freeBuffer(tokenEmbeddingTable);
freeBuffer(rmsFinal);
#endif
delete wclsMm;

freeBuffer(x);
Expand All @@ -295,12 +298,14 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, TransformerConfig* con
rmsMoeBytes = spec->dim * sizeof(float);
rmsFfn2Bytes = spec->dim * sizeof(float);

#if ALLOC_MEMORY
rmsAtt = (float*)newBuffer(rmsAttBytes);
rmsFfn = (float*)newBuffer(rmsFfnBytes);
if (spec->archType == GROK1) {
rmsMoe = (float*)newBuffer(rmsMoeBytes);
rmsFfn2 = (float*)newBuffer(rmsFfn2Bytes);
}
#endif
}

kvCacheSlice = new KvCacheSlice(spec->kvDim, spec->seqLen, spec->nSlices);
Expand Down Expand Up @@ -360,6 +365,7 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, TransformerConfig* con
}

TransformerBlock::~TransformerBlock() {
#if ALLOC_MEMORY
if (IS_ROOT_SLICE(sliceIndex)) {
freeBuffer(rmsAtt);
freeBuffer(rmsFfn);
Expand All @@ -368,6 +374,7 @@ TransformerBlock::~TransformerBlock() {
freeBuffer(rmsFfn2);
}
}
#endif

delete kvCacheSlice;
if (config->useDiscForKvCache) {
Expand Down Expand Up @@ -423,6 +430,7 @@ TransformerBlock::~TransformerBlock() {
}

static size_t loadSlicedMatmulWeights(const uint8_t nSlices, MatmulSlice* slice, char* source, MatmulCommand* mm, SocketPool* socketPool) {
#if ALLOC_MEMORY
char* buffer = (char*)newBuffer(slice->sliceBytes);
size_t loadedBytes = 0;
for (uint8_t s = 0; s < nSlices; s++) {
Expand All @@ -437,10 +445,17 @@ static size_t loadSlicedMatmulWeights(const uint8_t nSlices, MatmulSlice* slice,
}
freeBuffer(buffer);
return loadedBytes;
#else
return mm->loadWeights(source);
#endif
}

static size_t loadRootWeights(char** target, char* source, size_t bytes) {
#if ALLOC_MEMORY
memcpy(*target, source, bytes);
#else
*target = source;
#endif
return bytes;
}

Expand All @@ -456,8 +471,9 @@ Transformer Transformer::loadRootFromFile(const char* path, TransformerSpec* spe
char* weights = ((char*)file.data) + spec->headerSize;
Transformer transformer = Transformer::loadRoot((char*)weights, spec, config, socketPool);

#if ALLOC_MEMORY
closeMmapFile(&file);

#endif
return transformer;
}

Expand Down
2 changes: 2 additions & 0 deletions src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <cstdio>
#include "common/pthread.h"

#define ALLOC_MEMORY true

#ifdef _WIN32
#include <windows.h>
#endif
Expand Down

0 comments on commit ee2c689

Please sign in to comment.