Skip to content

Commit

Permalink
Add **kwargs back to NVFoundationLLMClient.generate_batch() and `…
Browse files Browse the repository at this point in the history
…generate_batch_async()` (#1967)

This PR reverts the removal of the `**kwargs` argument from the `NVFoundationLLMClient.generate_batch()` and `generate_batch_async()` methods, which introduced a regression.

Closes #1961 

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Ashley Song (https://github.com/ashsong-nv)

Approvers:
  - David Gardner (https://github.com/dagardner-nv)

URL: #1967
  • Loading branch information
ashsong-nv authored Oct 22, 2024
1 parent 674e629 commit 2cfcc69
Showing 1 changed file with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def generate_batch(self,
return_exceptions: bool = False) -> list[str] | list[str | BaseException]:
...

def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> list[str] | list[str | BaseException]:
def generate_batch(self,
inputs: dict[str, list],
return_exceptions=False,
**kwargs) -> list[str] | list[str | BaseException]:
"""
Issue a request to generate a list of responses based on a list of prompts.
Expand All @@ -141,6 +144,8 @@ def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> li
Inputs containing prompt data.
return_exceptions : bool
Whether to return exceptions in the output list or raise them immediately.
**kwargs
Additional keyword arguments for generate batch.
"""

# Note: We dont want to use the generate_multiple implementation from nemollm because there is no retry logic.
Expand All @@ -152,7 +157,7 @@ def generate_batch(self, inputs: dict[str, list], return_exceptions=False) -> li
"If an exception is raised for any item, the function will exit and raise that exception.")

prompts = [StringPromptValue(text=p) for p in inputs[self._prompt_key]]
final_kwargs = self._model_kwargs
final_kwargs = {**self._model_kwargs, **kwargs}

responses = []
try:
Expand Down Expand Up @@ -182,7 +187,8 @@ async def generate_batch_async(self,

async def generate_batch_async(self,
inputs: dict[str, list],
return_exceptions=False) -> list[str] | list[str | BaseException]:
return_exceptions=False,
**kwargs) -> list[str] | list[str | BaseException]:
"""
Issue an asynchronous request to generate a list of responses based on a list of prompts.
Expand All @@ -192,6 +198,8 @@ async def generate_batch_async(self,
Inputs containing prompt data.
return_exceptions : bool
Whether to return exceptions in the output list or raise them immediately.
**kwargs
Additional keyword arguments for generate batch async.
"""

# Note: We dont want to use the generate_multiple implementation from nemollm because there is no retry logic.
Expand All @@ -203,7 +211,7 @@ async def generate_batch_async(self,
"If an exception is raised for any item, the function will exit and raise that exception.")

prompts = [StringPromptValue(text=p) for p in inputs[self._prompt_key]]
final_kwargs = self._model_kwargs
final_kwargs = {**self._model_kwargs, **kwargs}

responses = []
try:
Expand Down

0 comments on commit 2cfcc69

Please sign in to comment.