Skip to content

Commit

Permalink
common : support HF download for vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 18, 2024
1 parent a95191c commit c0df192
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 20 deletions.
50 changes: 35 additions & 15 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,33 @@ std::string common_arg::to_string() {
// utils
//

static void common_params_handle_model_default(common_params & params) {
if (!params.hf_repo.empty()) {
static void common_params_handle_model_default(
std::string & model,
std::string & model_url,
std::string & hf_repo,
std::string & hf_file) {
if (!hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (params.hf_file.empty()) {
if (params.model.empty()) {
if (hf_file.empty()) {
if (model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
}
params.hf_file = params.model;
} else if (params.model.empty()) {
hf_file = model;
} else if (model.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = params.hf_repo + "_" + params.hf_file;
std::string filename = hf_repo + "_" + hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
params.model = fs_get_cache_file(filename);
model = fs_get_cache_file(filename);
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
auto f = string_split<std::string>(params.model_url, '#').front();
} else if (!model_url.empty()) {
if (model.empty()) {
auto f = string_split<std::string>(model_url, '#').front();
f = string_split<std::string>(f, '?').front();
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
} else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH;
} else if (model.empty()) {
model = DEFAULT_MODEL_PATH;
}
}

Expand Down Expand Up @@ -276,7 +280,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}

common_params_handle_model_default(params);
// TODO: refactor model params in a common struct
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);

if (params.escape) {
string_process_escapes(params.prompt);
Expand Down Expand Up @@ -1581,6 +1587,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE"));
add_opt(common_arg(
{"-hfrv", "--hf-repo-v"}, "REPO",
"Hugging Face model repository for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO_V"));
add_opt(common_arg(
{"-hffv", "--hf-file-v"}, "FILE",
"Hugging Face model file for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE_V"));
add_opt(common_arg(
{"-hft", "--hf-token"}, "TOKEN",
"Hugging Face access token (default: value from HF_TOKEN environment variable)",
Expand Down
7 changes: 4 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2

static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;

while (remaining_attempts > 0) {
Expand All @@ -1119,7 +1119,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
}

static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {

// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
if (!curl) {
Expand Down Expand Up @@ -1192,11 +1191,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
std::string etag;
std::string last_modified;
};

common_load_model_from_url_headers headers;

{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;

static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
Expand Down
6 changes: 5 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ struct common_params_speculative {
};

struct common_params_vocoder {
std::string model = ""; // vocoder model for producing audio // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT

std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
};

struct common_params {
Expand Down
7 changes: 6 additions & 1 deletion examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,12 @@ int main(int argc, char ** argv) {
model_ttc = llama_init_ttc.model;
ctx_ttc = llama_init_ttc.context;

params.model = params.vocoder.model;
// TODO: refactor in a common struct
params.model = params.vocoder.model;
params.model_url = params.vocoder.model_url;
params.hf_repo = params.vocoder.hf_repo;
params.hf_file = params.vocoder.hf_file;

params.embedding = true;

common_init_result llama_init_cts = common_init_from_params(params);
Expand Down

0 comments on commit c0df192

Please sign in to comment.