Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV2-4953: Storing ClassyCat data in Alegre #105

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env_file.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=XXX"
HONEYCOMB_API_ENDPOINT="https://api.honeycomb.io"
REDIS_URL="redis://redis:6379/0"
CACHE_DEFAULT_TTL=86400
ALEGRE_URL="http://host.docker.internal:3100"

CLASSYCAT_OUTPUT_BUCKET="classycat-qa"
CLASSYCAT_BATCH_SIZE_LIMIT=25
Expand Down
1 change: 1 addition & 0 deletions .env_file.test
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=XXX"
HONEYCOMB_API_ENDPOINT="https://api.honeycomb.io"
REDIS_URL="redis://redis:6379/0"
CACHE_DEFAULT_TTL=86400
ALEGRE_URL="http://host.docker.internal:3100"

CLASSYCAT_OUTPUT_BUCKET="classycat-qa"
CLASSYCAT_BATCH_SIZE_LIMIT=25
Expand Down
27 changes: 23 additions & 4 deletions lib/model/classycat_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ def get_client(self):
def classify(self, task_prompt, items_count, max_tokens_per_item=200):
pass


class AnthropicClient(LLMClient):
def __init__(self, model_name):
super().__init__()
self.model_name = model_name

def get_client(self):
if self.client is None:
self.client = Anthropic(api_key=os.environ.get('ANTHROPIC_API_KEY'), timeout=httpx.Timeout(60.0, read=60.0, write=60.0, connect=60.0), max_retries=0)
self.client = Anthropic(api_key=os.environ.get('ANTHROPIC_API_KEY'),
timeout=httpx.Timeout(60.0, read=60.0, write=60.0, connect=60.0), max_retries=0)
return self.client

def classify(self, task_prompt, items_count, max_tokens_per_item=200):
Expand All @@ -43,6 +45,7 @@ def classify(self, task_prompt, items_count, max_tokens_per_item=200):

return completion.content[0].text


class OpenRouterClient(LLMClient):
def __init__(self, model_name):
super().__init__()
Expand All @@ -65,7 +68,7 @@ def classify(self, task_prompt, items_count, max_tokens_per_item=200):
max_tokens=(max_tokens_per_item * items_count) + 15,
temperature=0.5
)
# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)
# TODO: record metric here with model name and number of items submitted (https://meedan.atlassian.net/browse/CV2-4987)
return completion.choices[0].message.content


Expand Down Expand Up @@ -137,17 +140,33 @@ def classify_and_store_results(self, schema_id, items):

result['labels'] = [label for label in result['labels'] if label in permitted_labels]

# if there is at least one item with labels, save the results to s3
if not all([len(result['labels']) == 0 for result in final_results]):
results_file_id = str(uuid.uuid4())
upload_file_to_s3(self.output_bucket, f"{schema_id}/{results_file_id}.json", json.dumps(final_results))

return final_results
# prepare the final results to be stored in alegre
# save "content" and "context"
# content is text, doc_id is the item's unique id, and context is input id, labels, schema_id, and model name
final_results_to_be_stored_in_alegre = {'documents': [
{'doc_id': str(uuid.uuid4()), # adding a unique id for each item to not rely on the input id for uniqueness
'content': items[i]['text'],
Copy link
Collaborator

@computermacgyver computermacgyver Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the /text/similarity endpoint this parameter is text and not content. Let's double check this is the correct name for the bulk endpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for /text/bulk_similarity/ we replace text with content:

# Rename "text" to "content" if present
if 'text' in params:
  params['content'] = params.get('text')
  del params["text"]

We can use both content and text for bulk but it ultimately gets renamed to content in Alegre.

'context': {
'input_id': items[i]['id'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think would would be helpful to put a `'source':'ClassyCat' so if we later want to query this objects there is a key to do it

'labels': final_results[i]['labels'],
'schema_id': schema_id,
'model_name': self.llm_client.model_name}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be the model name for the text vectorization? i.e. "paraphrase-multilingual-mpnet-base-v2"?
These are the parameters timpani used:
https://github.com/meedan/timpani/blob/6021d4fcb251d83ae48ed2c1566a16fad6971450/timpani/model_service/alegre_wrapper_service.py#L122

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within context, model_name (or any parameter) can be anything. We will, however, need to specify the models parameter for Alegre itself to know how to do vectorization. That parameter is currently missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch Scott, I did not know that we should specify model_name, and on local the documents are being index only by Elasticsearch (full text search). I will update the code and redo the tests to make sure vectorization works on local.

for i in range(len(items))]}

# call alegre endpoint to store the results: /text/bulk_similarity/
alegre_url = os.getenv('ALEGRE_URL')
httpx.post(alegre_url + '/text/bulk_similarity/', json=final_results_to_be_stored_in_alegre)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh, does this insert endpoint work? maybe I can switch timpani to use it as well when we convert to batch mode!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We think so, but @DGaffney is currently moving all vectorization to Presto and will need to ensure this endpoint continues to work after that migration 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense Scott, I added a comment on Devin's ticket to make sure he sees it.

@skyemeedan as of this moment the endpoint works fine locally for me, feel free to test it out!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want to support this after we move to presto vectorization. This is a blocking, bulk query which breaks expectations in Presto-based vectorization doubly (not-blocking, and single query per item). This is, to my knowledge, the only location in Check (minus the re-index job in Check-API) that uses bulk_similarity, and if we could get away from this pattern, that would be vastly preferable. Is there any way we can move this? Otherwise we'll be signing up for new complicated measures to support this long term.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • if the bulk_similarity is not going to be supported, then I think we can't build the classycat feature on top of it and will need to find another approach for this PR?
  • My memory is that we decided that everything will be 'default bulk', so probably we need to adjust the design of new endpoints to support bulk?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is reasonable to take on making async endpoints default bulk until after we have made text work on Presto. We've shunted in secondary features to this refactor before and the net effect is that they have greatly complicated our existing migration plan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If supporting bulk requires some refactoring of presto payload structures, wouldn't it be easier to do that before we are supporting the full text processing in live?


return final_results

def schema_id_exists(self, schema_id):
return file_exists_in_s3(self.output_bucket, f"{schema_id}.json")


def process(self, message: Message) -> ClassyCatBatchClassificationResponse:
# Example input:
# {
Expand Down
11 changes: 9 additions & 2 deletions test/lib/model/test_classycat.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,16 @@ def test_schema_lookup(self, file_exists_mock, load_file_from_s3_mock):
self.assertEqual(file_exists_mock.call_count, 1)
self.assertEqual(load_file_from_s3_mock.call_count, 1)

@patch('lib.model.classycat_classify.httpx.post')
@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
@patch('lib.model.classycat_classify.upload_file_to_s3')
@patch('lib.model.classycat_classify.file_exists_in_s3')
def test_classify_success(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
load_file_from_s3_mock, openrouter_classify_mock):
load_file_from_s3_mock, openrouter_classify_mock, httpx_post_mock):
file_exists_in_s3_mock.return_value = True
upload_file_to_s3_mock.return_value = None
httpx_post_mock.return_value = None
load_file_from_s3_mock.return_value = json.dumps(
{
"schema_id": "123456",
Expand Down Expand Up @@ -427,6 +429,7 @@ def test_classify_success(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
self.assertIn("Communalism", result.classification_results[0]['labels'])
self.assertEqual(len(result.classification_results[0]['labels']), 2)
self.assertEqual(upload_file_to_s3_mock.call_count, 1)
self.assertEqual(openrouter_classify_mock.call_count, 1)

@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
Expand Down Expand Up @@ -704,14 +707,16 @@ def test_classify_fail_wrong_number_of_results(self, file_exists_in_s3_mock, upl

self.assertEqual(result.responseMessage, "Error classifying items: Not all items were classified successfully: input length 1, output length 2")

@patch('lib.model.classycat_classify.httpx.post')
@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
@patch('lib.model.classycat_classify.upload_file_to_s3')
@patch('lib.model.classycat_classify.file_exists_in_s3')
def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, upload_file_to_s3_mock,
load_file_from_s3_mock, openrouter_classify_mock):
load_file_from_s3_mock, openrouter_classify_mock, httpx_post_mock):
file_exists_in_s3_mock.return_value = True
upload_file_to_s3_mock.return_value = None
httpx_post_mock.return_value = None
load_file_from_s3_mock.return_value = json.dumps(
{
"schema_id": "123456",
Expand Down Expand Up @@ -853,6 +858,8 @@ def test_classify_pass_some_out_of_schema_labels(self, file_exists_in_s3_mock, u
self.assertListEqual(["Politics", "Communalism"], result.classification_results[0]['labels'])
self.assertListEqual([], result.classification_results[1]['labels'])
self.assertListEqual(["Politics"], result.classification_results[2]['labels'])
self.assertEqual(upload_file_to_s3_mock.call_count, 1)
self.assertEqual(openrouter_classify_mock.call_count, 1)

@patch('lib.model.classycat_classify.OpenRouterClient.classify')
@patch('lib.model.classycat_classify.load_file_from_s3')
Expand Down
Loading