Skip to content

Commit

Permalink
Merge pull request #662 from PrefectHQ/fix-problems
Browse files Browse the repository at this point in the history
need numpy for embeddings when http retrieval
  • Loading branch information
zzstoatzz authored Dec 1, 2023
2 parents 7a363f0 + 129eb6b commit 04ef671
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cookbook/slackbot/Dockerfile.slackbot
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ RUN apt-get update && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

RUN pip install ".[prefect]"
RUN pip install ".[slackbot]"

EXPOSE 4200

Expand Down
68 changes: 68 additions & 0 deletions cookbook/slackbot/keywords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from marvin import ai_fn
from marvin.utilities.slack import post_slack_message
from prefect import task
from prefect.blocks.system import JSON, Secret, String
from prefect.exceptions import ObjectNotFound

"""
Define a map between keywords and the relationships we want to check for
in a given message related to that keyword.
"""

keywords = (
("429", "rate limit"),
("SSO", "Single Sign On", "RBAC", "Roles", "Role Based Access Controls"),
)

relationships = (
"The user is getting rate limited",
"The user is asking about a paid feature",
)


async def get_reduced_kw_relationship_map() -> dict:
try:
json_map = (await JSON.load("keyword-relationship-map")).value
except (ObjectNotFound, ValueError):
json_map = {"keywords": keywords, "relationships": relationships}
await JSON(value=json_map).save("keyword-relationship-map")

return {
keyword: relationship
for keyword_tuple, relationship in zip(
json_map["keywords"], json_map["relationships"]
)
for keyword in keyword_tuple
}


@ai_fn
def activation_score(message: str, keyword: str, target_relationship: str) -> float:
"""Return a score between 0 and 1 indicating whether the target relationship exists
between the message and the keyword"""


@task
async def handle_keywords(message: str, channel_name: str, asking_user: str, link: str):
keyword_relationships = await get_reduced_kw_relationship_map()
keywords = [
keyword for keyword in keyword_relationships.keys() if keyword in message
]
for keyword in keywords:
target_relationship = keyword_relationships.get(keyword)
if not target_relationship:
continue
score = activation_score(message, keyword, target_relationship)
if score > 0.5:
await post_slack_message(
message=(
f"A user ({asking_user}) just asked a question in"
f" {channel_name} that contains the keyword `{keyword}`, and I'm"
f" {score*100:.0f}% sure that their message indicates the"
f" following:\n\n**{target_relationship!r}**.\n\n[Go to"
f" message]({link})"
),
channel_id=(await String.load("ask-marvin-tests-channel-id")).value,
auth_token=(await Secret.load("slack-api-token")).get(),
)
return
18 changes: 17 additions & 1 deletion cookbook/slackbot/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
import uvicorn
from cachetools import TTLCache
from fastapi import FastAPI, HTTPException, Request
from keywords import handle_keywords
from marvin import Assistant
from marvin.beta.assistants import Thread
from marvin.tools.github import search_github_issues
from marvin.tools.retrieval import multi_query_chroma
from marvin.utilities.logging import get_logger
from marvin.utilities.slack import SlackPayload, post_slack_message
from marvin.utilities.slack import (
SlackPayload,
get_channel_name,
get_workspace_info,
post_slack_message,
)
from prefect import flow, task
from prefect.states import Completed

Expand All @@ -31,6 +37,16 @@ async def handle_message(payload: SlackPayload):
assistant_thread = CACHE.get(thread, Thread())
CACHE[thread] = assistant_thread

await handle_keywords.submit(
message=cleaned_message,
channel_name=await get_channel_name(event.channel),
asking_user=event.user,
link=( # to user's message
f"{(await get_workspace_info()).get('url')}archives/"
f"{event.channel}/p{event.ts.replace('.', '')}"
),
)

with Assistant(
name="Marvin (from Hitchhiker's Guide to the Galaxy)",
tools=[task(multi_query_chroma), task(search_github_issues)],
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [

[project.optional-dependencies]
generator = ["datamodel-code-generator>=0.20.0"]
prefect = ["prefect>=2.14.9"]
dev = [
"marvin[tests]",
"black[jupyter]",
Expand All @@ -43,16 +44,14 @@ dev = [
"ruff",
]
tests = [
"marvin[openai,anthropic]",
"pytest-asyncio~=0.20",
"pytest-env>=0.8,<2.0",
"pytest-rerunfailures>=10,<13",
"pytest-sugar~=0.9",
"pytest~=7.3.1",
"pytest-timeout",
]

prefect = ["prefect>=2.14.5"]
slackbot = ["marvin[prefect]", "numpy"]

[project.urls]
Code = "https://github.com/prefecthq/marvin"
Expand Down
11 changes: 9 additions & 2 deletions src/marvin/tools/github.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from datetime import datetime
from typing import List, Optional

Expand All @@ -8,6 +9,13 @@
from marvin.utilities.strings import slice_tokens


def get_token() -> str:
try:
return marvin.settings.github_token
except AttributeError:
return os.environ.get("MARVIN_GITHUB_TOKEN", "")


class GitHubUser(BaseModel):
"""GitHub user."""

Expand Down Expand Up @@ -60,8 +68,7 @@ async def search_github_issues(
"""
headers = {"Accept": "application/vnd.github.v3+json"}

if token := marvin.settings.github_token:
headers["Authorization"] = f"Bearer {token}"
headers["Authorization"] = f"Bearer {get_token()}"

async with httpx.AsyncClient() as client:
response = await client.get(
Expand Down
4 changes: 3 additions & 1 deletion src/marvin/tools/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ async def create_openai_embeddings(texts: list[str]) -> list[list[float]]:

return (
(
await AsyncOpenAI().embeddings.create(
await AsyncOpenAI(
api_key=marvin.settings.openai.api_key.get_secret_value()
).embeddings.create(
input=[text.replace("\n", " ") for text in texts],
model="text-embedding-ada-002",
)
Expand Down

0 comments on commit 04ef671

Please sign in to comment.