Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GDrive] Add caching of downloaded files #492

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions gdrive/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ This connector also supports a few optional environment variables to configure t
1. `GDRIVE_SEARCH_LIMIT` - Number of results to return. Default is 10.
2. `GDRIVE_FOLDER_ID` - ID of the folder to search in. If not provided, the search will be performed in the whole drive.

## Caching

This connector has an optional caching feature, which will cache the documents it downloads from Google Drive. This
will prevent the same documents from being downloaded repeatedly when a user continues a conversation on the same
topic, or multiple users are asking questions related to the same documents.

By default, the caching feature is not enabled. To enable it, set the env var `GDRIVE_CACHE_TYPE` to either:

* `memory` - to cache the documents in the Python process itself using TTLCache from cachetools.
* `redis` - to cache the documents in Redis.

When using caching, the following env vars are also available:

* `GDRIVE_REDIS_HOST` - Redis host to connect to
* `GDRIVE_REDIS_PORT` - Redis port to connect to
* `GDRIVE_REDIS_DB` - Redis database number to connect to
* `GDRIVE_CACHE_MAXSIZE` - Maximum number of documents to store in the TTL cache when using TTL cache
* `GDRIVE_CACHE_EXPIRE_TIME` - The number of seconds to cache documents for, with either Redis or TTL cache

## Development

Create a virtual environment and install dependencies with poetry. We recommend using in-project virtual environments:
Expand Down
1,688 changes: 916 additions & 772 deletions gdrive/poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions gdrive/provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import os
import nltk

import connexion
from dotenv import load_dotenv
from provider import cache


load_dotenv()

Expand All @@ -24,6 +25,7 @@ def __str__(self):

def create_app():
app = connexion.FlaskApp(__name__, specification_dir="../../.openapi")

app.add_api(
API_VERSION, resolver=connexion.resolver.RelativeResolver("provider.app")
)
Expand All @@ -32,5 +34,5 @@ def create_app():
config_prefix = os.path.split(os.getcwd())[1].upper()
flask_app.config.from_prefixed_env(config_prefix)
flask_app.config["APP_ID"] = config_prefix

cache.init(flask_app.config.get("CACHE_TYPE"), flask_app.config)
return flask_app
1 change: 0 additions & 1 deletion gdrive/provider/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging

from connexion.exceptions import Unauthorized
from flask import abort, request, current_app as app

Expand Down
90 changes: 90 additions & 0 deletions gdrive/provider/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from abc import abstractmethod

import redis
from cachetools import TTLCache
from flask import current_app as app


CACHE_TYPE_NONE = "none"
CACHE_TYPE_MEMORY = "memory"
CACHE_TYPE_REDIS = "redis"

DEFAULT_CACHE_EXPIRE_TIME = 3600


backend = None


class CacheBackend:
def get_cache_key(self, document_id: str) -> str:
return f"document_text_{document_id}"

@abstractmethod
def cache_document_text(self, document_id: str, text: str) -> None:
pass

@abstractmethod
def get_document_text(self, document_id: str) -> str:
pass


class MemoryBackend(CacheBackend):
def __init__(self, config):
self.ttl_cache = TTLCache(
config.get("CACHE_MAXSIZE") or 1000,
config.get("CACHE_EXPIRE_TIME") or DEFAULT_CACHE_EXPIRE_TIME,
)

def cache_document_text(self, document_id: str, text: str) -> None:
cache_key = self.get_cache_key(document_id)
self.ttl_cache[cache_key] = text

def get_document_text(self, document_id: str) -> str:
cache_key = self.get_cache_key(document_id)
return self.ttl_cache.get(cache_key)


class RedisBackend(CacheBackend):
def __init__(self, config):
self.r = redis.Redis(
host=config.get("REDIS_HOST") or "localhost",
port=config.get("REDIS_PORT") or 6379,
db=config.get("REDIS_DB") or 0,
)

self.expire_time = config.get("CACHE_EXPIRE_TIME") or DEFAULT_CACHE_EXPIRE_TIME

def cache_document_text(self, document_id: str, text: str) -> None:
cache_key = self.get_cache_key(document_id)
self.r.set(cache_key, text, self.expire_time)

def get_document_text(self, document_id: str) -> str:
cache_key = self.get_cache_key(document_id)
document_text = self.r.get(cache_key)
return document_text.decode() if document_text else None


CACHE_BACKENDS = {
CACHE_TYPE_MEMORY: MemoryBackend,
CACHE_TYPE_REDIS: RedisBackend,
}


def init(type: str, config) -> None:
global backend

if not type:
return

assert type in CACHE_BACKENDS, "Invalid cache backend"
backend = CACHE_BACKENDS[type](config)


def get_document_text(document_id: str) -> str:
assert backend, "Caching not configured"
return backend.get_document_text(document_id)


def cache_document_text(document_id: str, value: str):
assert backend, "Caching not configured"
backend.cache_document_text(document_id, value)
41 changes: 30 additions & 11 deletions gdrive/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError

from . import UpstreamProviderError, async_download
from . import UpstreamProviderError, async_download, cache

CSV_MIMETYPE = "text/csv"
TEXT_MIMETYPE = "text/plain"
Expand All @@ -29,14 +29,13 @@


def process_data_with_service(data, request_credentials: Credentials):
results = []
results: list[dict] = []
files = data.get("files", [])
if not files:
logger.debug("No files found.")
return results

id_to_urls = extract_links(files)
id_to_texts = async_download.perform(id_to_urls, request_credentials.token)
id_to_texts = retrieve_file_texts(files, request_credentials)

for _file in files:
id = _file.get("id")
Expand Down Expand Up @@ -67,19 +66,39 @@ def process_data_with_service(data, request_credentials: Credentials):
return results


def extract_links(files) -> [str, str]:
id_to_urls = dict()
def retrieve_file_texts(files, request_credentials: Credentials) -> dict[str, str]:
missing_ids_to_urls = {}
id_to_texts = {}

for _file in files:
export_links = _file.pop("exportLinks", {})
id = _file.get("id")

if id is None:
continue

if TEXT_MIMETYPE in export_links:
id_to_urls[id] = export_links[TEXT_MIMETYPE]
elif CSV_MIMETYPE in export_links:
id_to_urls[id] = export_links[CSV_MIMETYPE]
return id_to_urls
cached_text = cache.get_document_text(id) if cache.backend else None

if cached_text is not None:
id_to_texts[id] = cached_text
else:
if TEXT_MIMETYPE in export_links:
missing_ids_to_urls[id] = export_links[TEXT_MIMETYPE]
elif CSV_MIMETYPE in export_links:
missing_ids_to_urls[id] = export_links[CSV_MIMETYPE]

if missing_ids_to_urls:
downloaded_texts = async_download.perform(
missing_ids_to_urls, request_credentials.token
)

if cache.backend:
for document_id in downloaded_texts:
cache.cache_document_text(document_id, downloaded_texts[document_id])

id_to_texts.update(downloaded_texts)

return id_to_texts


def split_and_remove_stopwords(text: str):
Expand Down
3 changes: 3 additions & 0 deletions gdrive/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ google-auth-oauthlib = "^1.1.0"
gunicorn = "^22.0.0"
aiohttp = "^3.9.4"
nltk = "^3.8.1"
redis = "^5.0.8"
cachetools = "^5.5.0"

[tool.poetry.group.development.dependencies]
black = "^24.3.0"
types-requests = "^2.31.0.1"
types-cachetools = "^5.5.0"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading