Skip to content

Commit

Permalink
add news retriever (run-llama#13934)
Browse files Browse the repository at this point in the history
* add news retriever

* Integrate News API to YouRetriever

* Update You Retriever notebook

* 🎨 Rename `endpoint_type` to `endpoint`

* Return news metadata

* fixup! 🎨 Rename `endpoint_type` to `endpoint`

* ⬆️ Bump package version

---------

Co-authored-by: Christopher Tee <[email protected]>
  • Loading branch information
rmcarthur and christeefy authored Jun 16, 2024
1 parent 00e118c commit 554bfae
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 25 deletions.
75 changes: 59 additions & 16 deletions docs/docs/examples/retrievers/you_retriever.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,88 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from llama_index.retrievers.you import YouRetriever"
]
},
{
"cell_type": "markdown",
"id": "bda2c5e0",
"metadata": {},
"source": [
"### Retrieve from You.com's Search API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a38b87b3-c94e-4311-8335-86c6b0f32463",
"metadata": {},
"outputs": [],
"source": [
"you_api_key = \"\" or os.environ[\"YOU_API_KEY\"]\n",
"you_api_key = \"\" or os.environ[\"YDC_API_KEY\"]\n",
"\n",
"retriever = YouRetriever(api_key=you_api_key)"
"retriever = YouRetriever(endpoint=\"search\", api_key=you_api_key) # default"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbfc0fe3-7c64-4d5d-8190-f80e31d35b4c",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The beaches and underwater world off the coast of Florida provide endless opportunities of play in the ocean. ... Glacier Bay is a living laboratory with ongoing research and study by scientists on a wide range of ocean-related issues. ... A picture of coastal life, Fire Island offers rich biological diversity and the beautiful landscapes that draw us all to the ocean.\n",
"A military veteran, Jose Sarria also became a prominent advocate for Latinos, immigrants, and the LGBTQ community in San Francisco. ... Explore the history of the LGBTQ community on Governors Island and Henry Gurber's work in protecting gay rights.\n",
"Explore the history of the LGBTQ community on Governors Island and Henry Gurber's work in protecting gay rights. ... Sylvia Rivera was an advocate for transgender rights and LGBTQ+ communities, and was an active participant of the Stonewall uprising.\n"
]
}
],
"source": [
"retrieved_results = retriever.retrieve(\"national parks in the US\")"
"retrieved_results = retriever.retrieve(\"national parks in the US\")\n",
"print(retrieved_results[0].get_content())"
]
},
{
"cell_type": "markdown",
"id": "069c4adb",
"metadata": {},
"source": [
"### Retrieve from You.com's News API"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3142a3af-d9a0-4fc1-a6a4-f42eb11a9099",
"id": "47a7c7d3",
"metadata": {},
"outputs": [],
"source": [
"print(retrieved_results[0].get_content())\n",
"\n",
"from llama_index.core.response.notebook_utils import display_source_node\n",
"you_api_key = \"\" or os.environ[\"YDC_API_KEY\"]\n",
"\n",
"# for n in retrieved_results:\n",
"# display_source_node(n)"
"retriever = YouRetriever(endpoint=\"news\", api_key=you_api_key)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9eedea5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"But seven months after the October announcement, the Fed's key interest rate — the federal funds rate — is still stuck at 5.25% to 5.5%, where it has been since July 2023. U.S. interest rates are built with the fed funds rate as the foundation.\n"
]
}
],
"source": [
"retrieved_results = retriever.retrieve(\"Fed interest rates\")\n",
"print(retrieved_results[0].get_content())"
]
},
{
Expand All @@ -93,9 +137,8 @@
"source": [
"from llama_index.core.query_engine import RetrieverQueryEngine\n",
"\n",
"query_engine = RetrieverQueryEngine.from_args(\n",
" retriever,\n",
")"
"retriever = YouRetriever()\n",
"query_engine = RetrieverQueryEngine.from_args(retriever)"
]
},
{
Expand All @@ -108,7 +151,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The United States has 63 national parks, which are protected areas operated by the National Park Service. These parks are designated for their natural beauty, unique geological features, diverse ecosystems, and recreational opportunities. They are typically larger and more popular destinations compared to other units of the National Park System. National monuments, on the other hand, are also protected for their historical or archaeological significance. Some national parks are paired with national preserves, which have different levels of protection but are administered together. The national parks in the United States cover a total area of approximately 52.4 million acres.\n"
"There are 63 national parks in the United States, each established to preserve unique landscapes, wildlife, and historical sites for the enjoyment of present and future generations. These parks are managed by the National Park Service, which aims to conserve the scenery and natural and historic objects within the parks. National parks offer a wide range of activities such as hiking, camping, wildlife viewing, and learning about the natural world. Some of the most visited national parks include Great Smoky Mountains, Yellowstone, and Zion, while others like Gates of the Arctic see fewer visitors due to their remote locations. Each national park has its own distinct features and attractions, contributing to the diverse tapestry of protected lands across the country.\n"
]
}
],
Expand All @@ -120,9 +163,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "llama_index_v2",
"display_name": "you-llamaindex",
"language": "python",
"name": "llama_index_v2"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import logging
import os
from typing import List, Optional
import warnings
from typing import Any, Dict, List, Literal, Optional

import requests

from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
Expand All @@ -13,24 +15,109 @@


class YouRetriever(BaseRetriever):
"""You retriever."""
"""
Retriever for You.com's Search and News API.
[API reference](https://documentation.you.com/api-reference/)
Args:
api_key: you.com API key, if `YDC_API_KEY` is not set in the environment
endpoint: you.com endpoints
num_web_results: The max number of web results to return, must be under 20
safesearch: Safesearch settings, one of "off", "moderate", "strict", defaults to moderate
country: Country code, ex: 'US' for United States, see API reference for more info
search_lang: (News API) Language codes, ex: 'en' for English, see API reference for more info
ui_lang: (News API) User interface language for the response, ex: 'en' for English, see API reference for more info
spellcheck: (News API) Whether to spell check query or not, defaults to True
"""

def __init__(
self,
api_key: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
endpoint: Literal["search", "news"] = "search",
num_web_results: Optional[int] = None,
safesearch: Optional[Literal["off", "moderate", "strict"]] = None,
country: Optional[str] = None,
search_lang: Optional[str] = None,
ui_lang: Optional[str] = None,
spellcheck: Optional[bool] = None,
) -> None:
"""Init params."""
self._api_key = api_key or os.environ["YOU_API_KEY"]
# Should deprecate `YOU_API_KEY` in favour of `YDC_API_KEY` for standardization purposes
self._api_key = api_key or os.getenv("YOU_API_KEY") or os.environ["YDC_API_KEY"]
super().__init__(callback_manager)

if endpoint not in ("search", "news"):
raise ValueError('`endpoint` must be either "search" or "news"')

# Raise warning if News API-specific fields are set but endpoint is not "news"
if endpoint != "news":
news_api_fields = (search_lang, ui_lang, spellcheck)
for field in news_api_fields:
if field:
warnings.warn(
(
f"News API-specific field '{field}' is set but `{endpoint=}`. "
"This will have no effect."
),
UserWarning,
)

self.endpoint = endpoint
self.num_web_results = num_web_results
self.safesearch = safesearch
self.country = country
self.search_lang = search_lang
self.ui_lang = ui_lang
self.spellcheck = spellcheck

def _generate_params(self, query: str) -> Dict[str, Any]:
params = {"safesearch": self.safesearch, "country": self.country}

if self.endpoint == "search":
params.update(
query=query,
num_web_results=self.num_web_results,
)
elif self.endpoint == "news":
params.update(
q=query,
count=self.num_web_results,
search_lang=self.search_lang,
ui_lang=self.ui_lang,
spellcheck=self.spellcheck,
)

# Remove `None` values
return {k: v for k, v in params.items() if v is not None}

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve."""
headers = {"X-API-Key": self._api_key}
results = requests.get(
f"https://api.ydc-index.io/search?query={query_bundle.query_str}",
params = self._generate_params(query_bundle.query_str)
response = requests.get(
f"https://api.ydc-index.io/{self.endpoint}",
params=params,
headers=headers,
).json()
)
response.raise_for_status()
results = response.json()

nodes: List[TextNode] = []
if self.endpoint == "search":
for hit in results["hits"]:
nodes.append(
TextNode(
text="\n".join(hit["snippets"]),
)
)
else: # news endpoint
for article in results["news"]["results"]:
node = TextNode(
text=article["description"],
extra_info={"url": article["url"], "age": article["age"]},
)
nodes.append(node)

search_hits = ["\n".join(hit["snippets"]) for hit in results["hits"]]
return [NodeWithScore(node=TextNode(text=s), score=1.0) for s in search_hits]
return [NodeWithScore(node=node, score=1.0) for node in nodes]
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-retrievers-you"
readme = "README.md"
version = "0.1.2"
version = "0.1.3"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 554bfae

Please sign in to comment.