Skip to content

Commit

Permalink
if lora not exists do not restart worker
Browse files Browse the repository at this point in the history
  • Loading branch information
mitya52 committed Mar 14, 2024
1 parent a8752d7 commit e108a81
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
9 changes: 7 additions & 2 deletions self_hosting_machinery/inference/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import signal
import socket
import traceback

from refact_scratchpads_no_gpu.stream_results import infserver_session
from refact_scratchpads_no_gpu.stream_results import validate_description_dict
Expand Down Expand Up @@ -104,8 +105,12 @@ def check_cancelled(*args, **kwargs):
"ts_first_token": 0,
"ts_batch_finished": 0,
}
inference_model.lora_switch_according_to_request(request.get("lora_config", None))
inference_model.infer(request, upload_proxy, upload_proxy_args)
try:
inference_model.lora_switch_according_to_request(request.get("lora_config", None))
inference_model.infer(request, upload_proxy, upload_proxy_args)
except Exception as e:
log(f"inference failed with {e}")
log(traceback.format_exc())
elif retcode == "WAIT":
# Normal, no requests
pass
Expand Down
5 changes: 4 additions & 1 deletion self_hosting_machinery/inference/lora_loader_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def load_checkpoint(
}
lora_target_modules = adapter_config["target_modules"]
else:
old_format_finetune_cp = _load_filename(load_path / "mp_rank_00_model_states.pt")
finetune_cps = [_load_filename(p) for p in load_cp_paths]
if len(finetune_cps) > 1:
raise NotImplementedError("Loading of sharded checkpoint is not implemented")
old_format_finetune_cp = finetune_cps[0]
lora_cfg = old_format_finetune_cp['ds_config']['model_info']['lora']
_, lora_target_modules = map_model_specific_params(
model_name=self.model_name,
Expand Down

0 comments on commit e108a81

Please sign in to comment.