diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index e294957..6c4f357 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -214,7 +214,7 @@ def _build_query_parameters(self) -> Mapping[str, str]: def _build_json_body( self, request: AnyRequest, model: Optional[str] ) -> Mapping[str, Any]: - json_body = request.to_json() + json_body = dict(request.to_json()) if model is not None: json_body["model"] = model @@ -721,7 +721,7 @@ def _build_query_parameters(self) -> Mapping[str, str]: def _build_json_body( self, request: AnyRequest, model: Optional[str] ) -> Mapping[str, Any]: - json_body = request.to_json() + json_body = dict(request.to_json()) if model is not None: json_body["model"] = model diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index fbca704..ac48523 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -1,9 +1,11 @@ -from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Mapping, Optional, Sequence from aleph_alpha_client.prompt import Prompt -class CompletionRequest(NamedTuple): +@dataclass(frozen=True) +class CompletionRequest: """ Describes a completion request @@ -212,13 +214,17 @@ class CompletionRequest(NamedTuple): repetition_penalties_include_completion: bool = True raw_completion: bool = False - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> Mapping[str, Any]: payload = {k: v for k, v in self._asdict().items() if v is not None} payload["prompt"] = self.prompt.to_json() return payload + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) + -class CompletionResult(NamedTuple): +@dataclass(frozen=True) +class CompletionResult: log_probs: Optional[Sequence[Mapping[str, Optional[float]]]] = None completion: Optional[str] = None completion_tokens: Optional[Sequence[str]] = None @@ -235,8 +241,12 @@ def from_json(json: Dict[str, Any]) -> "CompletionResult": raw_completion=json.get("raw_completion"), ) + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class CompletionResponse(NamedTuple): + +@dataclass(frozen=True) +class CompletionResponse: model_version: str completions: Sequence[CompletionResult] optimized_prompt: Optional[Prompt] = None @@ -259,3 +269,6 @@ def to_json(self) -> Mapping[str, Any]: **self._asdict(), "completions": [completion._asdict() for completion in self.completions], } + + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) diff --git a/aleph_alpha_client/detokenization.py b/aleph_alpha_client/detokenization.py index e63a5aa..52bd974 100644 --- a/aleph_alpha_client/detokenization.py +++ b/aleph_alpha_client/detokenization.py @@ -1,7 +1,9 @@ -from typing import Any, Dict, List, NamedTuple, Optional, Sequence +from dataclasses import dataclass, asdict +from typing import Any, Dict, List, Mapping, Optional, Sequence -class DetokenizationRequest(NamedTuple): +@dataclass(frozen=True) +class DetokenizationRequest: """Describes a detokenization request. Parameters @@ -14,12 +16,16 @@ class DetokenizationRequest(NamedTuple): token_ids: Sequence[int] - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> Mapping[str, Any]: payload = self._asdict() return payload + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class DetokenizationResponse(NamedTuple): + +@dataclass(frozen=True) +class DetokenizationResponse: result: str @staticmethod diff --git a/aleph_alpha_client/embedding.py b/aleph_alpha_client/embedding.py index a5637d3..d0fdb41 100644 --- a/aleph_alpha_client/embedding.py +++ b/aleph_alpha_client/embedding.py @@ -1,9 +1,10 @@ +from dataclasses import asdict, dataclass from enum import Enum from typing import ( Any, Dict, List, - NamedTuple, + Mapping, Optional, Sequence, Tuple, @@ -11,7 +12,8 @@ from aleph_alpha_client.prompt import Prompt -class EmbeddingRequest(NamedTuple): +@dataclass(frozen=True) +class EmbeddingRequest: """ Embeds a text and returns vectors that can be used for downstream tasks (e.g. semantic similarity) and models (e.g. classifiers). @@ -73,13 +75,18 @@ class EmbeddingRequest(NamedTuple): contextual_control_threshold: Optional[float] = None control_log_additive: Optional[bool] = True - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - payload["prompt"] = self.prompt.to_json() - return payload + def to_json(self) -> Mapping[str, Any]: + return { + **self._asdict(), + "prompt": self.prompt.to_json(), + } + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class EmbeddingResponse(NamedTuple): + +@dataclass(frozen=True) +class EmbeddingResponse: model_version: str embeddings: Optional[Dict[Tuple[str, str], List[float]]] tokens: Optional[List[str]] @@ -120,7 +127,8 @@ class SemanticRepresentation(Enum): Query = "query" -class SemanticEmbeddingRequest(NamedTuple): +@dataclass(frozen=True) +class SemanticEmbeddingRequest: """ Embeds a text and returns vectors that can be used for downstream tasks (e.g. semantic similarity) and models (e.g. classifiers). @@ -181,14 +189,19 @@ class SemanticEmbeddingRequest(NamedTuple): contextual_control_threshold: Optional[float] = None control_log_additive: Optional[bool] = True - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - payload["representation"] = self.representation.value - payload["prompt"] = self.prompt.to_json() - return payload + def to_json(self) -> Mapping[str, Any]: + return { + **self._asdict(), + "representation": self.representation.value, + "prompt": self.prompt.to_json(), + } + + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class BatchSemanticEmbeddingRequest(NamedTuple): +@dataclass(frozen=True) +class BatchSemanticEmbeddingRequest: """ Embeds multiple multi-modal prompts and returns their embeddings in the same order as they were supplied. @@ -246,17 +259,22 @@ class BatchSemanticEmbeddingRequest(NamedTuple): contextual_control_threshold: Optional[float] = None control_log_additive: Optional[bool] = True - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - payload["representation"] = self.representation.value - payload["prompts"] = [prompt.to_json() for prompt in self.prompts] - return payload + def to_json(self) -> Mapping[str, Any]: + return { + **self._asdict(), + "representation": self.representation.value, + "prompts": [prompt.to_json() for prompt in self.prompts], + } + + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) EmbeddingVector = List[float] -class SemanticEmbeddingResponse(NamedTuple): +@dataclass(frozen=True) +class SemanticEmbeddingResponse: """ Response of a semantic embedding request @@ -275,10 +293,15 @@ class SemanticEmbeddingResponse(NamedTuple): @staticmethod def from_json(json: Dict[str, Any]) -> "SemanticEmbeddingResponse": - return SemanticEmbeddingResponse(**json) + return SemanticEmbeddingResponse( + model_version=json["model_version"], + embedding=json["embedding"], + message=json.get("message"), + ) -class BatchSemanticEmbeddingResponse(NamedTuple): +@dataclass(frozen=True) +class BatchSemanticEmbeddingResponse: """ Response of a batch semantic embedding request @@ -294,7 +317,9 @@ class BatchSemanticEmbeddingResponse(NamedTuple): @staticmethod def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse": - return BatchSemanticEmbeddingResponse(**json) + return BatchSemanticEmbeddingResponse( + model_version=json["model_version"], embeddings=json["embeddings"] + ) @staticmethod def _from_model_version_and_embeddings( diff --git a/aleph_alpha_client/evaluation.py b/aleph_alpha_client/evaluation.py index 8d9377e..3f8ebf1 100644 --- a/aleph_alpha_client/evaluation.py +++ b/aleph_alpha_client/evaluation.py @@ -1,13 +1,15 @@ +from dataclasses import dataclass, asdict from typing import ( Any, Dict, - NamedTuple, + Mapping, Optional, ) from aleph_alpha_client.prompt import Prompt -class EvaluationRequest(NamedTuple): +@dataclass(frozen=True) +class EvaluationRequest: """ Evaluates the model's likelihood to produce a completion given a prompt. @@ -39,13 +41,15 @@ class EvaluationRequest(NamedTuple): contextual_control_threshold: Optional[float] = None control_log_additive: Optional[bool] = True - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - payload["prompt"] = self.prompt.to_json() - return payload + def to_json(self) -> Mapping[str, Any]: + return {**self._asdict(), "prompt": self.prompt.to_json()} + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class EvaluationResponse(NamedTuple): + +@dataclass(frozen=True) +class EvaluationResponse: model_version: str message: Optional[str] result: Dict[str, Any] diff --git a/aleph_alpha_client/explanation.py b/aleph_alpha_client/explanation.py index e2ca6ef..9c962c9 100644 --- a/aleph_alpha_client/explanation.py +++ b/aleph_alpha_client/explanation.py @@ -1,10 +1,10 @@ +from dataclasses import dataclass from enum import Enum from typing import ( Any, List, Dict, Mapping, - NamedTuple, Optional, Union, ) @@ -34,7 +34,8 @@ def to_json(self) -> str: return self.value -class CustomGranularity(NamedTuple): +@dataclass(frozen=True) +class CustomGranularity: """ Allows for passing a custom delimiter to determine the granularity to to explain the prompt by. The text of the prompt will be split by the @@ -58,7 +59,7 @@ class PromptGranularity(Enum): Sentence = "sentence" Paragraph = "paragraph" - def to_json(self): + def to_json(self) -> Mapping[str, Any]: return {"type": self.value} @@ -90,7 +91,8 @@ def to_json(self) -> str: return self.value -class ExplanationRequest(NamedTuple): +@dataclass(frozen=True) +class ExplanationRequest: """ Describes an Explanation request you want to make agains the API. @@ -177,7 +179,7 @@ class ExplanationRequest(NamedTuple): postprocessing: Optional[ExplanationPostprocessing] = None normalize: Optional[bool] = None - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> Mapping[str, Any]: payload: Dict[str, Any] = { "prompt": self.prompt.to_json(), "target": self.target, @@ -206,7 +208,8 @@ def to_json(self) -> Dict[str, Any]: return payload -class TextScore(NamedTuple): +@dataclass(frozen=True) +class TextScore: start: int length: int score: float @@ -220,7 +223,8 @@ def from_json(score: Any) -> "TextScore": ) -class TextScoreWithRaw(NamedTuple): +@dataclass(frozen=True) +class TextScoreWithRaw: start: int length: int score: float @@ -236,7 +240,8 @@ def from_text_score(score: TextScore, prompt: Text) -> "TextScoreWithRaw": ) -class ImageScore(NamedTuple): +@dataclass(frozen=True) +class ImageScore: left: float top: float width: float @@ -254,7 +259,8 @@ def from_json(score: Any) -> "ImageScore": ) -class TargetScore(NamedTuple): +@dataclass(frozen=True) +class TargetScore: start: int length: int score: float @@ -268,7 +274,8 @@ def from_json(score: Any) -> "TargetScore": ) -class TargetScoreWithRaw(NamedTuple): +@dataclass(frozen=True) +class TargetScoreWithRaw: start: int length: int score: float @@ -284,7 +291,8 @@ def from_target_score(score: TargetScore, target: str) -> "TargetScoreWithRaw": ) -class TokenScore(NamedTuple): +@dataclass(frozen=True) +class TokenScore: score: float @staticmethod @@ -294,7 +302,8 @@ def from_json(score: Any) -> "TokenScore": ) -class ImagePromptItemExplanation(NamedTuple): +@dataclass(frozen=True) +class ImagePromptItemExplanation: """ Explains the importance of an image prompt item. The amount of items in the "scores" array depends on the granularity setting. @@ -328,7 +337,8 @@ def in_pixels(self, prompt_item: PromptItem) -> "ImagePromptItemExplanation": ) -class TextPromptItemExplanation(NamedTuple): +@dataclass(frozen=True) +class TextPromptItemExplanation: """ Explains the importance of a text prompt item. The amount of items in the "scores" array depends on the granularity setting. @@ -355,7 +365,8 @@ def with_text(self, prompt: Text) -> "TextPromptItemExplanation": ) -class TargetPromptItemExplanation(NamedTuple): +@dataclass(frozen=True) +class TargetPromptItemExplanation: """ Explains the importance of text in the target string that came before the currently to-be-explained target token. The amount of items in the "scores" array depends on the @@ -383,7 +394,8 @@ def with_text(self, prompt: str) -> "TargetPromptItemExplanation": ) -class TokenPromptItemExplanation(NamedTuple): +@dataclass(frozen=True) +class TokenPromptItemExplanation: """Explains the importance of a request prompt item of type "token_ids". Will contain one floating point importance value for each token in the same order as in the original prompt. """ @@ -397,7 +409,8 @@ def from_json(item: Dict[str, Any]) -> "TokenPromptItemExplanation": ) -class Explanation(NamedTuple): +@dataclass(frozen=True) +class Explanation: """ Explanations for a given portion of the target. @@ -482,7 +495,8 @@ def with_text_from_prompt(self, prompt: Prompt, target: str) -> "Explanation": ) -class ExplanationResponse(NamedTuple): +@dataclass(frozen=True) +class ExplanationResponse: """ The top-level response data structure that will be returned from an explanation request. diff --git a/aleph_alpha_client/prompt.py b/aleph_alpha_client/prompt.py index fdab870..242cd78 100644 --- a/aleph_alpha_client/prompt.py +++ b/aleph_alpha_client/prompt.py @@ -8,7 +8,6 @@ Dict, List, Mapping, - NamedTuple, Optional, Sequence, Tuple, @@ -42,7 +41,8 @@ def to_json(self) -> str: return self.value -class TokenControl(NamedTuple): +@dataclass(frozen=True) +class TokenControl: """ Used for Attention Manipulation, for a given token index, you can supply the factor you want to adjust the attention by. @@ -69,7 +69,8 @@ def to_json(self) -> Mapping[str, Any]: return {"index": self.pos, "factor": self.factor} -class Tokens(NamedTuple): +@dataclass(frozen=True) +class Tokens: """ A list of token ids to be sent as part of a prompt. @@ -108,7 +109,8 @@ def from_token_ids(token_ids: Sequence[int]) -> "Tokens": return Tokens(token_ids, []) -class TextControl(NamedTuple): +@dataclass(frozen=True) +class TextControl: """ Attention manipulation for a Text PromptItem. @@ -141,7 +143,7 @@ class TextControl(NamedTuple): factor: float token_overlap: Optional[ControlTokenOverlap] = None - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> Mapping[str, Any]: payload: Dict[str, Any] = { "start": self.start, "length": self.length, @@ -152,7 +154,8 @@ def to_json(self) -> Dict[str, Any]: return payload -class Text(NamedTuple): +@dataclass(frozen=True) +class Text: """ A Text-prompt including optional controls for attention manipulation. @@ -186,7 +189,8 @@ def from_text(text: str) -> "Text": return Text(text, []) -class Cropping(NamedTuple): +@dataclass(frozen=True) +class Cropping: """ Describes a quadratic crop of the file. """ @@ -196,7 +200,8 @@ class Cropping(NamedTuple): size: int -class ImageControl(NamedTuple): +@dataclass(frozen=True) +class ImageControl: """ Attention manipulation for an Image PromptItem. @@ -261,7 +266,8 @@ def to_json(self) -> Mapping[str, Any]: return payload -class Image(NamedTuple): +@dataclass(frozen=True) +class Image: """ An image send as part of a prompt to a model. The image is represented as base64. @@ -395,7 +401,7 @@ def _get_url(cls, url: str) -> bytes: response.raise_for_status() return response.content - def to_json(self) -> Dict[str, Any]: + def to_json(self) -> Mapping[str, Any]: """ A dict if serialized to JSON is suitable as a prompt element """ diff --git a/aleph_alpha_client/qa.py b/aleph_alpha_client/qa.py index ecf0411..d8e8891 100644 --- a/aleph_alpha_client/qa.py +++ b/aleph_alpha_client/qa.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, Mapping, NamedTuple, Optional, Sequence +from dataclasses import asdict, dataclass +from typing import Any, Dict, Mapping, Optional, Sequence from aleph_alpha_client.document import Document -class QaRequest(NamedTuple): - """DEPRECATED: `QaRequest` is deprecated and will be removed in the next major release. New +@dataclass(frozen=True) +class QaRequest: + """DEPRECATED: `QaRequest` is deprecated and will be removed in the future. New methods of processing Q&A tasks will be provided before this is removed. Answers a question about a prompt. @@ -30,18 +32,24 @@ class QaRequest(NamedTuple): documents: Sequence[Document] max_answers: Optional[int] = None - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - payload["documents"] = [ - document._to_serializable_document() for document in self.documents - ] + def to_json(self) -> Mapping[str, Any]: + payload = { + **self._asdict(), + "documents": [ + document._to_serializable_document() for document in self.documents + ], + } if self.max_answers is None: del payload["max_answers"] return payload + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) + -class QaAnswer(NamedTuple): - """DEPRECATED: `QaAnswer` is deprecated and will be removed in the next major release. New +@dataclass(frozen=True) +class QaAnswer: + """DEPRECATED: `QaAnswer` is deprecated and will be removed in the future. New methods of processing Q&A tasks will be provided before this is removed. """ @@ -50,8 +58,9 @@ class QaAnswer(NamedTuple): evidence: str -class QaResponse(NamedTuple): - """DEPRECATED: `QaResponse` is deprecated and will be removed in the next major release. New +@dataclass(frozen=True) +class QaResponse: + """DEPRECATED: `QaResponse` is deprecated and will be removed in the future. New methods of processing Q&A tasks will be provided before this is removed. """ diff --git a/aleph_alpha_client/summarization.py b/aleph_alpha_client/summarization.py index a86a12b..84809dc 100644 --- a/aleph_alpha_client/summarization.py +++ b/aleph_alpha_client/summarization.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, Mapping, NamedTuple, Sequence +from dataclasses import asdict, dataclass +from typing import Any, Dict, Mapping, Sequence from aleph_alpha_client.document import Document -class SummarizationRequest(NamedTuple): - """DEPRECATED: `SummarizationRequest` is deprecated and will be removed in the next major release. New +@dataclass(frozen=True) +class SummarizationRequest: + """DEPRECATED: `SummarizationRequest` is deprecated and will be removed in the future. New methods of processing Summarization tasks will be provided before this is removed. Summarizes a document. @@ -35,14 +37,16 @@ class SummarizationRequest(NamedTuple): document: Document disable_optimizations: bool = False - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - payload["document"] = self.document._to_serializable_document() - return payload + def to_json(self) -> Mapping[str, Any]: + return {**self._asdict(), "document": self.document._to_serializable_document()} + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class SummarizationResponse(NamedTuple): - """DEPRECATED: `SummarizationResponse` is deprecated and will be removed in the next major release. New + +@dataclass(frozen=True) +class SummarizationResponse: + """DEPRECATED: `SummarizationResponse` is deprecated and will be removed in the future. New methods of processing Summarization tasks will be provided before this is removed. """ diff --git a/aleph_alpha_client/tokenization.py b/aleph_alpha_client/tokenization.py index 0ea084a..48944ff 100644 --- a/aleph_alpha_client/tokenization.py +++ b/aleph_alpha_client/tokenization.py @@ -1,7 +1,9 @@ -from typing import Any, Dict, NamedTuple, Optional, Sequence +from dataclasses import asdict, dataclass +from typing import Any, Dict, Mapping, Optional, Sequence -class TokenizationRequest(NamedTuple): +@dataclass(frozen=True) +class TokenizationRequest: """Describes a tokenization request. Parameters @@ -25,15 +27,20 @@ class TokenizationRequest(NamedTuple): tokens: bool token_ids: bool - def to_json(self) -> Dict[str, Any]: - payload = self._asdict() - return payload + def to_json(self) -> Mapping[str, Any]: + return self._asdict() + def _asdict(self) -> Mapping[str, Any]: + return asdict(self) -class TokenizationResponse(NamedTuple): + +@dataclass(frozen=True) +class TokenizationResponse: tokens: Optional[Sequence[str]] = None token_ids: Optional[Sequence[int]] = None @staticmethod def from_json(json: Dict[str, Any]) -> "TokenizationResponse": - return TokenizationResponse(**json) + return TokenizationResponse( + tokens=json.get("tokens"), token_ids=json.get("token_ids") + ) diff --git a/tests/test_embed.py b/tests/test_embed.py index 61ef650..fdd7750 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -115,13 +115,15 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ representation=SemanticRepresentation.Symmetric, ) model_name = "test_model" - expected_body = request.to_json() - expected_body["model"] = model_name + expected_body = { + **request.to_json(), + "model": model_name, + } httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) ).respond_with_json({"model_version": "1", "embeddings": []}) async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1) - _resp = await async_client.batch_semantic_embed(request, model=model_name) + await async_client.batch_semantic_embed(request, model=model_name) # Client @@ -212,10 +214,9 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer): representation=SemanticRepresentation.Symmetric, ) model_name = "test_model" - expected_body = request.to_json() - expected_body["model"] = model_name + expected_body = {**request.to_json(), "model": model_name} httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) ).respond_with_json({"model_version": "1", "embeddings": []}) sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1) - _resp = sync_client.batch_semantic_embed(request, model=model_name) + sync_client.batch_semantic_embed(request, model=model_name)