Skip to content

Commit

Permalink
Fix Typing Issues (#22)
Browse files Browse the repository at this point in the history
Fixes a bunch of type issues identified by `mypy` and bumps minimum
python version to `3.10`. Also introduced `types.py` which can keep
track of types returned by `lotus` operations. There are still `mypy`
errors in the `models` directory but a lot of the model code will
undergo heavy rewrite anyways so leaving that for later. Once all the
`mypy` issues are resolved I will add it to the CI.
  • Loading branch information
sidjha1 authored Oct 28, 2024
1 parent 938a80d commit 8d61eb2
Show file tree
Hide file tree
Showing 33 changed files with 624 additions and 458 deletions.
27 changes: 27 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,33 @@ def test_join(setup_models):
assert joined_pairs == expected_pairs


def test_join_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data1 = {"School": ["UC Berkeley", "Stanford"]}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")])

# All joins resolved by the helper model
joined_df, stats = df1.sem_join(df2, join_instruction, cascade_threshold=0, return_stats=True)
joined_pairs = set(zip(joined_df["School"], joined_df["School Type"]))
assert joined_pairs == expected_pairs
assert stats["filters_resolved_by_large_model"] == 0, stats
assert stats["filters_resolved_by_helper_model"] == 4, stats

# All joins resolved by the large model
joined_df, stats = df1.sem_join(df2, join_instruction, cascade_threshold=1.01, return_stats=True)
joined_pairs = set(zip(joined_df["School"], joined_df["School Type"]))
assert joined_pairs == expected_pairs
assert stats["filters_resolved_by_large_model"] == 4, stats
assert stats["filters_resolved_by_helper_model"] == 0, stats


def test_map_fewshot(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: |
Expand All @@ -43,7 +43,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: "3.9"
python: "3.10"

sphinx:
configuration: docs/conf.py
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

To set up for development, create a conda environment, install lotus, and install additional dev dependencies.
```
conda create -n lotus python=3.9 -y
conda create -n lotus python=3.10 -y
conda activate lotus
git clone [email protected]:stanford-futuredata/lotus.git
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LOTUS offers a number of semantic operators in a Pandas-like API, some of which

# Installation
```
conda create -n lotus python=3.9 -y
conda create -n lotus python=3.10 -y
conda activate lotus
pip install lotus-ai
```
Expand Down
4 changes: 2 additions & 2 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Requirements
------------

* OS: MacOS, Linux
* Python: 3.9
* Python: 3.10

Install with pip
----------------
Expand All @@ -16,6 +16,6 @@ You can install Lotus using pip:

.. code-block:: console
$ conda create -n lotus python=3.9 -y
$ conda create -n lotus python=3.10 -y
$ conda activate lotus
$ pip install lotus-ai
18 changes: 9 additions & 9 deletions lotus/models/colbertv2_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

from lotus.models.rm import RM

Expand All @@ -8,9 +8,9 @@ class ColBERTv2Model(RM):
"""ColBERTv2 Model"""

def __init__(self, **kwargs):
self.docs: Optional[List[str]] = None
self.kwargs: Dict[str, Any] = {"doc_maxlen": 300, "nbits": 2, **kwargs}
self.index_dir: Optional[str] = None
self.docs: list[str] | None = None
self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2, **kwargs}
self.index_dir: str | None = None

from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig, Run, RunConfig
Expand All @@ -21,7 +21,7 @@ def __init__(self, **kwargs):
self.Run = Run
self.RunConfig = RunConfig

def index(self, docs: List[str], index_dir: str, **kwargs: Dict[str, Any]) -> None:
def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None:
kwargs = {**self.kwargs, **kwargs}
checkpoint = "colbert-ir/colbertv2.0"

Expand All @@ -41,15 +41,15 @@ def load_index(self, index_dir: str) -> None:
with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "rb") as fp:
self.docs = pickle.load(fp)

def get_vectors_from_index(self, index_dir: str, ids: List[int]) -> List:
def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list:
raise NotImplementedError("This method is not implemented for ColBERTv2Model")

def __call__(
self,
queries: Union[str, List[str], List[List[float]]],
queries: str | list[str] | list[list[float]],
k: int,
**kwargs: Dict[str, Any],
) -> Tuple[List[float], List[int]]:
**kwargs: dict[str, Any],
) -> tuple[list[float], list[int]]:
if isinstance(queries, str):
queries = [queries]

Expand Down
6 changes: 2 additions & 4 deletions lotus/models/cross_encoder_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List, Optional

import torch
from sentence_transformers import CrossEncoder

Expand All @@ -17,15 +15,15 @@ class CrossEncoderModel(Reranker):
def __init__(
self,
model: str = "mixedbread-ai/mxbai-rerank-large-v1",
device: Optional[str] = None,
device: str | None = None,
**kwargs,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.model = CrossEncoder(model, device=device, **kwargs)

def __call__(self, query: str, docs: List[str], k: int) -> List[int]:
def __call__(self, query: str, docs: list[str], k: int) -> list[int]:
results = self.model.rank(query, docs, top_k=k)
results = [result["corpus_id"] for result in results]
return results
37 changes: 19 additions & 18 deletions lotus/models/e5_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

import numpy as np
import torch
Expand All @@ -14,18 +14,18 @@
class E5Model(RM):
"""E5 retriever model"""

def __init__(self, model: str = "intfloat/e5-base-v2", device: Optional[str] = None, **kwargs):
def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None, **kwargs: dict[str, Any]) -> None:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model).to(self.device)
self.faiss_index = None
self.index_dir = None
self.docs = None
self.kwargs = {"normalize": True, "index_type": "Flat", **kwargs}
self.batch_size = 100
self.vecs = None
self.index_dir: str | None = None
self.docs: list[str] | None = None
self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs}
self.batch_size: int = 100
self.vecs: np.ndarray[Any, np.dtype[np.float32]] | None = None

import faiss

Expand All @@ -45,7 +45,7 @@ def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.T
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:
def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> np.ndarray[Any, np.dtype[np.float32]]:
"""Run the embedding model.
Args:
Expand All @@ -55,10 +55,11 @@ def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:
Embeddings of the documents.
"""

kwargs = {**self.kwargs, **kwargs}
kwargs = {**self.kwargs, **dict(kwargs)}

batch_size = kwargs.get("batch_size", self.batch_size)

assert isinstance(batch_size, int), "batch_size must be an integer"

# Calculating the embedding dimension
total_docs = len(docs)
first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True).to(self.device)
Expand All @@ -79,7 +80,7 @@ def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:

return embeddings.numpy(force=True)

def index(self, docs: List[str], index_dir: str, **kwargs: Dict[str, Any]) -> None:
def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None:
# Make index directory
os.makedirs(index_dir, exist_ok=True)

Expand Down Expand Up @@ -110,17 +111,17 @@ def load_index(self, index_dir: str) -> None:
self.vecs = pickle.load(fp)

@classmethod
def get_vectors_from_index(self, index_dir: str, ids: List[int]) -> List:
def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list[np.ndarray[Any, np.dtype[np.float32]]]:
with open(f"{index_dir}/vecs", "rb") as fp:
vecs = pickle.load(fp)
vecs: np.ndarray[Any, np.dtype[np.float32]] = pickle.load(fp)

return vecs[ids]

def load_vecs(self, index_dir: str, ids: List[int]) -> List:
def load_vecs(self, index_dir: str, ids: list[int]) -> list:
"""loads vectors to the rm and returns them
Args:
index_dir (str): Directory of the index.
ids (List[int]): The ids of the vectors to retrieve
ids (list[int]): The ids of the vectors to retrieve
Returns:
The vectors matching the specified ids.
Expand All @@ -134,10 +135,10 @@ def load_vecs(self, index_dir: str, ids: List[int]) -> List:

def __call__(
self,
queries: Union[str, List[str], List[List[float]]],
queries: str | list[str] | list[list[float]],
k: int,
**kwargs: Dict[str, Any],
) -> Tuple[List[float], List[int]]:
**kwargs: dict[str, Any],
) -> tuple[list[float], list[int]]:
if isinstance(queries, str):
queries = [queries]

Expand Down
34 changes: 23 additions & 11 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Union
from typing import Any


class LM(ABC):
Expand All @@ -9,42 +9,54 @@ def _init__(self):
pass

@abstractmethod
def count_tokens(self, prompt: Union[str, list]) -> int:
def count_tokens(self, prompt: str | list) -> int:
"""
Counts the number of tokens in the given prompt.
Args:
prompt (Union[str, list]): The prompt to count tokens for. This can be a string or a list of messages.
prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages.
Returns:
int: The number of tokens in the prompt.
"""
pass

def format_logprobs_for_cascade(self, logprobs: List) -> Tuple[List[List[str]], List[List[float]]]:
def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]:
"""
Formats the logprobs for the cascade.
Args:
logprobs (List): The logprobs to format.
logprobs (list): The logprobs to format.
Returns:
Tuple[List[List[str]], List[List[float]]]: A tuple containing the tokens and their corresponding confidences.
tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences.
"""
pass

@abstractmethod
def __call__(
self, messages_batch: Union[List, List[List]], **kwargs: Dict[str, Any]
) -> Union[List, Tuple[List, List]]:
self, messages_batch: list | list[list], **kwargs: dict[str, Any]
) -> list[str] | tuple[list[str], list[dict[str, Any]]]:
"""Invoke the LLM.
Args:
messages_batch (Union[List, List[List]]): Either one prompt or a list of prompts in message format.
kwargs (Dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters.
messages_batch (list | list[list]): Either one prompt or a list of prompts in message format.
kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters.
Returns:
Union[List, Tuple[List, List]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments,
list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments,
then a list of logprobs is also returned.
"""
pass

@property
@abstractmethod
def max_ctx_len(self) -> int:
"""The maximum context length of the LLM."""
pass

@property
@abstractmethod
def max_tokens(self) -> int:
"""The maximum number of tokens that can be generated by the LLM."""
pass
Loading

0 comments on commit 8d61eb2

Please sign in to comment.