Skip to content

Commit

Permalink
feat: support llama 3.1.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Jul 24, 2024
1 parent 8c57298 commit 17743c2
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 5 deletions.
16 changes: 16 additions & 0 deletions converter/convert-hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def parseHiddenAct(act: str):
raise Exception(f'Unsupported hidden act: {act}')
return hiddenAct

def parseRopeType(rt: str):
ropeType = {
'llama3': 2, # LLAMA3_1
}.get(rt)
if (ropeType is None):
raise Exception(f'Unsupported rope type: {ropeType}')
return ropeType

def loadConfig(folderPath: str, weightsFloatType: int):
allFiles = os.listdir(folderPath)
allFiles.sort()
Expand Down Expand Up @@ -178,6 +186,14 @@ def loadConfig(folderPath: str, weightsFloatType: int):
ropeTheta = config.get('rope_theta')
if (ropeTheta is not None):
result['rope_theta'] = int(ropeTheta)

ropeScaling = config.get('rope_scaling')
if (ropeScaling is not None):
result['rope_scaling_factor'] = int(ropeScaling['factor'])
result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor'])
result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor'])
result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings'])
result['rope_type'] = parseRopeType(ropeScaling['rope_type'])
return result

def printUsage():
Expand Down
7 changes: 6 additions & 1 deletion converter/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,12 @@ def writeHeader(file, params):
'max_seq_len': 10,
'hidden_act': 11,
'rope_theta': 12,
'weights_float_type': 13
'weights_float_type': 13,
'rope_scaling_factor': 14,
'rope_scaling_low_freq_factor': 15,
'rope_scaling_high_freq_factory': 16,
'rope_scaling_orig_max_seq_len': 17,
'rope_type': 18,
}
header = struct.pack('i', 0xA00ABCD)

Expand Down
2 changes: 1 addition & 1 deletion src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class Chat {
int nInputTokens;
tokenizer->encode((char*)inputPrompt.c_str(), inputTokens, &nInputTokens, true, false);

pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, pos + nInputTokens - 1);
pos_t userPromptEndPos = (pos_t)std::min(spec->seqLen, (int)pos + nInputTokens - 1);
for (pos_t i = 0; pos < userPromptEndPos; pos++, i++) {
inference->infer(inputTokens[i], pos);
token = inputTokens[i + 1];
Expand Down
48 changes: 48 additions & 0 deletions src/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,54 @@ void LlamaRopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nT
}
}

Llama3_1RopeCommand::Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen) {
this->slice = slice;
this->ropeScalingFactor = ropeScalingFactor;
this->ropeScalingLowFreqFactor = ropeScalingLowFreqFactor;
this->ropeScalingHighFreqFactory = ropeScalingHighFreqFactory;
this->ropeScalingOrigMaxSeqLen = ropeScalingOrigMaxSeqLen;
printf("🕒 ropeScalingFactor: %f\n", ropeScalingFactor);
printf("🕒 ropeScalingLowFreqFactor: %f\n", ropeScalingLowFreqFactor);
printf("🕒 ropeScalingHighFreqFactory: %f\n", ropeScalingHighFreqFactory);
printf("🕒 ropeScalingOrigMaxSeqLen: %d\n", ropeScalingOrigMaxSeqLen);
}

float Llama3_1RopeCommand::scale(float freq) {
float waveLen = 2.0f * M_PI * freq;
float lowFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingLowFreqFactor;
float highFreqWavelen = ropeScalingOrigMaxSeqLen / ropeScalingHighFreqFactory;
if (waveLen < highFreqWavelen) {
return freq;
} else if (waveLen > lowFreqWavelen) {
return freq / ropeScalingFactor;
} else {
float smooth = (ropeScalingOrigMaxSeqLen / waveLen - ropeScalingLowFreqFactor) / (ropeScalingHighFreqFactory - ropeScalingLowFreqFactor);
return (1 - smooth) * freq / ropeScalingFactor + smooth * freq;
}
}

void Llama3_1RopeCommand::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
const unsigned int dim0Half = (isQ ? slice->qDim0 : slice->kvDim0) / 2;
const unsigned int shift = isQ ? slice->qShift : 0;
SPLIT_RANGE_TO_THREADS(s, e, 0, dim0Half, nThreads, threadIndex);
const unsigned int iStart = s * 2;
const unsigned int iEnd = e * 2;

for (unsigned int i = iStart; i < iEnd; i += 2) {
const unsigned int headDim = i % slice->headSize;
const float freq = 1.0f / powf(slice->ropeTheta, headDim / (float)slice->headSize);
const float val = pos * freq;
const float fcr = cosf(val);
const float fci = sinf(val);

float v0 = qOrK[i];
float v1 = qOrK[i + 1];

qOrK[i] = scale(v0 * fcr - v1 * fci);
qOrK[i + 1] = scale(v0 * fci + v1 * fcr);
}
}

FalconRopeCommand::FalconRopeCommand(RopeSlice *slice) {
this->slice = slice;
}
Expand Down
15 changes: 14 additions & 1 deletion src/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// *Slice - calculates sizes, offsets, slice sizes etc. It is not responsible for memory allocation. It may help in the loading of data.
// *Command - allocates memory for weights, performs calculations.

typedef unsigned short pos_t;
typedef unsigned int pos_t;
typedef uint8_t slice_index_t;

class MatmulSlice {
Expand Down Expand Up @@ -106,6 +106,19 @@ class LlamaRopeCommand : public RopeCommand {
void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
};

class Llama3_1RopeCommand : public RopeCommand {
private:
RopeSlice* slice;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactory;
int ropeScalingOrigMaxSeqLen;
public:
Llama3_1RopeCommand(RopeSlice *slice, float ropeScalingFactor, float ropeScalingLowFreqFactor, float ropeScalingHighFreqFactory, int ropeScalingOrigMaxSeqLen);
void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
float scale(float freq);
};

class FalconRopeCommand : public RopeCommand {
private:
RopeSlice* slice;
Expand Down
35 changes: 33 additions & 2 deletions src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

#define IS_ROOT_SLICE(sliceIndex) (sliceIndex == 0)

#define USE_DISC_FOR_KV_CACHE 1

TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) {
TransformerSpec spec;
memset(&spec, 0, sizeof(TransformerSpec));
spec.hiddenAct = SILU;
spec.ropeType = ROPE_UNKNOWN;
spec.ropeTheta = 10000.0f;

FILE* fd = fopen(path, "rb");
Expand Down Expand Up @@ -68,6 +71,11 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i
else if (key == HIDDEN_ACT) spec.hiddenAct = (TransformerHiddenAct)value;
else if (key == ROPE_THETA) spec.ropeTheta = (float)value;
else if (key == WEIGHTS_FLOAT_TYPE) weightsFloatType = (FloatType)value;
else if (key == ROPE_SCALING_FACTOR) spec.ropeScalingFactor = (float)value;
else if (key == ROPE_SCALING_LOW_FREQ_FACTOR) spec.ropeScalingLowFreqFactor = (float)value;
else if (key == ROPE_SCALING_HIGH_FREQ_FACTORY) spec.ropeScalingHighFreqFactory = (float)value;
else if (key == ROPE_SCALING_ORIG_MAX_SEQ_LEN) spec.ropeScalingOrigMaxSeqLen = value;
else if (key == ROPE_TYPE) spec.ropeType = (TransformerRopeType)value;
else {
throw std::runtime_error("Unsupported header key");
}
Expand All @@ -79,6 +87,16 @@ TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned i
if (weightsFloatType == FUNK)
throw std::runtime_error("Not specified weights float type");

if (spec.ropeType == ROPE_UNKNOWN) {
if (spec.archType == LLAMA) {
spec.ropeType = ROPE_LLAMA;
} else if (spec.archType == GROK1 || spec.archType == MIXTRAL) {
spec.ropeType = ROPE_FALCON;
} else {
throw std::runtime_error("Cannot resolve rope type from architecture");
}
}

spec.headSize = spec.dim / spec.nHeads;
spec.kvDim = (spec.dim * spec.nKvHeads) / spec.nHeads;
spec.weightsFloatType = weightsFloatType;
Expand Down Expand Up @@ -223,10 +241,14 @@ Transformer::Transformer(TransformerSpec* spec, slice_index_t sliceIndex) {
}

ropeSlice = new RopeSlice(spec->dim, spec->kvDim, spec->nKvHeads, spec->nSlices, spec->seqLen, spec->headSize, spec->ropeTheta, sliceIndex);
if (spec->archType == GROK1 || spec->archType == MIXTRAL) {
if (spec->ropeType == ROPE_FALCON) {
rope = new FalconRopeCommand(ropeSlice);
} else {
} else if (spec->ropeType == ROPE_LLAMA) {
rope = new LlamaRopeCommand(ropeSlice);
} else if (spec->ropeType == ROPE_LLAMA3_1) {
rope = new Llama3_1RopeCommand(ropeSlice, spec->ropeScalingFactor, spec->ropeScalingLowFreqFactor, spec->ropeScalingHighFreqFactory, spec->ropeScalingOrigMaxSeqLen);
} else {
throw std::runtime_error("Unsupported rope type");
}

TransformerBlock* b = blocks[0];
Expand Down Expand Up @@ -276,8 +298,13 @@ TransformerBlock::TransformerBlock(TransformerSpec* spec, slice_index_t sliceInd
}

kvCacheSlice = new KvCacheSlice(spec->kvDim, spec->seqLen, spec->nSlices);
#if USE_DISC_FOR_KV_CACHE
keyCache = (float*)allocateWritableMmapBuffer(kvCacheSlice->keyCacheSize);
valueCache = (float*)allocateWritableMmapBuffer(kvCacheSlice->valueCacheSize);
#else
keyCache = (float*)newBuffer(kvCacheSlice->keyCacheSize);
valueCache = (float*)newBuffer(kvCacheSlice->valueCacheSize);
#endif

multiHeadAttSlice = new MultiHeadAttSlice(spec->nHeads, spec->seqLen, spec->nSlices, sliceIndex);
att = (float*)newBuffer(multiHeadAttSlice->attSize);
Expand Down Expand Up @@ -337,8 +364,12 @@ TransformerBlock::~TransformerBlock() {
}

delete kvCacheSlice;
#if USE_DISC_FOR_KV_CACHE
// TODO
#else
freeBuffer(keyCache);
freeBuffer(valueCache);
#endif
delete multiHeadAttSlice;
freeBuffer(att);

Expand Down
17 changes: 17 additions & 0 deletions src/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ enum TransformerHeaderKey {
HIDDEN_ACT = 11,
ROPE_THETA = 12,
WEIGHTS_FLOAT_TYPE = 13,
ROPE_SCALING_FACTOR = 14,
ROPE_SCALING_LOW_FREQ_FACTOR = 15,
ROPE_SCALING_HIGH_FREQ_FACTORY = 16,
ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17,
ROPE_TYPE = 18,
};

struct TransformerFileOldHeader {
Expand All @@ -47,6 +52,13 @@ enum TransformerHiddenAct {
SILU = 1,
};

enum TransformerRopeType {
ROPE_UNKNOWN = -1,
ROPE_LLAMA = 0,
ROPE_FALCON = 1,
ROPE_LLAMA3_1 = 2,
};

struct TransformerSpec {
size_t headerSize;
size_t fileSize;
Expand All @@ -65,6 +77,11 @@ struct TransformerSpec {
int kvDim;
int vocabSize;
float ropeTheta;
TransformerRopeType ropeType;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactory;
int ropeScalingOrigMaxSeqLen;

FloatType weightsFloatType;
FloatType bufferFloatType;
Expand Down
27 changes: 27 additions & 0 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,33 @@ void openMmapFile(MmapFile* file, const char* path, size_t size) {
#endif
}

int writableMmapIndex = 0;

void* allocateWritableMmapBuffer(size_t size) {
char path[256];
snprintf(path, 256, "mmap-%d", writableMmapIndex++);
int fd = open(path, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
if (fd == -1) {
perror("open");
exit(EXIT_FAILURE);
}
if (ftruncate(fd, size) == -1) {
perror("ftruncate");
close(fd);
exit(EXIT_FAILURE);
}
void *addr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if (addr == MAP_FAILED) {
perror("mmap");
close(fd);
exit(EXIT_FAILURE);
}
close(fd);
return addr;

// TODO: release somehow
}

void closeMmapFile(MmapFile* file) {
#ifdef _WIN32
UnmapViewOfFile(file->data);
Expand Down
2 changes: 2 additions & 0 deletions src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct MmapFile {
void openMmapFile(MmapFile* file, const char* path, size_t size);
void closeMmapFile(MmapFile* file);

void* allocateWritableMmapBuffer(size_t size);

typedef void (TaskLoopHandler)(unsigned int nThreads, unsigned int threadIndex, void* userData);
typedef struct {
TaskLoopHandler* handler;
Expand Down

0 comments on commit 17743c2

Please sign in to comment.