Skip to content

Commit

Permalink
feat: --max-seq-len argument. (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Jul 29, 2024
1 parent 57e3807 commit 71135e6
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.steps = 0;
args.seed = (unsigned long long)time(NULL);
args.chatTemplateType = TEMPLATE_UNKNOWN;
args.maxSeqLen = 0;
args.useDiscForKvCache = false;

int i = 1;
Expand Down Expand Up @@ -99,6 +100,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.seed = atoll(value);
} else if (strcmp(name, "--chat-template") == 0) {
args.chatTemplateType = parseChatTemplateType(value);
} else if (strcmp(name, "--max-seq-len") == 0) {
args.maxSeqLen = (unsigned int)atoi(value);
} else if (strcmp(name, "--kv-cache-storage") == 0) {
args.useDiscForKvCache = strcmp(value, "disc") == 0;
} else {
Expand Down Expand Up @@ -128,7 +131,7 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts);
unsigned int nSlices = args->nWorkers + 1;

TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->weightsFloatType, args->bufferFloatType);
TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->maxSeqLen, args->weightsFloatType, args->bufferFloatType);
TransformerArch arch = TransformerArchFactory::create(&spec);
Tokenizer tokenizer(args->tokenizerPath, spec.vocabSize);

Expand Down
1 change: 1 addition & 0 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class AppArgs {
bool benchmark;
unsigned long long seed;
ChatTemplateType chatTemplateType;
unsigned int maxSeqLen;

// worker
int port;
Expand Down
2 changes: 1 addition & 1 deletion src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class Chat {
int nInputTokens;
tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false);

pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, (int)pos + nInputTokens - 1);
pos_t userPromptEndPos = (pos_t)std::min<unsigned int>(spec->seqLen, pos + nInputTokens - 1);
for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) {
inference->infer(inputTokens[i], pos);
token = inputTokens[i + 1];
Expand Down
2 changes: 1 addition & 1 deletion src/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ LlamaRopeCommand::LlamaRopeCommand(RopeSlice *slice) {

size_t cacheBytes = slice->seqLen * slice->sliceDim * sizeof(float);
cache = (float*)newBuffer(cacheBytes);
printf("🕒 ropeCache: %ld kB\n", cacheBytes / 1024);
printf("🕒 ropeCacheSize: %ld kB\n", cacheBytes / 1024);

for (pos_t pos = 0; pos < slice->seqLen; pos++) {
for (unsigned int i = slice->kvDimStart; i < slice->qDimEnd; i += 2) {
Expand Down
9 changes: 8 additions & 1 deletion src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#define IS_ROOT_SLICE(sliceIndex) (sliceIndex == 0)

TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) {
TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, const unsigned int maxSeqLen, FloatType weightsFloatType, FloatType bufferFloatType) {
TransformerSpec spec;
memset(&spec, 0, sizeof(TransformerSpec));
spec.hiddenAct = SILU;
Expand Down Expand Up @@ -95,6 +95,10 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i
}
}

spec.origSeqLen = spec.seqLen;
if (maxSeqLen > 0 && spec.seqLen > maxSeqLen) {
spec.seqLen = maxSeqLen;
}
spec.headSize = spec.dim / spec.nHeads;
spec.kvDim = (spec.dim * spec.nKvHeads) / spec.nHeads;
spec.weightsFloatType = weightsFloatType;
Expand Down Expand Up @@ -131,6 +135,9 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i
printf("💡 nActiveExperts: %d\n", spec.nActiveExperts);
}
printf("💡 vocabSize: %d\n", spec.vocabSize);
if (spec.seqLen != spec.origSeqLen) {
printf("💡 origSeqLen: %d\n", spec.origSeqLen);
}
printf("💡 seqLen: %d\n", spec.seqLen);
printf("💡 nSlices: %d\n", spec.nSlices);
printf("💡 ropeTheta: %.1f\n", spec.ropeTheta);
Expand Down
5 changes: 3 additions & 2 deletions src/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ struct TransformerSpec {
int nKvHeads;
int nExperts;
int nActiveExperts;
int seqLen;
unsigned int origSeqLen; // Original model context length
unsigned int seqLen; // Limited context length by the `--max-seq-len` argument
int hiddenDim;
TransformerHiddenAct hiddenAct;
int kvDim;
Expand Down Expand Up @@ -197,7 +198,7 @@ class Transformer {

~Transformer();

static TransformerSpec loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType);
static TransformerSpec loadSpecFromFile(const char* path, const unsigned int nSlices, const unsigned int maxSeqLen, FloatType weightsFloatType, FloatType bufferFloatType);
static Transformer loadRootFromFile(const char* path, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool);
static Transformer loadRoot(char* data, TransformerSpec* spec, TransformerConfig* config, SocketPool* socketPool);
static Transformer loadSlice(TransformerSpec* spec, TransformerConfig* config, Socket* socket);
Expand Down

0 comments on commit 71135e6

Please sign in to comment.