Skip to content

Commit

Permalink
Turn all NamedTuples into dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
volkerstampa committed Oct 18, 2023
1 parent 608dc5f commit d408127
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 102 deletions.
4 changes: 2 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _build_query_parameters(self) -> Mapping[str, str]:
def _build_json_body(
self, request: AnyRequest, model: Optional[str]
) -> Mapping[str, Any]:
json_body = request.to_json()
json_body = dict(request.to_json())

if model is not None:
json_body["model"] = model
Expand Down Expand Up @@ -721,7 +721,7 @@ def _build_query_parameters(self) -> Mapping[str, str]:
def _build_json_body(
self, request: AnyRequest, model: Optional[str]
) -> Mapping[str, Any]:
json_body = request.to_json()
json_body = dict(request.to_json())

if model is not None:
json_body["model"] = model
Expand Down
23 changes: 18 additions & 5 deletions aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Mapping, Optional, Sequence

from aleph_alpha_client.prompt import Prompt


class CompletionRequest(NamedTuple):
@dataclass(frozen=True)
class CompletionRequest:
"""
Describes a completion request
Expand Down Expand Up @@ -212,13 +214,17 @@ class CompletionRequest(NamedTuple):
repetition_penalties_include_completion: bool = True
raw_completion: bool = False

def to_json(self) -> Dict[str, Any]:
def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in self._asdict().items() if v is not None}
payload["prompt"] = self.prompt.to_json()
return payload

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)


class CompletionResult(NamedTuple):
@dataclass(frozen=True)
class CompletionResult:
log_probs: Optional[Sequence[Mapping[str, Optional[float]]]] = None
completion: Optional[str] = None
completion_tokens: Optional[Sequence[str]] = None
Expand All @@ -235,8 +241,12 @@ def from_json(json: Dict[str, Any]) -> "CompletionResult":
raw_completion=json.get("raw_completion"),
)

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)

class CompletionResponse(NamedTuple):

@dataclass(frozen=True)
class CompletionResponse:
model_version: str
completions: Sequence[CompletionResult]
optimized_prompt: Optional[Prompt] = None
Expand All @@ -259,3 +269,6 @@ def to_json(self) -> Mapping[str, Any]:
**self._asdict(),
"completions": [completion._asdict() for completion in self.completions],
}

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)
14 changes: 10 additions & 4 deletions aleph_alpha_client/detokenization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, Dict, List, NamedTuple, Optional, Sequence
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Mapping, Optional, Sequence


class DetokenizationRequest(NamedTuple):
@dataclass(frozen=True)
class DetokenizationRequest:
"""Describes a detokenization request.
Parameters
Expand All @@ -14,12 +16,16 @@ class DetokenizationRequest(NamedTuple):

token_ids: Sequence[int]

def to_json(self) -> Dict[str, Any]:
def to_json(self) -> Mapping[str, Any]:
payload = self._asdict()
return payload

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)

class DetokenizationResponse(NamedTuple):

@dataclass(frozen=True)
class DetokenizationResponse:
result: str

@staticmethod
Expand Down
71 changes: 48 additions & 23 deletions aleph_alpha_client/embedding.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from dataclasses import asdict, dataclass
from enum import Enum
from typing import (
Any,
Dict,
List,
NamedTuple,
Mapping,
Optional,
Sequence,
Tuple,
)
from aleph_alpha_client.prompt import Prompt


class EmbeddingRequest(NamedTuple):
@dataclass(frozen=True)
class EmbeddingRequest:
"""
Embeds a text and returns vectors that can be used for downstream tasks (e.g. semantic similarity) and models (e.g. classifiers).
Expand Down Expand Up @@ -73,13 +75,18 @@ class EmbeddingRequest(NamedTuple):
contextual_control_threshold: Optional[float] = None
control_log_additive: Optional[bool] = True

def to_json(self) -> Dict[str, Any]:
payload = self._asdict()
payload["prompt"] = self.prompt.to_json()
return payload
def to_json(self) -> Mapping[str, Any]:
return {
**self._asdict(),
"prompt": self.prompt.to_json(),
}

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)

class EmbeddingResponse(NamedTuple):

@dataclass(frozen=True)
class EmbeddingResponse:
model_version: str
embeddings: Optional[Dict[Tuple[str, str], List[float]]]
tokens: Optional[List[str]]
Expand Down Expand Up @@ -120,7 +127,8 @@ class SemanticRepresentation(Enum):
Query = "query"


class SemanticEmbeddingRequest(NamedTuple):
@dataclass(frozen=True)
class SemanticEmbeddingRequest:
"""
Embeds a text and returns vectors that can be used for downstream tasks (e.g. semantic similarity) and models (e.g. classifiers).
Expand Down Expand Up @@ -181,14 +189,19 @@ class SemanticEmbeddingRequest(NamedTuple):
contextual_control_threshold: Optional[float] = None
control_log_additive: Optional[bool] = True

def to_json(self) -> Dict[str, Any]:
payload = self._asdict()
payload["representation"] = self.representation.value
payload["prompt"] = self.prompt.to_json()
return payload
def to_json(self) -> Mapping[str, Any]:
return {
**self._asdict(),
"representation": self.representation.value,
"prompt": self.prompt.to_json(),
}

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)


class BatchSemanticEmbeddingRequest(NamedTuple):
@dataclass(frozen=True)
class BatchSemanticEmbeddingRequest:
"""
Embeds multiple multi-modal prompts and returns their embeddings in the same order as they were supplied.
Expand Down Expand Up @@ -246,17 +259,22 @@ class BatchSemanticEmbeddingRequest(NamedTuple):
contextual_control_threshold: Optional[float] = None
control_log_additive: Optional[bool] = True

def to_json(self) -> Dict[str, Any]:
payload = self._asdict()
payload["representation"] = self.representation.value
payload["prompts"] = [prompt.to_json() for prompt in self.prompts]
return payload
def to_json(self) -> Mapping[str, Any]:
return {
**self._asdict(),
"representation": self.representation.value,
"prompts": [prompt.to_json() for prompt in self.prompts],
}

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)


EmbeddingVector = List[float]


class SemanticEmbeddingResponse(NamedTuple):
@dataclass(frozen=True)
class SemanticEmbeddingResponse:
"""
Response of a semantic embedding request
Expand All @@ -275,10 +293,15 @@ class SemanticEmbeddingResponse(NamedTuple):

@staticmethod
def from_json(json: Dict[str, Any]) -> "SemanticEmbeddingResponse":
return SemanticEmbeddingResponse(**json)
return SemanticEmbeddingResponse(
model_version=json["model_version"],
embedding=json["embedding"],
message=json.get("message"),
)


class BatchSemanticEmbeddingResponse(NamedTuple):
@dataclass(frozen=True)
class BatchSemanticEmbeddingResponse:
"""
Response of a batch semantic embedding request
Expand All @@ -294,7 +317,9 @@ class BatchSemanticEmbeddingResponse(NamedTuple):

@staticmethod
def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse":
return BatchSemanticEmbeddingResponse(**json)
return BatchSemanticEmbeddingResponse(
model_version=json["model_version"], embeddings=json["embeddings"]
)

@staticmethod
def _from_model_version_and_embeddings(
Expand Down
18 changes: 11 additions & 7 deletions aleph_alpha_client/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from dataclasses import dataclass, asdict
from typing import (
Any,
Dict,
NamedTuple,
Mapping,
Optional,
)
from aleph_alpha_client.prompt import Prompt


class EvaluationRequest(NamedTuple):
@dataclass(frozen=True)
class EvaluationRequest:
"""
Evaluates the model's likelihood to produce a completion given a prompt.
Expand Down Expand Up @@ -39,13 +41,15 @@ class EvaluationRequest(NamedTuple):
contextual_control_threshold: Optional[float] = None
control_log_additive: Optional[bool] = True

def to_json(self) -> Dict[str, Any]:
payload = self._asdict()
payload["prompt"] = self.prompt.to_json()
return payload
def to_json(self) -> Mapping[str, Any]:
return {**self._asdict(), "prompt": self.prompt.to_json()}

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)

class EvaluationResponse(NamedTuple):

@dataclass(frozen=True)
class EvaluationResponse:
model_version: str
message: Optional[str]
result: Dict[str, Any]
Expand Down
Loading

0 comments on commit d408127

Please sign in to comment.