Skip to content

Commit

Permalink
fix: adjust batch for bgmv
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 6, 2024
1 parent a864004 commit e0462af
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
14 changes: 14 additions & 0 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,13 @@ def generate_token(
if prefill_logprobs
else speculative_logits
)
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)

else:
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices

speculate = get_speculate()
(
Expand All @@ -1220,6 +1225,14 @@ def generate_token(
)

if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)

if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
Expand Down Expand Up @@ -1289,6 +1302,7 @@ def generate_token(
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices

if prefill and prefill_logprobs:
# Get prefill logprobs
Expand Down
40 changes: 18 additions & 22 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,28 +237,24 @@ async def serve_inner(
trust_remote_code,
)

# TODO: avoid hacky hardcoded adapter id
adapter_parameters = AdapterParameters(
adapter_ids=lora_adapter_ids,
weights=[
# TODO: fill with actual weights
torch.tensor([1.0], dtype=torch.float32)
],
merge_strategy=0,
density=0.0,
majority_sign_method=0,
)
adapter_source = None
adapter_index = 0
api_token = None

model.load_adapter(
adapter_parameters,
adapter_source,
adapter_index,
api_token,
False,
)
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: avoid hacky hardcoded adapter id
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=[],
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index
model.load_adapter(
adapter_parameters,
None, # adapter_source
adapter_index,
None, # api_token
False, # dynamic
)

except Exception:
logger.exception("Error when initializing model")
Expand Down

0 comments on commit e0462af

Please sign in to comment.