diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dd05a8cc7a4..a1e79e105db 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1197,8 +1197,13 @@ def generate_token( if prefill_logprobs else speculative_logits ) + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + else: next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices speculate = get_speculate() ( @@ -1220,6 +1225,14 @@ def generate_token( ) if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly @@ -1289,6 +1302,7 @@ def generate_token( batch.position_ids = next_position_ids + accepted_ids batch.input_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 0c376e70624..bb3951e7737 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -237,28 +237,24 @@ async def serve_inner( trust_remote_code, ) - # TODO: avoid hacky hardcoded adapter id - adapter_parameters = AdapterParameters( - adapter_ids=lora_adapter_ids, - weights=[ - # TODO: fill with actual weights - torch.tensor([1.0], dtype=torch.float32) - ], - merge_strategy=0, - density=0.0, - majority_sign_method=0, - ) - adapter_source = None - adapter_index = 0 - api_token = None - - model.load_adapter( - adapter_parameters, - adapter_source, - adapter_index, - api_token, - False, - ) + if len(lora_adapter_ids) > 0: + for index, adapter_id in enumerate(lora_adapter_ids): + # TODO: avoid hacky hardcoded adapter id + adapter_parameters = AdapterParameters( + adapter_ids=[adapter_id], + weights=[], + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + adapter_index = index + model.load_adapter( + adapter_parameters, + None, # adapter_source + adapter_index, + None, # api_token + False, # dynamic + ) except Exception: logger.exception("Error when initializing model")