Skip to content

Commit

Permalink
safe_mode flag and removed reset_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
StanChan03 committed Nov 27, 2024
1 parent 186f902 commit 9689736
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,23 @@ def __init__(
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
max_cache_size: int = 1024,
safe_mode: bool = False,
**kwargs: dict[str, Any],
):
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.tokenizer = tokenizer
self.safe_mode = safe_mode
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
self.cache = Cache(max_cache_size)

def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput:
def __call__(
self, messages: list[list[dict[str, str]]], safe_mode: bool = False, **kwargs: dict[str, Any]
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

# Set top_logprobs if logprobs requested
Expand Down Expand Up @@ -65,20 +69,17 @@ def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)

self.print_total_usage()
self.reset_stats()
if self.safe_mode:
self.print_total_usage()

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
with tqdm(total=len(uncached_data), desc="Processing uncached messages") as pbar:
for i in range(0, len(uncached_data), self.max_batch_size):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))
pbar.update(len(batch))
for i in tqdm(range(0, len(uncached_data), self.max_batch_size), desc="Processing uncached messages"):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))
return uncached_responses

def _cache_response(self, response, hash):
Expand Down

0 comments on commit 9689736

Please sign in to comment.