Skip to content

Commit

Permalink
fix: adjust adapter_segments logic when in batch
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 6, 2024
1 parent e0462af commit d103264
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
34 changes: 25 additions & 9 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,9 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)

all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
Expand Down Expand Up @@ -739,7 +742,7 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
else None
)

_adapter_segments, _adapter_segment_indices = adapter_segment_builder.build()
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()

return cls(
batch_id=batches[0].batch_id,
Expand Down Expand Up @@ -771,6 +774,12 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)

def __len__(self):
Expand Down Expand Up @@ -1225,14 +1234,6 @@ def generate_token(
)

if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)

if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
Expand Down Expand Up @@ -1277,6 +1278,12 @@ def generate_token(
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]

# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]

# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
Expand Down Expand Up @@ -1304,6 +1311,15 @@ def generate_token(
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices

if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)

if prefill and prefill_logprobs:
# Get prefill logprobs
prefill_logprobs_tensor = torch.log_softmax(out, -1)
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ async def serve_inner(

if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: avoid hacky hardcoded adapter id
# TODO: improve non merged adapter loading and long term
# improve adapter loading as a whole
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=[],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
majority_sign_method=0,
Expand Down
2 changes: 0 additions & 2 deletions server/text_generation_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ def load_module_map(
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
print("adapter_id", adapter_id)

revision = "main"

adapter_config = LoraConfig.load(adapter_id, api_token)
Expand Down

0 comments on commit d103264

Please sign in to comment.