From e93d1e6356a5d192f0b81f14dc18f7772affdaf4 Mon Sep 17 00:00:00 2001 From: Bart Tadych Date: Thu, 9 May 2024 19:35:14 +0200 Subject: [PATCH] feat: splitting RoPE into all nodes. (#38) --- .github/workflows/main.yml | 3 + Makefile | 4 +- examples/macbeth.sh | 248 ++++++++++++++++++------------------- src/llama2-tasks.cpp | 18 +-- src/llama2-tasks.hpp | 2 +- src/transformer-test.cpp | 79 ++++++++++++ src/transformer.cpp | 84 ++++++++----- src/transformer.hpp | 16 ++- 8 files changed, 280 insertions(+), 174 deletions(-) create mode 100644 src/transformer-test.cpp diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6e699b3..5ca9c6d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,12 +31,15 @@ jobs: make main make funcs-test make quants-test + make transformer-test make llama2-tasks-test make grok1-tasks-test - name: funcs-test run: ./funcs-test - name: quants-test run: ./quants-test + - name: transformer-test + run: ./transformer-test - name: llama2-tasks-test run: ./llama2-tasks-test - name: grok1-tasks-test diff --git a/Makefile b/Makefile index b61e212..af09cd8 100644 --- a/Makefile +++ b/Makefile @@ -27,9 +27,11 @@ tokenizer: src/tokenizer.cpp main: src/main.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer $(CXX) $(CXXFLAGS) src/main.cpp -o main utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o -lpthread funcs-test: src/funcs-test.cpp funcs utils quants - $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o + $(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o -lpthread quants-test: src/quants.cpp utils quants $(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o -lpthread +transformer-test: src/transformer-test.cpp funcs utils quants transformer socket + $(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o -lpthread llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer $(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o -lpthread grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks tokenizer diff --git a/examples/macbeth.sh b/examples/macbeth.sh index f45b9a2..86aeaba 100644 --- a/examples/macbeth.sh +++ b/examples/macbeth.sh @@ -68,138 +68,128 @@ They smack of honour both. Go get him surgeons. 65 Who comes here?" GENERATED="Malcolm. The worthy Thane of Ross. -Duncan. What a haste looks through a troop? and when may -No sooner had this battle fought, than, ingrate and ungracious! 70 -Who leap'd my back, and thence hasten'd me away, -And follows so he wins. - -Malcolm. As I do live, my lord, so happily prosper I, without this blow might have o'erpaid the world: 75 -He loves our Majesty as boundlessly as we -Muster ourselves, and make a full battalion. -Duncan. Then enter, sir, and alone with me great battles 80 -I'll strain upon thy forehead, to this day -It is a faith, and makes the fire that burn in my veins: -Thou hast it now, king, afield to-morrow. -God be wi' you, father. - -Duncan. Farewell, farewell! or let me hear from you. 85 -[Exeunt] - -THE THIRD SCENE -Macbeth, Banquo, Ross, and Angus. -Macbeth. So fair and foul a day I have not seen. -It is calm, and yet is all together. -[Thunder, then rain] -Look, how the blood of Sweden flows from hence; -The time is free. I see the Capitol; 90 -The city of kingly eyes. -[Thunder, then lightening] -And the remote parts of Parliament, -Which now behold, now can behold no more 95 -In which time will appear how much -I have translated the flesh of Banquo -Into a crow that strays about the capital. -[Treble knocks] -The Prince of Cumberland; that is a step -On which I must fall down, or else o'erleap, 100 -For in my way it lies. Stars, hide your fires; -Let not light see my black and deep desires: -The eye wink at the hand; yet let that be -Which the eye fears to look upon. -Let it be a sin; 105 -That the so lusterous, and so bright, so good -Should but be seen a fellow to my crime; -And dupe so ruin'd. -[Treble knocks] -That my keen knife -See you satisfied. -[Treble knocks] 110 -Go to thy death. -[Re-enter a Servant] -How now, what names. -Servant. Name. Marrow, marrow; that is the very question that I put to thee, 115 -That is the very question that I put to thee, -Macbeth. Thou 'rt mad that thy sword is not temper'd. -What you lack of temperance, that I lack in valour. -Art not without ambition; but without 120 -The illness should attend it. What thou wouldst highly, -That wouldst thou boldly, and with thy virtues else -Wouldst thou have wildly holden; let fall thy hand; -[To the Servant] -Kilt him like a boar. 125 -Macbeth. From this time these woes we will re-assume: -Not from our fingers' ends? We still have left -A special will to thrust these thorns more firmly. -A little more the wisely. Gently the weather. 130 -[Treble knocks] -The Prince of Cumberland; that is a step -On which I must fall down, or else o'erleap, -For in my way it lies. Stars, hide your fires; -Let not light see my black and deep desires: 135 -The eye wink at the hand; yet let that be -Which the eye fears to look upon. -Let it be a sin; that the so lusterous, and so bright, so good -Should but be seen a fellow to my crime; 140 -And dupe so ruin'd. -[Treble knocks] -That my keen knife -See you satisfied. -[Treble knocks] -Go to thy death. -[Re-enter a Servant] -How now, what names. -Servant. Name. Marrow, marrow; that is the very question that I put to thee, 145 -That is the very question that I put to thee, -Macbeth. Thou 'rt mad that thy sword is not temper'd. -What you lack of temperance, that I lack in valour. -Art not without ambition; but without 150 -The illness should attend it. What thou wouldst highly, -That wouldst thou boldly, and with thy virtues else -Wouldst thou have wildly holden; let fall thy hand; -[To the Servant] -Kilt him like a boar. 155 -Macbeth. From this time these woes we will re-assume: -Not from our fingers' ends? We still have left -A special will to thrust these thorns more firmly. -A little more the wisely. Gently the weather. 160 -[Treble knocks] -Come, love, and we will a while chastise -That dares come to this. -[Re-enter a second Servant] -What is that which caugh your eyes? 165 -Second Servant. My young lord, I can tell. -To think that they may see such sights! -And yet not be the eyes itself that see but, as 'tis said, a man should be the righter part of nature; if he be such, he need not -come behindhand too. 170 -'Tis no time to cloak our faults. -[Re-enter a third Servant] -The very firstlings of my heart shall be -The firstlings of my head; I'll be their patriarch. -Come, put on gaiter; come, come, good mother, 175 -Damned entrance of weather! -[Thunder] -Come, get you to my woman's breasts; And on them give, and mercy onen me, let fall your holy disinclinations. -[Exeunt] -Act III. SCENE 1. -The scene opens with the arrival of the King and his entourage at the castle of the thane of Fife. King Duncan, having heard of Macbeth's new successes, asks his thanes to rejoice with him. Macbeth's great respect for the king makes him slightly uncomfortable. King Duncan's concern for Macbeth's wife and children is further evidence of the king's warmth and loving nature. Macbeth appears ill at ease, perhaps at Duncan's evident concern. His language is overly formal and self-conscious, while his wife speaks rather bluntly. After Macbeth, the king, and his attendants enter, Banquo asks how Lady Macbeth is. - -Macbeth. When I am gone, 180 -After life's fitful fever, he sleeps well, -Though the powers of the strong world do set themselves -Against his estate. -King Duncan. So well to do! -Had he his heart's desire, he 'd stoop -To what humility 185 -Might become the matter. -Macbeth. As the matter now I 've put it. -King Duncan. Well then, 190 -Since that you are a father, show the child -The taking off, and that which now you do 195 -Commit" +Duncan. What a haste looks through a duel's wounds! 70 +Some must be pac'd. +[Exit Ross] +See this encounter is like to the poring +On of a beggar's story, told by one +That means to pluck upon the heart the strings +And draw the tears thriftily. 75 +[Enter Lennox] +How goes the night, boy? + +Lennox. The night is long that none should wake. + +Duncan. You do not need to stare. The Moor +To know the man. 'Tis the Moors devices. 80 +[Exit Lennox] +By the happy right of mine own hands, +Strike all that live in this poor thing of mine. +'Tis calld the Eyrie, and I am sick at heart. +As hellish-devils do the damned souls +O'their bad lives, thus ill-breveted, linger +O'er lamps and forks and other instruments +That prove the stages of the night. 90 +Good sir, take note; I bid you farewell: +Come sleep, and cut short this nitty romance. +[He sleeps.] +If cravens, I bear them like the Minion of the moon, +With tiptoe foot he sneaks and starts to be a man. 95 +And when he is found asleep, awake him with this armed' s address: +That sleep which th'assassin hallowed, +Scotland, awake; your king is murder'd, sleep no more. 100 +*Furbish'd. Weapons polished for battle. +*Thriftily. Fastidiously, thoughtfully. +*Eyrie. Fortress; the lair of birds of prey. +*Minion. A braggart, a coward. + +1.5 + +Macbeth. So foul and fair a day I have not seen. 5 +Ross. Good morning, noble Macbeth. I come from Inverness, +And find our throne void, the arm'd rest you; 10 +My Lord of Cassil has resigned his life. +Macbeth. Whate'er you owe, in time repay, fair friends. +Note you the words; I pray you do. +Ross. I am your faithful servant, and will keep +My sworn reward upon your life; my lord. +Macbeth. You shall be well rewarded; stay the press, 20 +And I'll not fail. How now, good fellow? +Servant. Sir, his schoolmaster. 25 +Macbeth. Well, good, though, old. +Tell me, good fellow, how goes the night? 30 +Servant. There's marrygold and fire in your veins, my lord. +Macbeth. He does commend you; the weight of this old night's embargoes 35 +Did one hour's waste of time lay upon him. +I know when we are too safe, 'tis dangerous to be secure; +Therefore our fearful parts do brave the danger 40 +Which knows it not. I see you are a gentleman. +And a laudable one too; I am most off obliged. +Servant. I should be sorry, my good lord, to have had the labour 45 +To outlive this damned hour. 50 +Macbeth. What's done cannot be undone. To bed, to bed, to bed. +Servant. Will it please you to lie still? 55 +Macbeth. Lord, lord, my heart is in my mouth. All's true that ends well. +Servant. I thank you, fair, and leave you to the content. 60 +Macbeth. You see, my lord, it smokes, and shows no cause +Why the drone dies. 65 +Servant. Grief fills the room up of one vast stair, +And downs our vaults to the inconstant man above. 70 +Macbeth. Go bid thy masters and thy mistress say, 75 +I have power in earth to do so much. +There's comfort yet. They are assailable. Then say I, +Thus ye may answer. +Servant. He cannot be wronged; or being wronged, 80 +I cannot help him. 85 +Macbeth. You know but by this; as this, 90 +The Jew foole is hang'd. 95 +Servant. No more today, my lord. 100 +Macbeth. He does shame to tell him he loves him, but not remove him 105 +From his true place; no. +Servant. That's true, and now I remember the story 110 +Of that sign in Leo four diurnal courses +Returning in a constant motion were within 115 +A boare that had on Taurus' back tetracted; 120 +Or neuer, or but once in modulated accidence. 125 +Macbeth. Thou climd'st alone, ty'd to the stag's horn. +Servant. I was a bull, for this the goodly year. 130 +Come, put me in my place. +Macbeth. Now go to sleep. 135 +Servant. The west neuer sett before the equinox 140 +Till now; and sunnes look'd not theyr frequencie 145 +Upon our lappe till now, my lord. 150 +Macbeth. This game of chance you term a gong. +Servant. A gong is a scotch word for an egg. 155 +Macbeth. Peace, be still. 160 +Servant. I coniecture I smell the blood of an Englishman. 165 +Macbeth. The faith is murthered. +Servant. That murder'd in his sleep. 170 +Macbeth. And sleeping murdered. 175 +Servant. In the fair queen heere in his royal court. 180 +Macbeth. So great a mercy that it may last eternally. +Servant. The earth hath bubbles as the water hath, 185 +And these are of them. Whate'er we will do 190 +To mend the trespasses of the comming time 195 +Shall be the seedes of new mischefe, and shall beget 200 +The formes of the extinctnese, which we are now. 205 +Macbeth. We have scorch'd the snake, not kill'd it. 210 +Servant. They hunt it in the morn. Good gally, good lord! 215 +It weares a gilded snout. 220 +Macbeth. It is the very painting of your fear. 225 +Servant. This is the worst. 230 +Macbeth. A fair quater of a mile is yet to go. 235 +Servant. A mile and half. 240 +Macbeth. I have run fifteen miles to-day. +Servant. A calender's date. +Macbeth. A bigger patch, a bigger patch. 245 +Servant. Thirteen of more. 250 +Macbeth. Wast thou with him? 255 +Servant. No, nor he to night. 260 +Macbeth. Thou seest the moon" echo "Generating, it can take a while..." -OUTPUT=$(( ./main generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type q80 --nthreads 8 --steps 2048 --model converter/dllama_meta-llama-3-8b_q40.bin --tokenizer converter/dllama_meta-llama3-tokenizer.t ) 2>&1) +OUTPUT=$(( ./main generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 8 --steps 2048 --model converter/dllama_meta-llama-3-8b_q40.bin --tokenizer converter/dllama_meta-llama3-tokenizer.t ) 2>&1) echo "$OUTPUT" diff --git a/src/llama2-tasks.cpp b/src/llama2-tasks.cpp index c0b6135..01d651d 100644 --- a/src/llama2-tasks.cpp +++ b/src/llama2-tasks.cpp @@ -43,6 +43,14 @@ void llamaQkv(TASK_ARGS) { matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex); } +void llamaRope(TASK_ARGS) { + TASK_VARIABLES; + float* q = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex); + float* k = (float*)transformer->buffer->getSliced(TB_SLICED_K, transformer->sliceIndex); + transformer->ropeSlice->forward(true, q, transformer->pos, nThreads, threadIndex); + transformer->ropeSlice->forward(false, k, transformer->pos, nThreads, threadIndex); +} + void llamaQuantizeQkv(TASK_ARGS) { TASK_VARIABLES; quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_Q, TB_SLICED_Q_QUANTIZED); @@ -76,13 +84,6 @@ void llamaMultiheadAtt(TASK_ARGS) { } } -void llamaMultiheadAttRope(TASK_ARGS) { - TASK_VARIABLES; - float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q); - float* k = block->keyCache + transformer->pos * spec->kvDim; - rope(transformer->ropeCache, q, k, spec, transformer->pos, nThreads, threadIndex); -} - void llamaMultiheadAttJoin(TASK_ARGS) { TASK_VARIABLES; float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q); @@ -293,11 +294,11 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) { a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); a.I(llamaQkv, TASK_TYPE_INFERENCE); + a.I(llamaRope, TASK_TYPE_INFERENCE); a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaSyncQkv, TASK_TYPE_TRANSFER); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); - a.I(llamaMultiheadAttRope, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAttJoin, TASK_TYPE_INFERENCE); a.I(llamaQuantizeMultiheadAtt, TASK_TYPE_INFERENCE); a.I(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER); @@ -330,6 +331,7 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) { for (int i = 0; i < spec->nLayers; i++) { a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); a.W(llamaQkv, TASK_TYPE_INFERENCE); + a.W(llamaRope, TASK_TYPE_INFERENCE); a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE); a.W(llamaSyncQkv, TASK_TYPE_TRANSFER); a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER); diff --git a/src/llama2-tasks.hpp b/src/llama2-tasks.hpp index 91ddafb..d40910e 100644 --- a/src/llama2-tasks.hpp +++ b/src/llama2-tasks.hpp @@ -8,11 +8,11 @@ void llamaRmsAttNorm(TASK_ARGS); void llamaQuantizeRmsAtt(TASK_ARGS); void llamaSyncRmsAtt(TASK_ARGS); void llamaQkv(TASK_ARGS); +void llamaRope(TASK_ARGS); void llamaQuantizeQkv(TASK_ARGS); void llamaSyncQkv(TASK_ARGS); void llamaDequantizeQkv(TASK_ARGS); void llamaMultiheadAtt(TASK_ARGS); -void llamaMultiheadAttRope(TASK_ARGS); void llamaMultiheadAttJoin(TASK_ARGS); void llamaQuantizeMultiheadAtt(TASK_ARGS); void llamaSyncMultiheadAtt(TASK_ARGS); diff --git a/src/transformer-test.cpp b/src/transformer-test.cpp new file mode 100644 index 0000000..8505a52 --- /dev/null +++ b/src/transformer-test.cpp @@ -0,0 +1,79 @@ +#include "transformer.hpp" +#include +#include +#include + +void testRopeSlice() { + TransformerSpec spec; + spec.dim = 4096; + spec.headSize = 128; + spec.nKvHeads = 8; + spec.seqLen = 2048; + spec.nHeads = spec.dim / spec.headSize; + spec.kvDim = (spec.dim * spec.nKvHeads) / spec.nHeads; + spec.ropeTheta = 10000.0f; + + float* q = new float[spec.dim]; + float* k = new float[spec.kvDim]; + float* correctQ = new float[spec.dim]; + float* correctK = new float[spec.kvDim]; + const int nSliceTests = 5; + const int nPosTests = 6; + const int nThreadTests = 3; + + for (int pos = 0; pos < spec.seqLen; pos += spec.seqLen / nPosTests) { + for (int si = 0; si < nSliceTests; si++) { + spec.nSlices = pow(2, si); + + for (int nThreads = 1; nThreads <= nThreadTests; nThreads++) { + printf("pos=%d slices=%d threads=%d\n", pos, spec.nSlices, nThreads); + + for (int j = 0; j < spec.dim; j++) q[j] = (j / (float)spec.dim); + for (int j = 0; j < spec.kvDim; j++) k[j] = (j / (float)spec.kvDim); + + for (uint8_t sliceIndex = 0; sliceIndex < spec.nSlices; sliceIndex++) { + RopeSlice slice(&spec, sliceIndex); + for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) { + slice.forward( + true, + &q[sliceIndex * spec.dim / spec.nSlices], + pos, nThreads, threadIndex); + slice.forward( + false, + &k[sliceIndex * spec.kvDim / spec.nSlices], + pos, nThreads, threadIndex); + } + } + + if (si == 0 && nThreads == 1) { + memcpy(correctQ, q, spec.dim * sizeof(float)); + memcpy(correctK, k, spec.kvDim * sizeof(float)); + } else { + for (int j = 0; j < spec.dim; j++) { + if (fabs(q[j] - correctQ[j]) > 1e-6) { + printf("q[%d] mismatch: %f != %f\n", j, q[j], correctQ[j]); + exit(EXIT_FAILURE); + } + } + for (int j = 0; j < spec.kvDim; j++) { + if (fabs(k[j] - correctK[j]) > 1e-6) { + printf("k[%d] mismatch: %f != %f\n", j, k[j], correctK[j]); + exit(EXIT_FAILURE); + } + } + } + } + } + } + + delete[] q; + delete[] k; + delete[] correctQ; + delete[] correctK; + printf("✅ ropeSlice\n"); +} + +int main() { + testRopeSlice(); + return 0; +} diff --git a/src/transformer.cpp b/src/transformer.cpp index 30003c2..81041a3 100644 --- a/src/transformer.cpp +++ b/src/transformer.cpp @@ -43,39 +43,60 @@ size_t MatmulSlice::splitWeights(uint8_t sliceIndex, char* weights, char* weight return copiedBytes; } -void initRope(float* cache, TransformerSpec* spec) { +RopeSlice::RopeSlice(TransformerSpec* spec, uint8_t sliceIndex) { + assert(spec->dim >= spec->kvDim); + assert(spec->dim % spec->nSlices == 0); + assert(spec->kvDim % spec->nSlices == 0); + + qDim0 = spec->dim / spec->nSlices; + kvDim0 = spec->kvDim / spec->nSlices; + assert(qDim0 % 2 == 0); + assert(kvDim0 % 2 == 0); + int kvDim0From = kvDim0 * sliceIndex; + int qDim0From = qDim0 * sliceIndex; + int qDim0To = qDim0From + qDim0; + qOffset = qDim0From - kvDim0From; + cacheDim = qDim0To - kvDim0From; + assert(cacheDim % 2 == 0); + + size_t cacheBytes = spec->seqLen * cacheDim * sizeof(float); + cache = (float*)NEW_BUFFER(cacheBytes); + printf("🕒 ropeCache: %ld kB\n", cacheBytes / 1024); + for (pos_t pos = 0; pos < spec->seqLen; pos++) { - for (int i = 0; i < spec->dim; i += 2) { - int head_dim = i % spec->headSize; - float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize); + for (int i = kvDim0From; i < qDim0To; i += 2) { + int headDim = i % spec->headSize; + float freq = 1.0f / powf(spec->ropeTheta, headDim / (float)spec->headSize); float val = pos * freq; float fcr = cosf(val); float fci = sinf(val); - cache[pos * spec->dim + i] = fcr; - cache[pos * spec->dim + i + 1] = fci; + cache[pos * cacheDim + (i - kvDim0From)] = fcr; + cache[pos * cacheDim + (i - kvDim0From) + 1] = fci; } } } -void rope(float* cache, float* q, float* k, TransformerSpec* spec, pos_t pos, unsigned int nThreads, unsigned int threadIndex) { - int halfDim = spec->dim / 2; - int slice = halfDim / nThreads; +RopeSlice::~RopeSlice() { + FREE_BUFFER(cache); +} + +void RopeSlice::forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex) { + int d0 = isQ ? qDim0 : kvDim0; + int offset = isQ ? qOffset : 0; + int halfD0 = d0 / 2; + int slice = halfD0 / nThreads; int iStart = threadIndex * slice; - int iEnd = ((nThreads - 1 == threadIndex) ? halfDim : (iStart + slice)) * 2; + int iEnd = (nThreads - 1 == threadIndex) ? halfD0 : (iStart + slice); iStart *= 2; + iEnd *= 2; - // RoPE relative positional encoding: complex-valued rotate q and k in each head for (int i = iStart; i < iEnd; i += 2) { - float fcr = cache[pos * spec->dim + i]; - float fci = cache[pos * spec->dim + i + 1]; - int rotn = i < spec->kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only - for (int _v = 0; _v < rotn; _v++) { - float* vec = _v == 0 ? q : k; // the vector to rotate (query or key) - float v0 = vec[i]; - float v1 = vec[i+1]; - vec[i] = v0 * fcr - v1 * fci; - vec[i+1] = v0 * fci + v1 * fcr; - } + float fcr = cache[pos * cacheDim + offset + i]; + float fci = cache[pos * cacheDim + offset + i + 1]; + float v0 = qOrK[i]; + float v1 = qOrK[i+1]; + qOrK[i] = v0 * fcr - v1 * fci; + qOrK[i+1] = v0 * fci + v1 * fcr; } } @@ -280,14 +301,13 @@ Transformer::Transformer(TransformerSpec* spec, uint8_t sliceIndex) { #endif x = (float*)NEW_BUFFER(spec->dim * sizeof(float)); logits = (float*)NEW_BUFFER(spec->vocabSize * sizeof(float)); + } - // TODO: cache should be for all architectures - if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) { - ropeCache = (float*)NEW_BUFFER(spec->seqLen * spec->dim * sizeof(float)); - initRope(ropeCache, spec); - } else { - ropeCache = NULL; - } + // TODO: cache should be for all architectures + if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) { + ropeSlice = new RopeSlice(spec, sliceIndex); + } else { + ropeSlice = NULL; } } @@ -306,10 +326,10 @@ Transformer::~Transformer() { #endif FREE_BUFFER(x); FREE_BUFFER(logits); + } - if (ropeCache != NULL) { - FREE_BUFFER(ropeCache); - } + if (ropeSlice != NULL) { + delete ropeSlice; } } @@ -572,7 +592,7 @@ Transformer Transformer::loadRoot(char* data, TransformerSpec* spec, SocketPool* exit(EXIT_FAILURE); } - printf("⏩ Loaded %ld bytes\n", (long)(w - data)); + printf("⏩ Loaded %ld kB\n", (long)(w - data) / 1024); return transformer; } diff --git a/src/transformer.hpp b/src/transformer.hpp index 2df3fa8..0e655d7 100644 --- a/src/transformer.hpp +++ b/src/transformer.hpp @@ -84,8 +84,18 @@ struct TransformerSpec { uint8_t nSlices; }; -void initRope(float* cache, TransformerSpec* spec); -void rope(float* cache, float* q, float* k, TransformerSpec* spec, pos_t pos, unsigned int nThreads, unsigned int threadIndex); +class RopeSlice { +private: + float* cache; + int cacheDim; + int qDim0; + int qOffset; + int kvDim0; +public: + RopeSlice(TransformerSpec* spec, uint8_t sliceIndex); + ~RopeSlice(); + void forward(bool isQ, float* qOrK, pos_t pos, unsigned int nThreads, unsigned int threadIndex); +}; class TransformerBlock { public: @@ -190,7 +200,7 @@ class Transformer { float rms; float* x; float* logits; - float* ropeCache; + RopeSlice* ropeSlice; ~Transformer();