Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Jan 4, 2025
1 parent 3744bc4 commit a8153cc
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 48 deletions.
1 change: 1 addition & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf";
params.vocoder.hf_repo = "ggml-org/WavTokenizer";
params.vocoder.hf_file = "WavTokenizer-Large-75-F16.gguf";
params.ctx_shift = false; // for better results
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));

Expand Down
50 changes: 47 additions & 3 deletions examples/server/public_tts/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<title>llama.cpp TTS</title>
<style>
body {
font-family: 'Courier New', Courier, monospace;
font-family: monospace;
margin: 2em;
}
</style>
Expand All @@ -16,10 +16,15 @@ <h1>llama.cpp TTS</h1>

Input text:<br/>
<textarea id="input" rows="4" cols="50">Hello world</textarea><br/>
<button id="btn_speak" onclick="speak()">Speak</button>
<button id="btn_speak" onclick="speak()">Speak</button><br/>
<br/>
<p id="status">Status: ready</p><br/>
<p id="output"></p>

<script>
const input_el = document.getElementById('input');
const output_el = document.getElementById('output');
const status_el = document.getElementById('status');
const btn_speak_el = document.getElementById('btn_speak');

let working = false;
Expand All @@ -32,6 +37,9 @@ <h1>llama.cpp TTS</h1>
working = true;
input_el.disabled = true;
btn_speak_el.disabled = true;
status_el.textContent = 'Status: generating...';

const input = input_el.value.trim();

try {
const res = await fetch('/v1/audio/speech', {
Expand All @@ -40,7 +48,7 @@ <h1>llama.cpp TTS</h1>
'Content-Type': 'application/json'
},
body: JSON.stringify({
input: input_el.value.trim(),
input,
response_format: 'wav',
}),
});
Expand All @@ -50,19 +58,55 @@ <h1>llama.cpp TTS</h1>
const url = URL.createObjectURL(blob);
const audio = new Audio(url);
audio.play();
status_el.textContent = 'Status: playing...';
audio.addEventListener('ended', () => {
status_el.textContent = 'Status: ready';
});
echoTimings(res.headers, input);
} else {
const text = await res.text();
throw new Error(`Failed to generate speech: ${text}`);
}
} catch (e) {
console.error(e);
alert(e.message);
status_el.textContent = 'Status: ready';
}

working = false;
input_el.disabled = false;
btn_speak_el.disabled = false;
}

function echoTimings(headers, input_txt) {
try {
const timingsTTC = JSON.parse(headers.get('X-timings-ttc'));
const timingsVoc = JSON.parse(headers.get('X-timings-voc'));
const timingsSpec = JSON.parse(headers.get('X-timings-spec'));
output_el.innerHTML = `
<b>Input text:</b> ${escapeHtml(input_txt)}<br/>
<b>Timings:</b><br/>
<b>TTC:</b>
<ul>
${Object.entries(timingsTTC).map(([k, v]) => `<li>${k}: ${v.toFixed(2)} ms</li>`).join('')}
</ul>
<b>Voc:</b> ${timingsVoc.t_voc_ms.toFixed(2)} ms<br/>
<b>Spec:</b> ${timingsSpec.t_spec_ms.toFixed(2)} ms
`;
} catch (e) {
console.error(e);
output_el.innerHTML = 'No timings data is available.';
}
}

function escapeHtml(unsafe) {
return unsafe
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
</script>
</body>
</html>
78 changes: 53 additions & 25 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ struct server_task_result_embd : server_task_result {
struct server_task_result_tts_embd : server_task_result {
int index = 0;
std::vector<float> embd;
double t_ms = 0.0;

virtual int get_index() override {
return index; // unused
Expand Down Expand Up @@ -1749,11 +1750,15 @@ struct server_context {

if (!params.vocoder.model.empty()) {
common_params v_params = params_base;
v_params.model = params.vocoder.model;
v_params.model_url = params.vocoder.model_url;
v_params.hf_repo = params.vocoder.hf_repo;
v_params.hf_file = params.vocoder.hf_file;
v_params.embedding = true;
v_params.model = params.vocoder.model;
v_params.model_url = params.vocoder.model_url;
v_params.hf_repo = params.vocoder.hf_repo;
v_params.hf_file = params.vocoder.hf_file;
v_params.embedding = true;
v_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
// make sure the vocoder has the sufficient batch size
v_params.n_batch = v_params.n_ctx;
v_params.n_ubatch = v_params.n_ctx;
llama_init_vocoder = common_init_from_params(v_params);
}

Expand Down Expand Up @@ -2606,9 +2611,18 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_TTS_EMBD:
{
const auto ctx_cts = llama_init_vocoder.context.get();
const int n_ubatch = llama_n_ubatch(ctx_cts);
const int n_codes = (int) task.prompt_tokens.size();
if (n_codes > n_ubatch) {
send_error(task, string_format("Number of codes (%d) exceeds the maximum ubatch of vocoder model (%d)", n_codes, n_ubatch), ERROR_TYPE_INVALID_REQUEST);
break;
}

std::vector<float> embd;
SRV_DBG("tts_get_embd with %d tokens", (int) task.prompt_tokens.size());
int status = tts_get_embd(llama_init_vocoder.context.get(), task.prompt_tokens, embd);
uint64_t t_start = ggml_time_us();
SRV_DBG("tts_get_embd with %d codes", n_codes);
int status = tts_get_embd(ctx_cts, task.prompt_tokens, embd);
if (status != 0) {
send_error(task, string_format("Failed to get TTS embedding, status code = %d", status), ERROR_TYPE_SERVER);
break;
Expand All @@ -2620,6 +2634,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_tts_embd>();
res->id = task.id;
res->embd = std::move(embd);
res->t_ms = (ggml_time_us() - t_start) / 1e3;
queue_results.send(std::move(res));
} break;
}
Expand Down Expand Up @@ -4149,8 +4164,9 @@ int main(int argc, char ** argv) {
return;
}

llama_tokens audio_tokens;
// convert text to audio token
llama_tokens codes;
result_timings ttc_timings;
// convert text to codes
{
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
task.id = ctx_server.queue_tasks.get_new_id();
Expand All @@ -4174,32 +4190,35 @@ int main(int argc, char ** argv) {
const server_task_result_cmpl_final * result = dynamic_cast<server_task_result_cmpl_final*>(raw_result.get());
GGML_ASSERT(result != nullptr);
GGML_ASSERT(!result->tokens.empty());
audio_tokens = std::move(result->tokens);
codes = std::move(result->tokens);

// debug
SRV_DBG("codes str (before filter) = %s\n", common_detokenize(ctx_server.ctx, audio_tokens, true).c_str());
// SRV_DBG("codes str (before filter) = %s\n", common_detokenize(ctx_server.ctx, codes, true).c_str());

// post-process audio tokens
// post-process codes
// remove all non-audio tokens (i.e. < 151672 || > 155772)
audio_tokens.erase(std::remove_if(
audio_tokens.begin(),
audio_tokens.end(),
codes.erase(std::remove_if(
codes.begin(),
codes.end(),
[](llama_token t) { return t < 151672 || t > 155772; }),
audio_tokens.end());
SRV_DBG("codes size = %d\n", (int) audio_tokens.size());
codes.end());
SRV_DBG("codes size = %d\n", (int) codes.size());

ttc_timings = std::move(result->timings);
}

// debug
SRV_DBG("codes str = %s\n", common_detokenize(ctx_server.ctx, audio_tokens, true).c_str());
// SRV_DBG("codes str = %s\n", common_detokenize(ctx_server.ctx, codes, true).c_str());

// convert audio token to embeddings
int n_embd = llama_n_embd(ctx_server.model);
// convert codes to embeddings
int n_embd = llama_n_embd(ctx_server.llama_init_vocoder.model.get());
int n_codes = -1;
double t_voc_ms = 0.0;
std::vector<float> embd;
{
server_task task = server_task(SERVER_TASK_TYPE_TTS_EMBD);
task.id = ctx_server.queue_tasks.get_new_id();
task.prompt_tokens = std::move(audio_tokens);
task.prompt_tokens = std::move(codes);

ctx_server.queue_results.add_waiting_tasks({task});
ctx_server.queue_tasks.post(task);
Expand All @@ -4215,16 +4234,25 @@ int main(int argc, char ** argv) {
GGML_ASSERT(!result->embd.empty());

// flatten the array
n_codes = result->embd.size() / n_embd;
embd = std::move(result->embd);
SRV_DBG("tts embd n_code = %d\n", n_codes);
SRV_DBG("tts embd size = %zu\n", embd.size());
n_codes = result->embd.size() / n_embd;
embd = std::move(result->embd);
t_voc_ms = result->t_ms;
SRV_DBG("tts embd n_code = %d\n", n_codes);
SRV_DBG("tts embd size = %zu\n", embd.size());
SRV_DBG("tts embd t_voc_ms = %lf\n", t_voc_ms);
GGML_ASSERT(n_codes > 0);
}

// convert embeddings to wav
// will be freed by chunked_content_provider
const auto t_spec_start = ggml_time_us();
std::vector<float> audio = tts_embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
double t_spec_ms = (ggml_time_us() - t_spec_start) / 1e3;

// for now, we can only leave timings in response headers, mostly for debugging
res.set_header("X-timings-ttc", ttc_timings.to_json().dump());
res.set_header("X-timings-voc", (json{{ "t_voc_ms", t_voc_ms }}).dump());
res.set_header("X-timings-spec", (json{{ "t_spec_ms", t_spec_ms }}).dump());

const auto chunked_content_provider = [audio = std::move(audio)](size_t, httplib::DataSink & sink) mutable {
// TODO: some how reuse save_wav16 instead of duplicating the code here
Expand Down
1 change: 1 addition & 0 deletions examples/tts/tts-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <string>
#include <thread>
#include <vector>
#include <cstring>

//
// Terminal utils
Expand Down
25 changes: 5 additions & 20 deletions examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,39 +346,24 @@ int main(int argc, char ** argv) {
LOG_INF("%s: codes audio size: %d\n", __func__, (int) codes.size());
}

for (auto & token : codes) {
token -= 151672;
}

const auto t_voc_start = ggml_time_us();

const int n_codes = codes.size();

llama_batch batch = llama_batch_init(n_codes, 0, 1);

for (size_t i = 0; i < codes.size(); ++i) {
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
}
GGML_ASSERT(batch.n_tokens == n_codes);

if (llama_decode(ctx_cts, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
std::vector<float> embd;
if (tts_get_embd(ctx_cts, codes, embd) != 0) {
LOG_ERR("%s: tts_get_embd() failed\n", __func__);
return 1;
}

llama_synchronize(ctx_cts);

LOG_INF("%s: time for vocoder: %.3f ms\n", __func__, (ggml_time_us() - t_voc_start) / 1000.0f);

const auto t_spec_start = ggml_time_us();

#if 1
// spectral operations
const int n_embd = llama_n_embd(model_cts);
const float * embd = llama_get_embeddings(ctx_cts);

auto audio = tts_embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);
const int n_codes = codes.size();

auto audio = tts_embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
#else
// read the spectrogram from a file for debugging purposes
std::vector<float> audio;
Expand Down

0 comments on commit a8153cc

Please sign in to comment.