diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a1e79e105db..fe229853c2e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -695,6 +695,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch ) cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -739,7 +742,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch else None ) - _adapter_segments, _adapter_segment_indices = adapter_segment_builder.build() + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() return cls( batch_id=batches[0].batch_id, @@ -771,6 +774,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) def __len__(self): @@ -1225,14 +1234,6 @@ def generate_token( ) if prefill: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) - if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly @@ -1277,6 +1278,12 @@ def generate_token( # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] + # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: @@ -1304,6 +1311,15 @@ def generate_token( batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index bb3951e7737..c4411b745cd 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -239,10 +239,11 @@ async def serve_inner( if len(lora_adapter_ids) > 0: for index, adapter_id in enumerate(lora_adapter_ids): - # TODO: avoid hacky hardcoded adapter id + # TODO: improve non merged adapter loading and long term + # improve adapter loading as a whole adapter_parameters = AdapterParameters( adapter_ids=[adapter_id], - weights=[], + weights=None, # will be set to 1 merge_strategy=0, density=1.0, majority_sign_method=0, diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 53e92180deb..1d3a644269b 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -162,8 +162,6 @@ def load_module_map( api_token: str, trust_remote_code: bool = False, ) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: - print("adapter_id", adapter_id) - revision = "main" adapter_config = LoraConfig.load(adapter_id, api_token)