Skip to content

Commit

Permalink
initial chat functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jmansdorfer committed Nov 8, 2024
1 parent c4ac589 commit 5d3e250
Show file tree
Hide file tree
Showing 6 changed files with 768 additions and 1 deletion.
511 changes: 511 additions & 0 deletions docs/docs/integrations/chat/predictionguard.ipynb

Large diffs are not rendered by default.

21 changes: 20 additions & 1 deletion docs/docs/integrations/providers/predictionguard.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,31 @@ pip install predictionguard
## Prediction Guard Langchain Integrations
|API|Description|Endpoint Docs|Import|Example Usage|
|---|---|---|---|---|
|Chat|Build Chat Bots|[Chat](https://docs.predictionguard.com/api-reference/api-reference/chat-completions)|`from langchain_community.chat_models.predictionguard import ChatPredictionGuard`|[predictionguard.ipynb](/docs/integrations/chat/predictionguard)|
|Completions|Generate Text|[Completions](https://docs.predictionguard.com/api-reference/api-reference/completions)|`from langchain_community.llms.predictionguard import PredictionGuard`|[predictionguard.ipynb](/docs/integrations/llms/predictionguard)|
|Text Embedding|Embed String to Vectores|[Embeddings](https://docs.predictionguard.com/api-reference/api-reference/embeddings)|`from langchain_community.embeddings.predictionguard import PredictionGuardEmbeddings`|[predictionguard.ipynb](/docs/integrations/text_embedding/predictionguard)|

## Getting Started

## Chat Models

### Prediction Guard Chat

See a [usage example](/docs/integrations/chat/predictionguard)

```python
from langchain_community.chat_models.predictionguard import ChatPredictionGuard
```

#### Usage

```python
# If predictionguard_api_key is not passed, default behavior is to use the `PREDICTIONGUARD_API_KEY` environment variable.
chat = ChatPredictionGuard(model="Hermes-2-Pro-Llama-3-8B")

chat.invoke("Tell me a joke")
```

## Embedding Models

### Prediction Guard Embeddings
Expand All @@ -40,7 +60,6 @@ output = embeddings.embed_query(text)
```



## LLMs
### Prediction Guard LLM

Expand Down
5 changes: 5 additions & 0 deletions libs/community/langchain_community/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
from langchain_community.chat_models.perplexity import (
ChatPerplexity,
)
from langchain_community.chat_models.predictionguard import (
ChatPredictionGuard,
)
from langchain_community.chat_models.premai import (
ChatPremAI,
)
Expand Down Expand Up @@ -226,6 +229,7 @@
"ChatOllama",
"ChatOpenAI",
"ChatPerplexity",
"ChatPredictionGuard",
"ChatPremAI",
"ChatSambaNovaCloud",
"ChatSambaStudio",
Expand Down Expand Up @@ -317,6 +321,7 @@
"ChatPremAI": "langchain_community.chat_models.premai",
"ChatLlamaCpp": "langchain_community.chat_models.llamacpp",
"ChatYi": "langchain_community.chat_models.yi",
"ChatPredictionGuard": "langchain_community.chat_models.predictionguard",
}


Expand Down
173 changes: 173 additions & 0 deletions libs/community/langchain_community/chat_models/predictionguard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import logging
from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, ConfigDict, model_validator

from langchain_community.adapters.openai import (
convert_dict_to_message,
convert_message_to_dict,
)

logger = logging.getLogger(__name__)


class ChatPredictionGuard(BaseChatModel):
"""Prediction Guard chat models.
To use, you should have the ``predictionguard`` python package installed,
and the environment variable ``PREDICTIONGUARD_API_KEY`` set with your API key,
or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
chat = ChatPredictionGuard(
predictionguard_api_key="<your API key>",
model="Hermes-3-Llama-3.1-8B",
)
"""

client: Any = None

model: Optional[str] = "Hermes-3-Llama-3.1-8B"
"""Model name to use."""

max_tokens: Optional[int] = 256
"""The maximum number of tokens in the generated completion."""

temperature: Optional[float] = 0.75
"""The temperature parameter for controlling randomness in completions."""

top_p: Optional[float] = 0.1
"""The diversity of the generated text based on nucleus sampling."""

top_k: Optional[int] = None
"""The diversity of the generated text based on top-k sampling."""

stop: Optional[List[str]] = None

predictionguard_input: Optional[Dict[str, bool]] = None
"""The input check to run over the prompt before sending to the LLM."""

predictionguard_output: Optional[Dict[str, bool]] = None
"""The output check to run the LLM output against."""

predictionguard_api_key: Optional[str] = None
"""Prediction Guard API key."""

model_config = ConfigDict(extra="forbid")

@property
def _llm_type(self) -> str:
return "predictionguard-chat"

@model_validator(mode="before")
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
pg_api_key = get_from_dict_or_env(
values, "predictionguard_api_key", "PREDICTIONGUARD_API_KEY"
)

try:
from predictionguard import PredictionGuard

values["client"] = PredictionGuard(
api_key=pg_api_key,
)

except ImportError:
raise ImportError(
"Could not import predictionguard python package. "
"Please install it with `pip install predictionguard --upgrade`."
)

return values

def _get_parameters(self, **kwargs: Any) -> Dict[str, Any]:
# input kwarg conflicts with LanguageModelInput on BaseChatModel
input = kwargs.pop("predictionguard_input", self.predictionguard_input)
output = kwargs.pop("predictionguard_output", self.predictionguard_output)

params = {
**{
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"input": (
input.model_dump() if isinstance(input, BaseModel) else input
),
"output": (
output.model_dump() if isinstance(output, BaseModel) else output
),
},
**kwargs,
}

return params

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = [convert_message_to_dict(m) for m in messages]

params = self._get_parameters(**kwargs)
params["stream"] = True

result = self.client.chat.completions.create(
model=self.model,
messages=message_dicts,
**params,
)
for part in result:
# get the data from SSE
if "data" in part:
part = part["data"]
if len(part["choices"]) == 0:
continue
content = part["choices"][0]["delta"]["content"]
chunk = ChatGenerationChunk(
message=AIMessageChunk(id=part["id"], content=content)
)
yield chunk

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [convert_message_to_dict(m) for m in messages]
params = self._get_parameters(**kwargs)

response = self.client.chat.completions.create(
model=self.model,
messages=message_dicts,
**params,
)

generations = []
for res in response["choices"]:
if res.get("status", "").startswith("error: "):
err_msg = res["status"].removeprefix("error: ")
raise ValueError(f"Error from PredictionGuard API: {err_msg}")

message = convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)

return ChatResult(generations=generations)
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Test Prediction Guard API wrapper"""

import pytest

from langchain_community.chat_models.predictionguard import ChatPredictionGuard


def test_predictionguard_call() -> None:
"""Test a valid call to Prediction Guard."""
chat = ChatPredictionGuard(
model="Hermes-3-Llama-3.1-8B", max_tokens=100, temperature=1.0
)

messages = [
(
"system",
"You are a helpful chatbot",
),
("human", "Tell me a joke."),
]

output = chat.invoke(messages)
assert isinstance(output.content, str)


def test_predictionguard_pii() -> None:
chat = ChatPredictionGuard(
model="Hermes-3-Llama-3.1-8B",
predictionguard_input={
"pii": "block",
},
max_tokens=100,
temperature=1.0,
)

messages = [
"Hello, my name is John Doe and my SSN is 111-22-3333",
]

with pytest.raises(ValueError, match=r"Could not make prediction. pii detected"):
chat.invoke(messages)


def test_predictionguard_stream() -> None:
"""Test a valid call with streaming to Prediction Guard"""

chat = ChatPredictionGuard(
model="Hermes-3-Llama-3.1-8B",
)

messages = [("system", "You are a helpful chatbot."), ("human", "Tell me a joke.")]

num_chunks = 0
for chunk in chat.stream(messages):
assert isinstance(chunk.content, str)
num_chunks += 1

assert num_chunks > 0
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ChatOctoAI",
"ChatSnowflakeCortex",
"ChatYi",
"ChatPredictionGuard",
]


Expand Down

0 comments on commit 5d3e250

Please sign in to comment.