Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: accelerator structure. #90

Merged
merged 8 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
make funcs-test
make quants-test
make tokenizer-test
make transformer-test
make commands-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
Expand All @@ -40,8 +40,8 @@ jobs:
run: ./quants-test
- name: tokenizer-test
run: ./tokenizer-test
- name: transformer-test
run: ./transformer-test
- name: commands-test
run: ./commands-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
Expand All @@ -64,7 +64,7 @@ jobs:
make funcs-test
make quants-test
make tokenizer-test
make transformer-test
make commands-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
Expand All @@ -73,8 +73,8 @@ jobs:
run: ./quants-test
- name: tokenizer-test
run: ./tokenizer-test
- name: transformer-test
run: ./transformer-test
- name: commands-test
run: ./commands-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
Expand Down
26 changes: 14 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ funcs: src/funcs.cpp
$(CXX) $(CXXFLAGS) -c src/funcs.cpp -o funcs.o
funcs-test: src/funcs-test.cpp funcs
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o
commands: src/commands.cpp
$(CXX) $(CXXFLAGS) -c src/commands.cpp -o commands.o
socket: src/socket.cpp
$(CXX) $(CXXFLAGS) -c src/socket.cpp -o socket.o
transformer: src/utils.cpp
Expand All @@ -33,20 +35,20 @@ tokenizer: src/tokenizer.cpp
app: src/app.cpp
$(CXX) $(CXXFLAGS) -c src/app.cpp -o app.o

dllama: src/apps/dllama/dllama.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)

funcs-test: src/funcs-test.cpp funcs utils quants
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o $(LIBS)
quants-test: src/quants.cpp utils quants
$(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o $(LIBS)
tokenizer-test: src/tokenizer-test.cpp tokenizer funcs utils quants
$(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o utils.o quants.o $(LIBS)
transformer-test: src/transformer-test.cpp funcs utils quants transformer socket
$(CXX) $(CXXFLAGS) src/transformer-test.cpp -o transformer-test funcs.o utils.o quants.o transformer.o socket.o $(LIBS)
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer
$(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS)
grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks tokenizer
$(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS)
tokenizer-test: src/tokenizer-test.cpp tokenizer funcs commands utils quants
$(CXX) $(CXXFLAGS) src/tokenizer-test.cpp -o tokenizer-test tokenizer.o funcs.o commands.o utils.o quants.o $(LIBS)
commands-test: src/commands-test.cpp funcs commands utils quants transformer socket
$(CXX) $(CXXFLAGS) src/commands-test.cpp -o commands-test funcs.o commands.o utils.o quants.o transformer.o socket.o $(LIBS)
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks tokenizer
$(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS)
grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks tokenizer
$(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS)
2 changes: 1 addition & 1 deletion examples/macbeth.sh
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Macbeth. Thou seest the moon"

echo "Generating, it can take a while..."

OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 8 --steps 2048 --model converter/dllama_meta-llama-3-8b_q40.bin --tokenizer converter/dllama_meta-llama3-tokenizer.t ) 2>&1)
OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 2 --steps 2048 --model models/llama3_8b_q40/dllama_model_llama3_8b_q40.m --tokenizer models/llama3_8b_q40/dllama_tokenizer_llama3_8b_q40.t --workers 127.0.0.1:9999 127.0.0.1:9998 127.0.0.1:9997 ) 2>&1)

echo "$OUTPUT"

Expand Down
7 changes: 4 additions & 3 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
exit(EXIT_FAILURE);
}

void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec)) {
void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc)) {
if (args->modelPath == NULL) {
throw std::runtime_error("Model is required");
}
Expand All @@ -119,14 +119,15 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
args->steps = spec.seqLen;
}

Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool);
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, socketPool, &acc);
socketPool->setTurbo(true);

Inference inference = Inference(&arch, args->nThreads, &transformer, socketPool);

Sampler sampler(spec.vocabSize, args->temperature, args->topp, args->seed);

program(&inference, socketPool, &tokenizer, &sampler, args, &spec);
program(&inference, socketPool, &tokenizer, &sampler, args, &spec, &acc);

delete socketPool;
}
8 changes: 4 additions & 4 deletions src/app.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#ifndef FUNCS_HPP
#define FUNCS_HPP
#ifndef APP_HPP
#define APP_HPP

#include "quants.hpp"
#include "transformer.hpp"
#include "utils.hpp"
#include "socket.hpp"
#include "utils.hpp"
#include "app.hpp"
#include "transformer.hpp"
#include "tasks.hpp"
Expand Down Expand Up @@ -46,7 +46,7 @@ class TransformerArchFactory {

class App {
public:
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec));
static void run(AppArgs* args, void (*program)(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc));
};

#endif
2 changes: 1 addition & 1 deletion src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void handleModelsRequest(HttpRequest& request) {
"] }");
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
SocketServer* server = new SocketServer(args->port);

TokenizerChatStops stops(tokenizer);
Expand Down
7 changes: 4 additions & 3 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "../../tokenizer.hpp"
#include "../../app.hpp"

void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
if (args->prompt == NULL)
throw std::runtime_error("Prompt is required");

Expand Down Expand Up @@ -193,7 +193,7 @@ class Chat {
}
};

void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec, AcceleratorContext* acc) {
TokenizerChatStops stops(tokenizer);
ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
Expand All @@ -210,7 +210,8 @@ void worker(AppArgs* args) {
SocketServer server(args->port);
Socket socket = server.accept();
TransformerSpec spec;
Transformer transformer = Transformer::loadSlice(&spec, &socket);
AcceleratorContext acc(0, 1, NULL);
Transformer transformer = Transformer::loadSlice(&spec, &socket, &acc);
TransformerArch arch = TransformerArchFactory::create(&spec);

Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
Expand Down
85 changes: 85 additions & 0 deletions src/commands-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "commands.hpp"
#include <cmath>
#include <cstdio>
#include <cstring>

void testRopeSlice(int arch, const int nSliceTests, const int nPosTests, const int nThreadTests) {
int dim = 4096;
int headSize = 128;
int nKvHeads = 8;
int seqLen = 2048;
int nHeads = dim / headSize;
int kvDim = (dim * nKvHeads) / nHeads;
int ropeTheta = 10000.0f;

float* q = new float[dim];
float* k = new float[kvDim];
float* correctQ = new float[dim];
float* correctK = new float[kvDim];

for (int pos = 0; pos < seqLen; pos += seqLen / nPosTests) {
for (int si = 0; si < nSliceTests; si++) {
int nSlices = pow(2, si);

for (int nThreads = 1; nThreads <= nThreadTests; nThreads++) {
printf("pos=%d nSlices=%d threads=%d\n", pos, nSlices, nThreads);

for (int j = 0; j < dim; j++) q[j] = 1.0;
for (int j = 0; j < kvDim; j++) k[j] = 1.0;

for (slice_index_t sliceIndex = 0; sliceIndex < nSlices; sliceIndex++) {
RopeSlice slice(dim, kvDim, nKvHeads, nSlices, seqLen, headSize, ropeTheta, sliceIndex);
RopeCommand* rope;
if (arch == 1) {
rope = new LlamaRopeCommand(&slice);
} else if (arch == 2) {
rope = new FalconRopeCommand(&slice);
}

for (int threadIndex = 0; threadIndex < nThreads; threadIndex++) {
rope->forward(
true,
&q[(sliceIndex * dim) / nSlices],
pos, nThreads, threadIndex);
rope->forward(
false,
&k[(sliceIndex * kvDim) / nSlices],
pos, nThreads, threadIndex);
}

delete rope;
}

if (si == 0 && nThreads == 1) {
memcpy(correctQ, q, dim * sizeof(float));
memcpy(correctK, k, kvDim * sizeof(float));
} else {
for (int j = 0; j < dim; j++) {
if (fabs(q[j] - correctQ[j]) > 1e-6) {
printf("q[%d] mismatch: %f != %f (arch=%d)\n", j, q[j], correctQ[j], arch);
exit(EXIT_FAILURE);
}
}
for (int j = 0; j < kvDim; j++) {
if (fabs(k[j] - correctK[j]) > 1e-6) {
printf("k[%d] mismatch: %f != %f (arch=%d)\n", j, k[j], correctK[j], arch);
exit(EXIT_FAILURE);
}
}
}
}
}
}

delete[] q;
delete[] k;
delete[] correctQ;
delete[] correctK;
printf("✅ ropeSlice (arch=%d)\n", arch);
}

int main() {
testRopeSlice(2, 4, 6, 3);
testRopeSlice(1, 6, 4, 3);
return 0;
}
Loading
Loading