Skip to content

Commit

Permalink
Merge pull request #12 from l3utterfly/merge-master-layla
Browse files Browse the repository at this point in the history
Merge from upstream
  • Loading branch information
l3utterfly authored Apr 25, 2024
2 parents ef64937 + 7a92678 commit c259b80
Show file tree
Hide file tree
Showing 22 changed files with 1,992 additions and 337 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,9 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m
- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a`
- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions
- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices
- Matrix multiplication is unconventional: [`z = ggml_mul_mat(ctx, x, y)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means `zT = x @ yT`
- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$
![matmul](media/matmul.png)
### Docs
Expand Down
8 changes: 5 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
invalid_param = true;
return true;
}
// This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
return true;
}
if (arg == "-t" || arg == "--threads") {
Expand Down Expand Up @@ -2359,12 +2361,12 @@ std::vector<llama_token> llama_tokenize(
return result;
}

std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), true);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), true);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
9 changes: 5 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ struct gpt_params {

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down Expand Up @@ -238,11 +238,12 @@ std::vector<llama_token> llama_tokenize(
bool add_special,
bool parse_special = false);

// tokenizes a token into a piece
// tokenizes a token into a piece, optionally renders special/control tokens
// should work similar to Python's `tokenizer.id_to_piece`
std::string llama_token_to_piece(
const struct llama_context * ctx,
llama_token token);
llama_token token,
bool special = true);

// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
// that takes into account the tokenizer type and decides how to handle the leading space
Expand Down
13 changes: 12 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#define LLAMA_API_INTERNAL
#include "sampling.h"
#include <random>

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
Expand Down Expand Up @@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

result->prev.resize(params.n_prev);

llama_sampling_set_rng_seed(result, params.seed);

return result;
}

Expand Down Expand Up @@ -66,6 +70,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}

void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL);
}
ctx->rng.seed(seed);
}

void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
Expand Down Expand Up @@ -207,7 +218,7 @@ static llama_token llama_sampling_sample_impl(

sampler_queue(ctx_main, params, cur_p, min_keep);

id = llama_sample_token(ctx_main, &cur_p);
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);

//{
// const int n_top = 10;
Expand Down
47 changes: 27 additions & 20 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

#include "grammar-parser.h"

#include <random>
#include <string>
#include <vector>
#include <unordered_map>
#include <vector>

// sampler types
enum class llama_sampler_type : char {
Expand All @@ -20,25 +21,26 @@ enum class llama_sampler_type : char {

// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f
float dry_base = 1.75f;
int dry_allowed_length = 2;
Expand Down Expand Up @@ -83,6 +85,8 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;

std::mt19937 rng;
};

#include "common.h"
Expand All @@ -100,6 +104,9 @@ void llama_sampling_reset_grammar(struct llama_sampling_context * ctx);
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);

// Set the sampler seed
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);

// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);

Expand Down
101 changes: 101 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ def _set_vocab_sentencepiece(self):
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)

if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
print(
f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]"
)
for i in range(1, pad_count + 1):
tokens.append(f"[PAD{i}]")
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)

assert len(tokens) == vocab_size

self.gguf_writer.add_tokenizer_model("llama")
Expand Down Expand Up @@ -1789,6 +1799,12 @@ def write_tensors(self):
class Qwen2Model(Model):
model_arch = gguf.MODEL_ARCH.QWEN2

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()


@Model.register("Qwen2MoeForCausalLM")
class Qwen2MoeModel(Model):
Expand Down Expand Up @@ -1979,6 +1995,91 @@ def set_gguf_parameters(self):
self.gguf_writer.add_add_bos_token(False)


@Model.register("Phi3ForCausalLM")
class Phi3MiniModel(Model):
model_arch = gguf.MODEL_ARCH.PHI3

def set_vocab(self):
from sentencepiece import SentencePieceProcessor

tokenizer_path = self.dir_model / 'tokenizer.model'

if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
sys.exit(1)

tokenizer = SentencePieceProcessor(str(tokenizer_path))

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size

for token_id in range(tokenizer.vocab_size()):

piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype

added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)

for key in added_tokens_json:
token_id = added_tokens_json[key]
if (token_id >= vocab_size):
print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue

tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])

rot_pct = 1.0
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
rms_eps = self.find_hparam(["rms_norm_eps"])

self.gguf_writer.add_name("Phi3")
self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))

self.gguf_writer.add_embedding_length(n_embd)
self.gguf_writer.add_feed_forward_length(8192)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_file_type(self.ftype)


@Model.register("PlamoForCausalLM")
class PlamoModel(Model):
model_arch = gguf.MODEL_ARCH.PLAMO
Expand Down
1 change: 0 additions & 1 deletion examples/lookup/lookup-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ int main(int argc, char ** argv){

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
Expand Down
1 change: 0 additions & 1 deletion examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ int main(int argc, char ** argv){

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
Expand Down
1 change: 0 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
return 1;
}
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
}
}
Expand Down
10 changes: 5 additions & 5 deletions examples/server/public/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -881,11 +881,11 @@
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/^#{1,6} (.*)$/gim, '<h3>$1</h3>')
.replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')
.replace(/__(.*?)__/g, '<strong>$1</strong>')
.replace(/\*(.*?)\*/g, '<em>$1</em>')
.replace(/_(.*?)_/g, '<em>$1</em>')
.replace(/(^|\n)#{1,6} ([^\n]*)(?=([^`]*`[^`]*`)*[^`]*$)/g, '$1<h3>$2</h3>')
.replace(/\*\*(.*?)\*\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '<strong>$1</strong>')
.replace(/__(.*?)__(?=([^`]*`[^`]*`)*[^`]*$)/g, '<strong>$1</strong>')
.replace(/\*(.*?)\*(?=([^`]*`[^`]*`)*[^`]*$)/g, '<em>$1</em>')
.replace(/_(.*?)_(?=([^`]*`[^`]*`)*[^`]*$)/g, '<em>$1</em>')
.replace(/```.*?\n([\s\S]*?)```/g, '<pre><code>$1</code></pre>')
.replace(/`(.*?)`/g, '<code>$1</code>')
.replace(/\n/gim, '<br />');
Expand Down
5 changes: 2 additions & 3 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
slot.params.seed = json_value(data, "seed", default_params.seed);
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);

Expand Down Expand Up @@ -1028,7 +1028,6 @@ struct server_context {
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
llama_set_rng_seed(ctx, slot.params.seed);
}

slot.command = SLOT_COMMAND_LOAD_PROMPT;
Expand Down Expand Up @@ -1118,7 +1117,7 @@ struct server_context {

bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = llama_token_to_piece(ctx, result.tok);
const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
slot.sampled = result.tok;

// search stop word and delete it
Expand Down
Loading

0 comments on commit c259b80

Please sign in to comment.