Skip to content

Commit

Permalink
feat: download lora adapter weights from launcher (#2140)
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh authored Jul 1, 2024
1 parent 25f57e2 commit 0d97a93
Showing 1 changed file with 37 additions and 9 deletions.
46 changes: 37 additions & 9 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -898,13 +898,20 @@ enum LauncherError {
WebserverCannotStart,
}

fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
fn download_convert_model(
model_id: &str,
revision: Option<&str>,
trust_remote_code: bool,
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
// Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered();

let mut download_args = vec![
"download-weights".to_string(),
args.model_id.to_string(),
model_id.to_string(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
Expand All @@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
];

// Model optional revision
if let Some(revision) = &args.revision {
if let Some(revision) = &revision {
download_args.push("--revision".to_string());
download_args.push(revision.to_string())
}

// Trust remote code for automatic peft fusion
if args.trust_remote_code {
if trust_remote_code {
download_args.push("--trust-remote-code".to_string());
}

Expand All @@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L

// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache {
if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};

Expand All @@ -952,15 +959,15 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L

// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override {
if let Some(weights_cache_override) = &weights_cache_override {
envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(),
));
};

// Start process
tracing::info!("Starting download process.");
tracing::info!("Starting check and download process for {model_id}");
let mut download_process = match Command::new("text-generation-server")
.args(download_args)
.env_clear()
Expand Down Expand Up @@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
loop {
if let Some(status) = download_process.try_wait().unwrap() {
if status.success() {
tracing::info!("Successfully downloaded weights.");
tracing::info!("Successfully downloaded weights for {model_id}");
break;
}

Expand Down Expand Up @@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler");

// Download and convert model weights
download_convert_model(&args, running.clone())?;
download_convert_model(
&args.model_id,
args.revision.as_deref(),
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;

// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
}
}

if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop
Expand Down

0 comments on commit 0d97a93

Please sign in to comment.