Skip to content

Commit

Permalink
feat: Update argilla to 2.0 in intelligence layer (#1007)
Browse files Browse the repository at this point in the history
* refactor: Remove DefaultArgillaClient + use Argilla native types
TASK: PHS-732
---------

Co-authored-by: Sebastian Niehus <[email protected]>
  • Loading branch information
MerlinKallenbornAA and SebastianNiehusAA authored Aug 28, 2024
1 parent 21cd529 commit 5854bca
Show file tree
Hide file tree
Showing 12 changed files with 864 additions and 1,680 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/sdk-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ jobs:
discovery.type: "single-node"
xpack.security.enabled: "false"
argilla:
image: argilla/argilla-server:v1.26.1
image: argilla/argilla-server:v1.29.1
ports:
- "6900:6900"
env:
Expand Down Expand Up @@ -202,7 +202,7 @@ jobs:
discovery.type: "single-node"
xpack.security.enabled: "false"
argilla:
image: argilla/argilla-server:v1.26.1
image: argilla/argilla-server:v1.29.1
ports:
- "6900:6900"
env:
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
- Abstract `LanguageModel` class to integrate with LLMs from any API
- Abstract `ChatModel` class to integrate with chat models from any API
- Every `LanguageModel` supports echo to retrieve log probs for an expected completion given a prompt
- Upgrade `ArgillaWrapperClient` to use Argilla v2.0.1

### Fixes
- increase number of returned `log_probs` in `EloQaEvaluationLogic` to avoid missing a valid answer

### Deprecations
...
- Removed `DefaultArgillaClient`

### Breaking Changes
- Upgrade argilla-server image version from `argilla-server:v1.26.0`to `argilla-server:v1.29.0`
...

## 5.1.0
Expand Down
1,200 changes: 595 additions & 605 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ opentelemetry-exporter-otlp-proto-http = "1.23.0"
rouge-score = "^0.1.2"
sacrebleu = "^2.4.3"
lingua-language-detector = "^2.0.2"
argilla = "^1.29.1"
argilla = "^2.0.1"

[tool.poetry.group.dev.dependencies]
# lint & format
Expand Down Expand Up @@ -85,7 +85,6 @@ addopts = "--capture=tee-sys"
filterwarnings = [ #ignore: message : warning : location ?
'ignore:.*\`general_plain_validator_function\` is deprecated.*',
'ignore::DeprecationWarning:.*importlib._bootstrap.*',
'ignore:.*DefaultArgillaClient.*:DeprecationWarning:'
]

[tool.ruff]
Expand Down
3 changes: 0 additions & 3 deletions src/intelligence_layer/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from .argilla.argilla_wrapper_client import (
ArgillaWrapperClient as ArgillaWrapperClient,
)
from .argilla.default_client import DefaultArgillaClient as DefaultArgillaClient
from .argilla.default_client import Field as Field
from .argilla.default_client import Question as Question
from .base.json_serializable import JsonSerializable as JsonSerializable
from .base.json_serializable import SerializableDict as SerializableDict
from .document_index.document_index import CollectionPath as CollectionPath
Expand Down
204 changes: 118 additions & 86 deletions src/intelligence_layer/connectors/argilla/argilla_wrapper_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
)

import argilla as rg # type: ignore
from argilla.client.feedback.schemas.types import ( # type: ignore
AllowedFieldTypes,
AllowedQuestionTypes,
)

from intelligence_layer.connectors.argilla.argilla_client import (
ArgillaClient,
Expand All @@ -37,24 +33,24 @@ def __init__(
logging.WARNING
)

rg.init(
self.client = rg.Argilla(
api_url=api_url if api_url is not None else os.getenv("ARGILLA_API_URL"),
api_key=api_key if api_key is not None else os.getenv("ARGILLA_API_KEY"),
)

def create_dataset(
self,
workspace_id: str,
workspace_name: str,
dataset_name: str,
fields: Sequence[AllowedFieldTypes],
questions: Sequence[AllowedQuestionTypes],
fields: Sequence[rg.TextField],
questions: Sequence[rg.QuestionType],
) -> str:
"""Creates and publishes a new feedback dataset in Argilla.
Raises an error if the name exists already.
Args:
workspace_id: the name of the workspace the feedback dataset should be created in.
workspace_name: the name of the workspace the feedback dataset should be created in.
The user executing this request must have corresponding permissions for this workspace.
dataset_name: the name of the feedback-dataset to be created.
fields: all fields of this dataset.
Expand All @@ -63,27 +59,34 @@ def create_dataset(
Returns:
The id of the created dataset.
"""
dataset = rg.FeedbackDataset(
fields=fields, questions=questions, allow_extra_metadata=True
settings = rg.Settings(
fields=list(fields),
questions=list(questions),
allow_extra_metadata=True,
)
remote_datasets = dataset.push_to_argilla(

workspace = self.ensure_workspace_exists(workspace_name)

dataset = rg.Dataset(
name=dataset_name,
workspace=rg.Workspace.from_name(workspace_id),
show_progress=False,
settings=settings,
workspace=workspace,
client=self.client,
)
return str(remote_datasets.id)
dataset.create()
return str(dataset.id)

def ensure_dataset_exists(
self,
workspace_id: str,
workspace_name: str,
dataset_name: str,
fields: Sequence[AllowedFieldTypes],
questions: Sequence[AllowedQuestionTypes],
fields: Sequence[rg.TextField],
questions: Sequence[rg.QuestionType],
) -> str:
"""Retrieves an existing dataset or creates and publishes a new feedback dataset in Argilla.
Args:
workspace_id: the name of the workspace the feedback dataset should be created in.
workspace_name: the name of the workspace the feedback dataset should be created in.
The user executing this request must have corresponding permissions for this workspace.
dataset_name: the name of the feedback-dataset to be created.
fields: all fields of this dataset.
Expand All @@ -92,50 +95,51 @@ def ensure_dataset_exists(
Returns:
The id of the dataset to be retrieved .
"""
try:
return str(
rg.FeedbackDataset.from_argilla(
name=dataset_name, workspace=rg.Workspace.from_name(workspace_id)
).id
)
except ValueError:
pass
return self.create_dataset(workspace_id, dataset_name, fields, questions)
dataset = self.client.datasets(name=dataset_name, workspace=workspace_name)
return (
str(dataset.id)
if dataset
else self.create_dataset(workspace_name, dataset_name, fields, questions)
)

def add_record(self, dataset_id: str, record: RecordData) -> None:
self.add_records(dataset_id=dataset_id, records=[record])

def add_records(self, dataset_id: str, records: Sequence[RecordData]) -> None:
remote_dataset = self._dataset_from_id(dataset_id=dataset_id)
if remote_dataset is None:
raise ValueError
argilla_records = [
rg.FeedbackRecord(
fields=record.content,
rg.Record(
fields=dict(record.content),
metadata={
**record.metadata,
"example_id": record.example_id,
},
)
for record in records
]
remote_dataset.add_records(argilla_records, show_progress=False)
remote_dataset.records.log(records=argilla_records)

def evaluations(self, dataset_id: str) -> Iterable[ArgillaEvaluation]:
remote_dataset = self._dataset_from_id(dataset_id=dataset_id)
filtered_dataset = remote_dataset.filter_by(response_status="submitted")

for record in filtered_dataset.records:
submitted_response = next((response for response in record.responses), None)
if submitted_response is not None:
metadata = record.metadata
example_id = metadata.pop("example_id")
yield ArgillaEvaluation(
example_id=example_id,
record_id="ignored",
responses={
k: v.value for k, v in submitted_response.values.items()
},
metadata=metadata,
)

status_filter = rg.Filter([("response.status", "==", "submitted")])
query = rg.Query(filter=status_filter)

for record in remote_dataset.records(query=query):
metadata = record.metadata
example_id = metadata.pop("example_id")
yield ArgillaEvaluation(
example_id=example_id,
record_id="ignored",
responses={
response.question_name: response.value
for response in record.responses
if response is not None
},
metadata=metadata,
)

def split_dataset(self, dataset_id: str, n_splits: int) -> None:
"""Adds a new metadata property to the dataset and assigns a split to each record.
Expand All @@ -147,46 +151,49 @@ def split_dataset(self, dataset_id: str, n_splits: int) -> None:
n_splits: the number of splits to create
"""
remote_dataset = self._dataset_from_id(dataset_id=dataset_id)

name = "split"
metadata_config = remote_dataset.metadata_property_by_name(name)

metadata_config = remote_dataset.settings.metadata[name]

if n_splits == 1:
if metadata_config is None:
return
remote_dataset.delete_metadata_properties(name)
metadata_config.delete()
self._delete_metadata_from_records(remote_dataset, name)
return

if metadata_config is None:
config = rg.IntegerMetadataProperty(
name=name, visible_for_annotators=True, min=1, max=n_splits
)
remote_dataset.add_metadata_property(config)
else:
metadata_config.max = n_splits
remote_dataset.update_metadata_properties(metadata_config)
if metadata_config:
metadata_config.delete()
self._delete_metadata_from_records(remote_dataset, name)

self._update_record_metadata(n_splits, remote_dataset, name)

def _update_record_metadata(
self, n_splits: int, remote_dataset: rg.FeedbackDataset, metadata_name: str
self, n_splits: int, remote_dataset: rg.Dataset, metadata_name: str
) -> None:
modified_records = []
config = rg.IntegerMetadataProperty(
name=metadata_name, visible_for_annotators=True, min=1, max=n_splits
)
remote_dataset.settings.metadata = [config]
remote_dataset.update()
updated_records = []
for record, split in zip(
remote_dataset.records, itertools.cycle(range(1, n_splits + 1))
):
record.metadata[metadata_name] = split
modified_records.append(record)
remote_dataset.update_records(modified_records, show_progress=False)
updated_records.append(record)

remote_dataset.records.log(updated_records)

def _delete_metadata_from_records(
self, remote_dataset: rg.FeedbackDataset, metadata_name: str
self, remote_dataset: rg.Dataset, metadata_name: str
) -> None:
modified_records = []
for record in remote_dataset.records:
del record.metadata[metadata_name]
modified_records.append(record)
remote_dataset.update_records(modified_records, show_progress=False)
remote_dataset.records.log(modified_records)

def ensure_workspace_exists(self, workspace_name: str) -> str:
"""Retrieves the name of an argilla workspace with specified name or creates a new workspace if necessary.
Expand All @@ -197,11 +204,17 @@ def ensure_workspace_exists(self, workspace_name: str) -> str:
Returns:
The name of an argilla workspace with the given `workspace_name`.
"""
try:
workspace = rg.Workspace.from_name(workspace_name)
return str(workspace.name)
except ValueError:
return str(rg.Workspace.create(name=workspace_name).name)
workspace = self.client.workspaces(name=workspace_name)
if workspace:
return workspace_name

workspace = rg.Workspace(name=workspace_name, client=self.client)
workspace.create()
if not workspace:
raise ValueError(
f"Workspace with name {workspace_name} could not be created."
)
return str(workspace.name)

def records(self, dataset_id: str) -> Iterable[Record]:
remote_dataset = self._dataset_from_id(dataset_id=dataset_id)
Expand All @@ -215,31 +228,50 @@ def records(self, dataset_id: str) -> Iterable[Record]:
for record in remote_dataset.records
)

def _create_evaluation(self, record_id: str, data: dict[str, Any]) -> None:
api_url = os.environ["ARGILLA_API_URL"]
if not api_url.endswith("/"):
api_url = api_url + "/"
rg.active_client().http_client.post(
f"{api_url}api/v1/records/{record_id}/responses",
json={
"status": "submitted",
"values": {
question_name: {"value": response_value}
for question_name, response_value in data.items()
},
},
)
def _create_evaluation(
self, dataset_id: str, record_id: str, data: dict[str, Any]
) -> None:
dataset = self._dataset_from_id(dataset_id=dataset_id)
if dataset is None:
raise ValueError(f"Dataset with id {dataset_id} does not exist.")
records = dataset.records

user_id = self.client.me.id
if user_id is None:
raise ValueError("user_id is not a UUID")

# argilla currently does not allow to retrieve a record by id
# This could be optimized (in a scenario for creating multiple evaluations) by passing a dict of record_ids to the function
# and update all the records for the given record id list.
for record in records:
if record.id == record_id:
for question_name, response_value in data.items():
response = rg.Response(
question_name=question_name,
value=response_value,
status="submitted",
user_id=user_id,
)
record.responses.add(response)
dataset.records.log([record])
return

def _delete_dataset(self, dataset_id: str) -> None:
remote_dataset = self._dataset_from_id(dataset_id=dataset_id)
remote_dataset.delete()

def _delete_workspace(self, workspace_name: str) -> None:
workspace = rg.Workspace.from_name(workspace_name)
datasets = rg.list_datasets(workspace=workspace.name)
for dataset in datasets:
workspace = self.client.workspaces(name=workspace_name)
if workspace is None:
raise ValueError("workspace does not exist.")
for dataset in workspace.datasets:
dataset.delete()
workspace.delete()

def _dataset_from_id(self, dataset_id: str) -> rg.FeedbackDataset:
return rg.FeedbackDataset.from_argilla(id=dataset_id)
def _dataset_from_id(self, dataset_id: str) -> rg.Dataset:
dataset = self.client.datasets(id=dataset_id)
if not dataset:
raise ValueError("Dataset is not existent")
# Fetch Settings from Dataset in order to pass questions, necessary since Argilla V2
dataset.settings.get()
return dataset
Loading

0 comments on commit 5854bca

Please sign in to comment.