diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 78ec9098..b41f9b52 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -25,3 +25,154 @@ moto = {extras = ["s3"], version = "^5.0.7"} [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +exclude = [".venv", "alembic"] +required-version = ">=0.3.5" +# change the default line length number or characters. +line-length = 88 + +[tool.ruff.format] +exclude = [".venv", "**/__init__.py"] +quote-style = "single" +indent-style = "space" +docstring-code-format = true + +[tool.ruff.lint] +select = [ + "B002", # Python does not support the unary prefix increment + "B005", # Using .strip() with multi-character strings is misleading + "B007", # Loop control variable {name} not used within loop body + "B014", # Exception handler with duplicate exception + "B015", # Pointless comparison. Did you mean to assign a value? Otherwise, prepend assert or remove it. + "B018", # Found useless attribute access. Either assign it to a variable or remove it. + "B023", # Function definition does not bind loop variable {name} + "B026", # Star-arg unpacking after a keyword argument is strongly discouraged + "B032", # Possible unintentional type annotation (using :). Did you mean to assign (using =)? + "B904", # Use raise from to specify exception cause + "C", # complexity + "COM818", # Trailing comma on bare tuple prohibited + "D", # docstrings + "DTZ003", # Use datetime.now(tz=) instead of datetime.utcnow() + "DTZ004", # Use datetime.fromtimestamp(ts, tz=) instead of datetime.utcfromtimestamp(ts) + "E", # pycodestyle + "F", # pyflakes/autoflake + "G", # flake8-logging-format + "I", # isort + "ISC", # flake8-implicit-str-concat + "ICN001", # import concentions; {name} should be imported as {asname} + "LOG", # flake8-logging + "N804", # First argument of a class method should be named cls + "N805", # First argument of a method should be named self + "N815", # Variable {name} in class scope should not be mixedCase + "PERF", # Perflint + "PGH004", # Use specific rule codes when using noqa + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "RET", # flake8-return + "RSE", # flake8-raise + "RUF005", # Consider iterable unpacking instead of concatenation + "RUF006", # Store a reference to the return value of asyncio.create_task + # "RUF100", # Unused `noqa` directive; temporarily every now and then to clean them up + "S102", # Use of exec detected + "S103", # bad-file-permissions + "S108", # hardcoded-temp-file + "S306", # suspicious-mktemp-usage + "S307", # suspicious-eval-usage + "S313", # suspicious-xmlc-element-tree-usage + "S314", # suspicious-xml-element-tree-usage + "S315", # suspicious-xml-expat-reader-usage + "S316", # suspicious-xml-expat-builder-usage + "S317", # suspicious-xml-sax-usage + "S318", # suspicious-xml-mini-dom-usage + "S319", # suspicious-xml-pull-dom-usage + "S320", # suspicious-xmle-tree-usage + "S601", # paramiko-call + "S602", # subprocess-popen-with-shell-equals-true + "S604", # call-with-shell-equals-true + "S608", # hardcoded-sql-expression + "S609", # unix-command-wildcard-injection + "SIM", # flake8-simplify + "T100", # Trace found: {name} used + "T20", # flake8-print + "TID251", # Banned imports + "TRY", # tryceratops + "UP", # pyupgrade + "W", # pycodestyle +] + +ignore = [ + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D107", # Missing docstring in `__init__` + "D202", # No blank lines allowed after function docstring + "D203", # 1 blank line required before class docstring + "D205", # 1 blank line required between summary line and description + "D213", # Multi-line docstring summary should start at the second line + "D400", # First line should end with a period + "D406", # Section name should end with a newline + "D407", # Section name underlining + "D415", # First line should end with a period, question mark, or exclamation point + "E501", # line too long + "E731", # do not assign a lambda expression, use a def + "PLC1901", # {existing} can be simplified to {replacement} as an empty string is falsey; too many false positives + "PLR0911", # Too many return statements ({returns} > {max_returns}) + "PLR0912", # Too many branches ({branches} > {max_branches}) + "PLR0913", # Too many arguments to function call ({c_args} > {max_args}) + "PLR0915", # Too many statements ({statements} > {max_statements}) + "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable + "PLW2901", # Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target + "PT004", # Fixture {fixture} does not return anything, add leading underscore + "PT011", # pytest.raises({exception}) is too broad, set the `match` parameter or use a more specific exception + "PT012", # `pytest.raises()` block should contain a single simple statement + "PT018", # Assertion should be broken down into multiple parts + "SIM102", # Use a single if statement instead of nested if statements + "SIM108", # Use ternary operator {contents} instead of if-else-block + "SIM115", # Use context handler for opening files + "TRY003", # Avoid specifying long messages outside the exception class + "TRY400", # Use `logging.exception` instead of `logging.error` + "UP006", # keep type annotation style as is + "UP007", # keep type annotation style as is + # Ignored due to performance: https://github.com/charliermarsh/ruff/issues/2923 + "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` + # Ignored due to incompatible with mypy: https://github.com/python/mypy/issues/15238 + "UP040", # Checks for use of TypeAlias annotation for declaring type aliases. + + # May conflict with the formatter, https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + "W191", + "E111", + "E114", + "E117", + "D206", + "D300", + "Q", + "COM812", + "COM819", + "ISC001", + + # Disabled because ruff does not understand type of __all__ generated by a function + "PLE0605" +] + +[tool.ruff.lint.extend-per-file-ignores] +"__init__.py" = ["I001"] + +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = false +mark-parentheses = false + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"async_timeout".msg = "use asyncio.timeout instead" +"pytz".msg = "use zoneinfo instead" + +[tool.ruff.lint.isort] +force-sort-within-sections = true +combine-as-imports = true +split-on-trailing-comma = false + +[tool.ruff.lint.mccabe] +max-complexity = 25 \ No newline at end of file diff --git a/sdk/radicalbit_platform_sdk/apis/__init__.py b/sdk/radicalbit_platform_sdk/apis/__init__.py index 15d230d8..d6255496 100644 --- a/sdk/radicalbit_platform_sdk/apis/__init__.py +++ b/sdk/radicalbit_platform_sdk/apis/__init__.py @@ -2,4 +2,4 @@ from .model_reference_dataset import ModelReferenceDataset from .model import Model -__all__ = ["Model", "ModelCurrentDataset", "ModelReferenceDataset"] +__all__ = ['Model', 'ModelCurrentDataset', 'ModelReferenceDataset'] diff --git a/sdk/radicalbit_platform_sdk/apis/model.py b/sdk/radicalbit_platform_sdk/apis/model.py index 3564a37f..5e3d4211 100644 --- a/sdk/radicalbit_platform_sdk/apis/model.py +++ b/sdk/radicalbit_platform_sdk/apis/model.py @@ -1,26 +1,28 @@ +import os +from typing import List, Optional +from uuid import UUID + +import boto3 +from botocore.exceptions import ClientError as BotoClientError +import pandas as pd +from pydantic import ValidationError +import requests + +from radicalbit_platform_sdk.apis import ModelCurrentDataset, ModelReferenceDataset from radicalbit_platform_sdk.commons import invoke -from radicalbit_platform_sdk.apis import ModelReferenceDataset, ModelCurrentDataset +from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( + AwsCredentials, ColumnDefinition, + CurrentFileUpload, + DataType, + FileReference, Granularity, ModelDefinition, ModelType, - DataType, OutputType, ReferenceFileUpload, - CurrentFileUpload, - FileReference, - AwsCredentials, ) -from radicalbit_platform_sdk.errors import ClientError -from botocore.exceptions import ClientError as BotoClientError -from pydantic import ValidationError -from typing import Optional, List -from uuid import UUID -import boto3 -import os -import pandas as pd -import requests class Model: @@ -81,8 +83,8 @@ def delete(self) -> None: :return: None """ invoke( - method="DELETE", - url=f"{self.__base_url}/api/models/{str(self.__uuid)}", + method='DELETE', + url=f'{self.__base_url}/api/models/{str(self.__uuid)}', valid_response_code=200, func=lambda _: None, ) @@ -93,7 +95,7 @@ def load_reference_dataset( bucket: str, object_name: Optional[str] = None, aws_credentials: Optional[AwsCredentials] = None, - separator: str = ",", + separator: str = ',', ) -> ModelReferenceDataset: """Upload reference dataset to an S3 bucket and then bind it inside the platform. @@ -115,11 +117,11 @@ def load_reference_dataset( if set(required_headers).issubset(file_headers): if object_name is None: - object_name = f"{self.__uuid}/reference/{os.path.basename(file_name)}" + object_name = f'{self.__uuid}/reference/{os.path.basename(file_name)}' try: s3_client = boto3.client( - "s3", + 's3', aws_access_key_id=( None if aws_credentials is None @@ -142,30 +144,30 @@ def load_reference_dataset( bucket, object_name, ExtraArgs={ - "Metadata": { - "model_uuid": str(self.__uuid), - "model_name": self.__name, - "file_type": "reference", + 'Metadata': { + 'model_uuid': str(self.__uuid), + 'model_name': self.__name, + 'file_type': 'reference', } }, ) except BotoClientError as e: raise ClientError( - f"Unable to upload file {file_name} to remote storage: {e}" - ) + f'Unable to upload file {file_name} to remote storage: {e}' + ) from e return self.__bind_reference_dataset( - f"s3://{bucket}/{object_name}", separator - ) - else: - raise ClientError( - f"File {file_name} not contains all defined columns: {required_headers}" + f's3://{bucket}/{object_name}', separator ) + raise ClientError( + f'File {file_name} not contains all defined columns: {required_headers}' + ) from None + def bind_reference_dataset( self, dataset_url: str, aws_credentials: Optional[AwsCredentials] = None, - separator: str = ",", + separator: str = ',', ) -> ModelReferenceDataset: """Bind an existing reference dataset file already uploded to S3 to a `Model` @@ -175,11 +177,11 @@ def bind_reference_dataset( :return: An instance of `ModelReferenceDataset` representing the reference dataset """ - url_parts = dataset_url.replace("s3://", "").split("/") + url_parts = dataset_url.replace('s3://', '').split('/') try: s3_client = boto3.client( - "s3", + 's3', aws_access_key_id=( None if aws_credentials is None else aws_credentials.access_key_id ), @@ -194,27 +196,27 @@ def bind_reference_dataset( ) chunks_iterator = s3_client.get_object( - Bucket=url_parts[0], Key="/".join(url_parts[1:]) - )["Body"].iter_chunks() + Bucket=url_parts[0], Key='/'.join(url_parts[1:]) + )['Body'].iter_chunks() - chunks = "" - for c in (chunk for chunk in chunks_iterator if "\n" not in chunks): - chunks += c.decode("UTF-8") + chunks = '' + for c in (chunk for chunk in chunks_iterator if '\n' not in chunks): + chunks += c.decode('UTF-8') - file_headers = chunks.split("\n")[0].split(separator) + file_headers = chunks.split('\n')[0].split(separator) required_headers = self.__required_headers() if set(required_headers).issubset(file_headers): return self.__bind_reference_dataset(dataset_url, separator) - else: - raise ClientError( - f"File {dataset_url} not contains all defined columns: {required_headers}" - ) + + raise ClientError( + f'File {dataset_url} not contains all defined columns: {required_headers}' + ) from None except BotoClientError as e: raise ClientError( - f"Unable to get file {dataset_url} from remote storage: {e}" - ) + f'Unable to get file {dataset_url} from remote storage: {e}' + ) from e def load_current_dataset( self, @@ -223,7 +225,7 @@ def load_current_dataset( correlation_id_column: str, object_name: Optional[str] = None, aws_credentials: Optional[AwsCredentials] = None, - separator: str = ",", + separator: str = ',', ) -> ModelCurrentDataset: """Upload current dataset to an S3 bucket and then bind it inside the platform. @@ -248,11 +250,11 @@ def load_current_dataset( if set(required_headers).issubset(file_headers): if object_name is None: - object_name = f"{self.__uuid}/current/{os.path.basename(file_name)}" + object_name = f'{self.__uuid}/current/{os.path.basename(file_name)}' try: s3_client = boto3.client( - "s3", + 's3', aws_access_key_id=( None if aws_credentials is None @@ -275,31 +277,31 @@ def load_current_dataset( bucket, object_name, ExtraArgs={ - "Metadata": { - "model_uuid": str(self.__uuid), - "model_name": self.__name, - "file_type": "reference", + 'Metadata': { + 'model_uuid': str(self.__uuid), + 'model_name': self.__name, + 'file_type': 'reference', } }, ) except BotoClientError as e: raise ClientError( - f"Unable to upload file {file_name} to remote storage: {e}" - ) + f'Unable to upload file {file_name} to remote storage: {e}' + ) from e return self.__bind_current_dataset( - f"s3://{bucket}/{object_name}", separator, correlation_id_column - ) - else: - raise ClientError( - f"File {file_name} not contains all defined columns: {required_headers}" + f's3://{bucket}/{object_name}', separator, correlation_id_column ) + raise ClientError( + f'File {file_name} not contains all defined columns: {required_headers}' + ) from None + def bind_current_dataset( self, dataset_url: str, correlation_id_column: str, aws_credentials: Optional[AwsCredentials] = None, - separator: str = ",", + separator: str = ',', ) -> ModelCurrentDataset: """Bind an existing current dataset file already uploded to S3 to a `Model` @@ -310,11 +312,11 @@ def bind_current_dataset( :return: An instance of `ModelReferenceDataset` representing the reference dataset """ - url_parts = dataset_url.replace("s3://", "").split("/") + url_parts = dataset_url.replace('s3://', '').split('/') try: s3_client = boto3.client( - "s3", + 's3', aws_access_key_id=( None if aws_credentials is None else aws_credentials.access_key_id ), @@ -329,14 +331,14 @@ def bind_current_dataset( ) chunks_iterator = s3_client.get_object( - Bucket=url_parts[0], Key="/".join(url_parts[1:]) - )["Body"].iter_chunks() + Bucket=url_parts[0], Key='/'.join(url_parts[1:]) + )['Body'].iter_chunks() - chunks = "" - for c in (chunk for chunk in chunks_iterator if "\n" not in chunks): - chunks += c.decode("UTF-8") + chunks = '' + for c in (chunk for chunk in chunks_iterator if '\n' not in chunks): + chunks += c.decode('UTF-8') - file_headers = chunks.split("\n")[0].split(separator) + file_headers = chunks.split('\n')[0].split(separator) required_headers = self.__required_headers() required_headers.append(correlation_id_column) @@ -346,14 +348,14 @@ def bind_current_dataset( return self.__bind_current_dataset( dataset_url, separator, correlation_id_column ) - else: - raise ClientError( - f"File {dataset_url} not contains all defined columns: {required_headers}" - ) + + raise ClientError( + f'File {dataset_url} not contains all defined columns: {required_headers}' + ) from None except BotoClientError as e: raise ClientError( - f"Unable to get file {dataset_url} from remote storage: {e}" - ) + f'Unable to get file {dataset_url} from remote storage: {e}' + ) from e def __bind_reference_dataset( self, @@ -366,14 +368,14 @@ def __callback(response: requests.Response) -> ModelReferenceDataset: return ModelReferenceDataset( self.__base_url, self.__uuid, self.__model_type, response ) - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e file_ref = FileReference(file_url=dataset_url, separator=separator) return invoke( - method="POST", - url=f"{self.__base_url}/api/models/{str(self.__uuid)}/reference/bind", + method='POST', + url=f'{self.__base_url}/api/models/{str(self.__uuid)}/reference/bind', valid_response_code=200, func=__callback, data=file_ref.model_dump_json(), @@ -391,8 +393,8 @@ def __callback(response: requests.Response) -> ModelCurrentDataset: return ModelCurrentDataset( self.__base_url, self.__uuid, self.__model_type, response ) - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e file_ref = FileReference( file_url=dataset_url, @@ -401,8 +403,8 @@ def __callback(response: requests.Response) -> ModelCurrentDataset: ) return invoke( - method="POST", - url=f"{self.__base_url}/api/models/{str(self.__uuid)}/current/bind", + method='POST', + url=f'{self.__base_url}/api/models/{str(self.__uuid)}/current/bind', valid_response_code=200, func=__callback, data=file_ref.model_dump_json(), diff --git a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py index fc5f4c14..114533d0 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_current_dataset.py @@ -1,17 +1,19 @@ +from typing import Optional +from uuid import UUID + +from pydantic import ValidationError +import requests + from radicalbit_platform_sdk.commons import invoke +from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - ModelType, + BinaryClassDrift, CurrentFileUpload, - JobStatus, DatasetStats, Drift, - BinaryClassDrift, + JobStatus, + ModelType, ) -from radicalbit_platform_sdk.errors import ClientError -from pydantic import ValidationError -from typing import Optional -import requests -from uuid import UUID class ModelCurrentDataset: @@ -51,8 +53,7 @@ def status(self) -> str: return self.__status def statistics(self) -> Optional[DatasetStats]: - """ - Get statistics about the current dataset + """Get statistics about the current dataset :return: The `DatasetStats` if exists """ @@ -62,17 +63,18 @@ def __callback( ) -> tuple[JobStatus, Optional[DatasetStats]]: try: response_json = response.json() - job_status = JobStatus(response_json["jobStatus"]) - if "statistics" in response_json: + job_status = JobStatus(response_json['jobStatus']) + if 'statistics' in response_json: return job_status, DatasetStats.model_validate( - response_json["statistics"] + response_json['statistics'] ) - else: - return job_status, None - except KeyError as _: - raise ClientError(f"Unable to parse response: {response.text}") - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + + except KeyError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + else: + return job_status, None match self.__status: case JobStatus.ERROR: @@ -80,16 +82,16 @@ def __callback( case JobStatus.SUCCEEDED: if self.__statistics is None: _, stats = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics', valid_response_code=200, func=__callback, ) self.__statistics = stats case JobStatus.IMPORTING: status, stats = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/statistics', valid_response_code=200, func=__callback, ) @@ -99,8 +101,7 @@ def __callback( return self.__statistics def drift(self) -> Optional[Drift]: - """ - Get drift about the current dataset + """Get drift about the current dataset :return: The `Drift` if exists """ @@ -110,23 +111,22 @@ def __callback( ) -> tuple[JobStatus, Optional[Drift]]: try: response_json = response.json() - job_status = JobStatus(response_json["jobStatus"]) - if "drift" in response_json: + job_status = JobStatus(response_json['jobStatus']) + if 'drift' in response_json: if self.__model_type is ModelType.BINARY: return ( job_status, - BinaryClassDrift.model_validate(response_json["drift"]), - ) - else: - raise ClientError( - "Unable to parse get metrics for not binary models" + BinaryClassDrift.model_validate(response_json['drift']), ) - else: - return job_status, None - except KeyError as _: - raise ClientError(f"Unable to parse response: {response.text}") - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + raise ClientError( + 'Unable to parse get metrics for not binary models' + ) from None + except KeyError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + else: + return job_status, None match self.__status: case JobStatus.ERROR: @@ -134,16 +134,16 @@ def __callback( case JobStatus.SUCCEEDED: if self.__drift is None: _, drift = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift', valid_response_code=200, func=__callback, ) self.__drift = drift case JobStatus.IMPORTING: status, drift = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/current/{str(self.__uuid)}/drift', valid_response_code=200, func=__callback, ) diff --git a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py index 2d99e2bb..45424a4d 100644 --- a/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py +++ b/sdk/radicalbit_platform_sdk/apis/model_reference_dataset.py @@ -1,19 +1,21 @@ +from typing import Optional +from uuid import UUID + +from pydantic import ValidationError +import requests + from radicalbit_platform_sdk.commons import invoke +from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - JobStatus, - ReferenceFileUpload, + BinaryClassificationDataQuality, + BinaryClassificationModelQuality, + DataQuality, DatasetStats, + JobStatus, ModelQuality, - DataQuality, ModelType, - BinaryClassificationModelQuality, - BinaryClassificationDataQuality, + ReferenceFileUpload, ) -from radicalbit_platform_sdk.errors import ClientError -from pydantic import ValidationError -from typing import Optional -import requests -from uuid import UUID class ModelReferenceDataset: @@ -48,8 +50,7 @@ def status(self) -> str: return self.__status def statistics(self) -> Optional[DatasetStats]: - """ - Get statistics about the current dataset + """Get statistics about the current dataset :return: The `DatasetStats` if exists """ @@ -59,17 +60,17 @@ def __callback( ) -> tuple[JobStatus, Optional[DatasetStats]]: try: response_json = response.json() - job_status = JobStatus(response_json["jobStatus"]) - if "statistics" in response_json: + job_status = JobStatus(response_json['jobStatus']) + if 'statistics' in response_json: return job_status, DatasetStats.model_validate( - response_json["statistics"] + response_json['statistics'] ) - else: - return job_status, None - except KeyError as _: - raise ClientError(f"Unable to parse response: {response.text}") - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + except KeyError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + else: + return job_status, None match self.__status: case JobStatus.ERROR: @@ -77,16 +78,16 @@ def __callback( case JobStatus.SUCCEEDED: if self.__statistics is None: _, stats = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/statistics", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/statistics', valid_response_code=200, func=__callback, ) self.__statistics = stats case JobStatus.IMPORTING: status, stats = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/statistics", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/statistics', valid_response_code=200, func=__callback, ) @@ -96,8 +97,7 @@ def __callback( return self.__statistics def data_quality(self) -> Optional[DataQuality]: - """ - Get data quality metrics about the current dataset + """Get data quality metrics about the current dataset :return: The `DataQuality` if exists """ @@ -107,25 +107,24 @@ def __callback( ) -> tuple[JobStatus, Optional[DataQuality]]: try: response_json = response.json() - job_status = JobStatus(response_json["jobStatus"]) - if "dataQuality" in response_json: + job_status = JobStatus(response_json['jobStatus']) + if 'dataQuality' in response_json: if self.__model_type is ModelType.BINARY: return ( job_status, BinaryClassificationDataQuality.model_validate( - response_json["dataQuality"] + response_json['dataQuality'] ), ) - else: - raise ClientError( - "Unable to parse get metrics for not binary models" - ) - else: - return job_status, None - except KeyError as _: - raise ClientError(f"Unable to parse response: {response.text}") - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + raise ClientError( + 'Unable to parse get metrics for not binary models' + ) from None + except KeyError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + else: + return job_status, None match self.__status: case JobStatus.ERROR: @@ -133,16 +132,16 @@ def __callback( case JobStatus.SUCCEEDED: if self.__data_metrics is None: _, metrics = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/data-quality", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/data-quality', valid_response_code=200, func=__callback, ) self.__data_metrics = metrics case JobStatus.IMPORTING: status, metrics = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/data-quality", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/data-quality', valid_response_code=200, func=__callback, ) @@ -152,8 +151,7 @@ def __callback( return self.__data_metrics def model_quality(self) -> Optional[ModelQuality]: - """ - Get model quality metrics about the current dataset + """Get model quality metrics about the current dataset :return: The `ModelQuality` if exists """ @@ -163,25 +161,24 @@ def __callback( ) -> tuple[JobStatus, Optional[ModelQuality]]: try: response_json = response.json() - job_status = JobStatus(response_json["jobStatus"]) - if "modelQuality" in response_json: + job_status = JobStatus(response_json['jobStatus']) + if 'modelQuality' in response_json: if self.__model_type is ModelType.BINARY: return ( job_status, BinaryClassificationModelQuality.model_validate( - response_json["modelQuality"] + response_json['modelQuality'] ), ) - else: - raise ClientError( - "Unable to parse get metrics for not binary models" - ) - else: - return job_status, None - except KeyError as _: - raise ClientError(f"Unable to parse response: {response.text}") - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + raise ClientError( + 'Unable to parse get metrics for not binary models' + ) from None + except KeyError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e + else: + return job_status, None match self.__status: case JobStatus.ERROR: @@ -189,16 +186,16 @@ def __callback( case JobStatus.SUCCEEDED: if self.__model_metrics is None: _, metrics = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/model-quality", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/model-quality', valid_response_code=200, func=__callback, ) self.__model_metrics = metrics case JobStatus.IMPORTING: status, metrics = invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/model-quality", + method='GET', + url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/reference/model-quality', valid_response_code=200, func=__callback, ) diff --git a/sdk/radicalbit_platform_sdk/client.py b/sdk/radicalbit_platform_sdk/client.py index 36b54346..4d48f226 100644 --- a/sdk/radicalbit_platform_sdk/client.py +++ b/sdk/radicalbit_platform_sdk/client.py @@ -1,15 +1,17 @@ -from radicalbit_platform_sdk.commons import invoke +from typing import List +from uuid import UUID + +from pydantic import ValidationError +import requests + from radicalbit_platform_sdk.apis import Model +from radicalbit_platform_sdk.commons import invoke +from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( CreateModel, ModelDefinition, PaginatedModelDefinitions, ) -from radicalbit_platform_sdk.errors import ClientError -from pydantic import ValidationError -from typing import List -from uuid import UUID -import requests class Client: @@ -21,12 +23,12 @@ def __callback(response: requests.Response) -> Model: try: response_model = ModelDefinition.model_validate(response.json()) return Model(self.__base_url, response_model) - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e return invoke( - method="POST", - url=f"{self.__base_url}/api/models", + method='POST', + url=f'{self.__base_url}/api/models', valid_response_code=201, func=__callback, data=model.model_dump_json(), @@ -37,34 +39,31 @@ def __callback(response: requests.Response) -> Model: try: response_model = ModelDefinition.model_validate(response.json()) return Model(self.__base_url, response_model) - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e return invoke( - method="GET", - url=f"{self.__base_url}/api/models/{str(id)}", + method='GET', + url=f'{self.__base_url}/api/models/{str(id)}', valid_response_code=200, func=__callback, ) def search_models(self) -> List[Model]: - def __callback(response: requests.Response) -> Model: + def __callback(response: requests.Response) -> List[Model]: try: paginated_response = PaginatedModelDefinitions.model_validate( response.json() ) - return list( - map( - lambda model: Model(self.__base_url, model), - paginated_response.items, - ) - ) - except ValidationError as _: - raise ClientError(f"Unable to parse response: {response.text}") + return [ + Model(self.__base_url, model) for model in paginated_response.items + ] + except ValidationError as e: + raise ClientError(f'Unable to parse response: {response.text}') from e return invoke( - method="GET", - url=f"{self.__base_url}/api/models", + method='GET', + url=f'{self.__base_url}/api/models', valid_response_code=200, func=__callback, ) diff --git a/sdk/radicalbit_platform_sdk/commons/__init__.py b/sdk/radicalbit_platform_sdk/commons/__init__.py index 4f501872..d748f621 100644 --- a/sdk/radicalbit_platform_sdk/commons/__init__.py +++ b/sdk/radicalbit_platform_sdk/commons/__init__.py @@ -1,3 +1,3 @@ from .rest_utils import invoke -__all__ = ["invoke"] +__all__ = ['invoke'] diff --git a/sdk/radicalbit_platform_sdk/commons/rest_utils.py b/sdk/radicalbit_platform_sdk/commons/rest_utils.py index 97a857ab..3f0376b9 100644 --- a/sdk/radicalbit_platform_sdk/commons/rest_utils.py +++ b/sdk/radicalbit_platform_sdk/commons/rest_utils.py @@ -1,12 +1,13 @@ +from typing import Callable + +import requests + from radicalbit_platform_sdk.errors import ( - NetworkError, APIError, + NetworkError, ServerError, UnhandledResponseCode, ) -from typing import Callable - -import requests def invoke( @@ -17,19 +18,19 @@ def invoke( except requests.RequestException as e: raise NetworkError( f'Network error: {{"method": "{method}", "url": "{url}", "exception": "{e}"}}' - ) + ) from e match response.status_code: case code if 500 <= code <= 599: raise ServerError( f'Server Error: {{"method": "{method}", "url": "{url}", "status_code": "{response.status_code}", "response": "{response.text}"}}' - ) + ) from None case code if 400 <= code <= 499: raise APIError( f'API Error: {{"method": "{method}", "url": "{url}", "status_code": "{response.status_code}", "response": "{response.text}"}}' - ) + ) from None case code if code == valid_response_code: return func(response) case _: raise UnhandledResponseCode( f'Unhandled Response Code Error: {{"method": "{method}", "url": "{url}", "status_code": "{response.status_code}", "response": "{response.text}"}}' - ) + ) from None diff --git a/sdk/radicalbit_platform_sdk/errors/__init__.py b/sdk/radicalbit_platform_sdk/errors/__init__.py index b42a6f1f..9595a15d 100644 --- a/sdk/radicalbit_platform_sdk/errors/__init__.py +++ b/sdk/radicalbit_platform_sdk/errors/__init__.py @@ -1,17 +1,17 @@ from .errors import ( + APIError, + ClientError, Error, NetworkError, - ClientError, - APIError, ServerError, UnhandledResponseCode, ) __all__ = [ - "Error", - "NetworkError", - "ClientError", - "APIError", - "ServerError", - "UnhandledResponseCode", + 'Error', + 'NetworkError', + 'ClientError', + 'APIError', + 'ServerError', + 'UnhandledResponseCode', ] diff --git a/sdk/radicalbit_platform_sdk/models/__init__.py b/sdk/radicalbit_platform_sdk/models/__init__.py index c05b67ad..34b28e9e 100644 --- a/sdk/radicalbit_platform_sdk/models/__init__.py +++ b/sdk/radicalbit_platform_sdk/models/__init__.py @@ -1,83 +1,83 @@ -from .model_definition import ( - OutputType, - Granularity, - CreateModel, - ModelDefinition, - PaginatedModelDefinitions, -) -from .file_upload_result import ReferenceFileUpload, CurrentFileUpload, FileReference +from .aws_credentials import AwsCredentials +from .column_definition import ColumnDefinition from .data_type import DataType -from .model_type import ModelType -from .job_status import JobStatus -from .dataset_stats import DatasetStats -from .dataset_model_quality import ( - ModelQuality, - BinaryClassificationModelQuality, - MultiClassModelQuality, - RegressionModelQuality, -) from .dataset_data_quality import ( - DataQuality, BinaryClassificationDataQuality, - MultiClassDataQuality, - RegressionDataQuality, + CategoricalFeatureMetrics, + CategoryFrequency, + ClassMedianMetrics, ClassMetrics, + DataQuality, + FeatureMetrics, MedianMetrics, MissingValue, - ClassMedianMetrics, - FeatureMetrics, + MultiClassDataQuality, NumericalFeatureMetrics, - CategoryFrequency, - CategoricalFeatureMetrics, + RegressionDataQuality, ) from .dataset_drift import ( + BinaryClassDrift, + Drift, DriftAlgorithm, - FeatureDriftCalculation, FeatureDrift, - Drift, - BinaryClassDrift, + FeatureDriftCalculation, MultiClassDrift, RegressionDrift, ) -from .column_definition import ColumnDefinition -from .aws_credentials import AwsCredentials +from .dataset_model_quality import ( + BinaryClassificationModelQuality, + ModelQuality, + MultiClassModelQuality, + RegressionModelQuality, +) +from .dataset_stats import DatasetStats +from .file_upload_result import CurrentFileUpload, FileReference, ReferenceFileUpload +from .job_status import JobStatus +from .model_definition import ( + CreateModel, + Granularity, + ModelDefinition, + OutputType, + PaginatedModelDefinitions, +) +from .model_type import ModelType __all__ = [ - "OutputType", - "Granularity", - "CreateModel", - "ModelDefinition", - "ColumnDefinition", - "JobStatus", - "DataType", - "ModelType", - "DatasetStats", - "ModelQuality", - "BinaryClassificationModelQuality", - "MultiClassModelQuality", - "RegressionModelQuality", - "DataQuality", - "BinaryClassificationDataQuality", - "MultiClassDataQuality", - "RegressionDataQuality", - "ClassMetrics", - "MedianMetrics", - "MissingValue", - "ClassMedianMetrics", - "FeatureMetrics", - "NumericalFeatureMetrics", - "CategoryFrequency", - "CategoricalFeatureMetrics", - "DriftAlgorithm", - "FeatureDriftCalculation", - "FeatureDrift", - "Drift", - "BinaryClassDrift", - "MultiClassDrift", - "RegressionDrift", - "PaginatedModelDefinitions", - "ReferenceFileUpload", - "CurrentFileUpload", - "FileReference", - "AwsCredentials", + 'OutputType', + 'Granularity', + 'CreateModel', + 'ModelDefinition', + 'ColumnDefinition', + 'JobStatus', + 'DataType', + 'ModelType', + 'DatasetStats', + 'ModelQuality', + 'BinaryClassificationModelQuality', + 'MultiClassModelQuality', + 'RegressionModelQuality', + 'DataQuality', + 'BinaryClassificationDataQuality', + 'MultiClassDataQuality', + 'RegressionDataQuality', + 'ClassMetrics', + 'MedianMetrics', + 'MissingValue', + 'ClassMedianMetrics', + 'FeatureMetrics', + 'NumericalFeatureMetrics', + 'CategoryFrequency', + 'CategoricalFeatureMetrics', + 'DriftAlgorithm', + 'FeatureDriftCalculation', + 'FeatureDrift', + 'Drift', + 'BinaryClassDrift', + 'MultiClassDrift', + 'RegressionDrift', + 'PaginatedModelDefinitions', + 'ReferenceFileUpload', + 'CurrentFileUpload', + 'FileReference', + 'AwsCredentials', ] diff --git a/sdk/radicalbit_platform_sdk/models/data_type.py b/sdk/radicalbit_platform_sdk/models/data_type.py index fe7f7458..1b0b0035 100644 --- a/sdk/radicalbit_platform_sdk/models/data_type.py +++ b/sdk/radicalbit_platform_sdk/models/data_type.py @@ -2,6 +2,6 @@ class DataType(str, Enum): - TABULAR = "TABULAR" - TEXT = "TEXT" - IMAGE = "IMAGE" + TABULAR = 'TABULAR' + TEXT = 'TEXT' + IMAGE = 'IMAGE' diff --git a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py index c437394b..2f890443 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_data_quality.py @@ -1,6 +1,7 @@ +from typing import List, Optional, Union + from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel -from typing import List, Optional, Union class ClassMetrics(BaseModel): @@ -51,7 +52,7 @@ class Histogram(BaseModel): class NumericalFeatureMetrics(FeatureMetrics): - type: str = "numerical" + type: str = 'numerical' mean: Optional[float] = None std: Optional[float] = None min: Optional[float] = None @@ -72,7 +73,7 @@ class CategoryFrequency(BaseModel): class CategoricalFeatureMetrics(FeatureMetrics): - type: str = "categorical" + type: str = 'categorical' category_frequency: List[CategoryFrequency] distinct_value: int diff --git a/sdk/radicalbit_platform_sdk/models/dataset_drift.py b/sdk/radicalbit_platform_sdk/models/dataset_drift.py index c69524bb..503761bb 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_drift.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_drift.py @@ -6,8 +6,8 @@ class DriftAlgorithm(str, Enum): - KS = "KS" - CHI2 = "CHI2" + KS = 'KS' + CHI2 = 'CHI2' class FeatureDriftCalculation(BaseModel): diff --git a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py index b5c8bab2..f48c7437 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_model_quality.py @@ -1,6 +1,7 @@ +from typing import Optional + from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel -from typing import Optional class ModelQuality(BaseModel): diff --git a/sdk/radicalbit_platform_sdk/models/dataset_stats.py b/sdk/radicalbit_platform_sdk/models/dataset_stats.py index 115dd496..4339717c 100644 --- a/sdk/radicalbit_platform_sdk/models/dataset_stats.py +++ b/sdk/radicalbit_platform_sdk/models/dataset_stats.py @@ -1,4 +1,5 @@ from typing import Optional + from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel diff --git a/sdk/radicalbit_platform_sdk/models/file_upload_result.py b/sdk/radicalbit_platform_sdk/models/file_upload_result.py index a1ddff18..1e4b0545 100644 --- a/sdk/radicalbit_platform_sdk/models/file_upload_result.py +++ b/sdk/radicalbit_platform_sdk/models/file_upload_result.py @@ -1,8 +1,10 @@ +from typing import Optional +from uuid import UUID + from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel + from radicalbit_platform_sdk.models.job_status import JobStatus -from typing import Optional -from uuid import UUID class FileUploadResult(BaseModel): @@ -26,7 +28,7 @@ class CurrentFileUpload(FileUploadResult): class FileReference(BaseModel): file_url: str - separator: str = "," + separator: str = ',' correlation_id_column: Optional[str] = None model_config = ConfigDict( diff --git a/sdk/radicalbit_platform_sdk/models/job_status.py b/sdk/radicalbit_platform_sdk/models/job_status.py index 9bbc154f..e8bae428 100644 --- a/sdk/radicalbit_platform_sdk/models/job_status.py +++ b/sdk/radicalbit_platform_sdk/models/job_status.py @@ -2,6 +2,6 @@ class JobStatus(str, Enum): - IMPORTING = "IMPORTING" - SUCCEEDED = "SUCCEEDED" - ERROR = "ERROR" + IMPORTING = 'IMPORTING' + SUCCEEDED = 'SUCCEEDED' + ERROR = 'ERROR' diff --git a/sdk/radicalbit_platform_sdk/models/model_definition.py b/sdk/radicalbit_platform_sdk/models/model_definition.py index b97ba3a6..49ca7140 100644 --- a/sdk/radicalbit_platform_sdk/models/model_definition.py +++ b/sdk/radicalbit_platform_sdk/models/model_definition.py @@ -1,12 +1,14 @@ -import uuid as uuid_lib from enum import Enum -from radicalbit_platform_sdk.models.data_type import DataType -from radicalbit_platform_sdk.models.model_type import ModelType -from radicalbit_platform_sdk.models.column_definition import ColumnDefinition from typing import List, Optional +import uuid as uuid_lib + from pydantic import BaseModel, ConfigDict, Field from pydantic.alias_generators import to_camel +from radicalbit_platform_sdk.models.column_definition import ColumnDefinition +from radicalbit_platform_sdk.models.data_type import DataType +from radicalbit_platform_sdk.models.model_type import ModelType + class OutputType(BaseModel): prediction: ColumnDefinition @@ -17,10 +19,10 @@ class OutputType(BaseModel): class Granularity(str, Enum): - HOUR = "HOUR" - DAY = "DAY" - WEEK = "WEEK" - MONTH = "MONTH" + HOUR = 'HOUR' + DAY = 'DAY' + WEEK = 'WEEK' + MONTH = 'MONTH' class BaseModelDefinition(BaseModel): @@ -38,6 +40,7 @@ class BaseModelDefinition(BaseModel): timestamp: The column used to store the when prediction was done frameworks: An optional field to describe the frameworks used by the model algorithm: An optional field to ecplane the algorithm used by the model + """ name: str @@ -63,8 +66,8 @@ class CreateModel(BaseModelDefinition): class ModelDefinition(BaseModelDefinition): uuid: uuid_lib.UUID = Field(default_factory=lambda: uuid_lib.uuid4()) - created_at: str = Field(alias="createdAt") - updated_at: str = Field(alias="updatedAt") + created_at: str = Field(alias='createdAt') + updated_at: str = Field(alias='updatedAt') model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) diff --git a/sdk/radicalbit_platform_sdk/models/model_type.py b/sdk/radicalbit_platform_sdk/models/model_type.py index 4e81c085..48cdb051 100644 --- a/sdk/radicalbit_platform_sdk/models/model_type.py +++ b/sdk/radicalbit_platform_sdk/models/model_type.py @@ -2,6 +2,6 @@ class ModelType(str, Enum): - REGRESSION = "REGRESSION" - BINARY = "BINARY" - MULTI_CLASS = "MULTI_CLASS" + REGRESSION = 'REGRESSION' + BINARY = 'BINARY' + MULTI_CLASS = 'MULTI_CLASS' diff --git a/sdk/tests/apis/model_current_dataset_test.py b/sdk/tests/apis/model_current_dataset_test.py index 9e7ef6a6..d39a5501 100644 --- a/sdk/tests/apis/model_current_dataset_test.py +++ b/sdk/tests/apis/model_current_dataset_test.py @@ -1,15 +1,23 @@ -from radicalbit_platform_sdk.apis import ModelCurrentDataset -from radicalbit_platform_sdk.models import CurrentFileUpload, ModelType, JobStatus, DriftAlgorithm -from radicalbit_platform_sdk.errors import ClientError -import responses import unittest import uuid +import pytest +import responses + +from radicalbit_platform_sdk.apis import ModelCurrentDataset +from radicalbit_platform_sdk.errors import ClientError +from radicalbit_platform_sdk.models import ( + CurrentFileUpload, + DriftAlgorithm, + JobStatus, + ModelType, +) + class ModelCurrentDatasetTest(unittest.TestCase): @responses.activate def test_statistics_ok(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() n_variables = 10 @@ -27,19 +35,18 @@ def test_statistics_ok(self): ModelType.BINARY, CurrentFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", - correlation_id_column="column", + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics", - "status": 200, - "body": f"""{{ + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics', + status=200, + body=f"""{{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", "statistics": {{ @@ -54,7 +61,6 @@ def test_statistics_ok(self): "datetime": {datetime} }} }}""", - } ) stats = model_reference_dataset.statistics() @@ -72,7 +78,7 @@ def test_statistics_ok(self): @responses.activate def test_statistics_validation_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelCurrentDataset( @@ -81,28 +87,26 @@ def test_statistics_validation_error(self): ModelType.BINARY, CurrentFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", - correlation_id_column="column", + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics", - "status": 200, - "body": '{"statistics": "wrong"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics', + status=200, + body='{"statistics": "wrong"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.statistics() @responses.activate def test_statistics_key_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelCurrentDataset( @@ -111,28 +115,26 @@ def test_statistics_key_error(self): ModelType.BINARY, CurrentFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", - correlation_id_column="column", + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics", - "status": 200, - "body": '{"wrong": "json"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/statistics', + status=200, + body='{"wrong": "json"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.statistics() @responses.activate def test_drift_ok(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelCurrentDataset( @@ -141,19 +143,18 @@ def test_drift_ok(self): ModelType.BINARY, CurrentFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", - correlation_id_column="column", + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", - "status": 200, - "body": """{ + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift', + status=200, + body="""{ "jobStatus": "SUCCEEDED", "drift": { "featureMetrics": [ @@ -172,17 +173,16 @@ def test_drift_ok(self): ] } }""", - } ) drift = model_reference_dataset.drift() assert len(drift.feature_metrics) == 3 - assert drift.feature_metrics[1].feature_name == "city" + assert drift.feature_metrics[1].feature_name == 'city' assert drift.feature_metrics[1].drift_calc.type == DriftAlgorithm.CHI2 assert drift.feature_metrics[1].drift_calc.value == 0.12 assert drift.feature_metrics[1].drift_calc.has_drift is False - assert drift.feature_metrics[2].feature_name == "age" + assert drift.feature_metrics[2].feature_name == 'age' assert drift.feature_metrics[2].drift_calc.type == DriftAlgorithm.KS assert drift.feature_metrics[2].drift_calc.value == 0.92 assert drift.feature_metrics[2].drift_calc.has_drift is True @@ -190,7 +190,7 @@ def test_drift_ok(self): @responses.activate def test_drift_validation_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelCurrentDataset( @@ -199,28 +199,26 @@ def test_drift_validation_error(self): ModelType.BINARY, CurrentFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", - correlation_id_column="column", + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", - "status": 200, - "body": '{"statistics": "wrong"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift', + status=200, + body='{"statistics": "wrong"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.drift() @responses.activate def test_drift_key_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelCurrentDataset( @@ -229,21 +227,19 @@ def test_drift_key_error(self): ModelType.BINARY, CurrentFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", - correlation_id_column="column", + path='s3://bucket/file.csv', + date='2014', + correlation_id_column='column', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift", - "status": 200, - "body": '{"wrong": "json"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/current/{str(import_uuid)}/drift', + status=200, + body='{"wrong": "json"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.drift() diff --git a/sdk/tests/apis/model_reference_dataset_test.py b/sdk/tests/apis/model_reference_dataset_test.py index 052ef3ef..4a02a4aa 100644 --- a/sdk/tests/apis/model_reference_dataset_test.py +++ b/sdk/tests/apis/model_reference_dataset_test.py @@ -1,15 +1,18 @@ -from radicalbit_platform_sdk.apis import ModelReferenceDataset -from radicalbit_platform_sdk.models import ReferenceFileUpload, ModelType, JobStatus -from radicalbit_platform_sdk.errors import ClientError -import responses import unittest import uuid +import pytest +import responses + +from radicalbit_platform_sdk.apis import ModelReferenceDataset +from radicalbit_platform_sdk.errors import ClientError +from radicalbit_platform_sdk.models import JobStatus, ModelType, ReferenceFileUpload + class ModelReferenceDatasetTest(unittest.TestCase): @responses.activate def test_statistics_ok(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() n_variables = 10 @@ -27,18 +30,17 @@ def test_statistics_ok(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/statistics", - "status": 200, - "body": f"""{{ + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/statistics', + status=200, + body=f"""{{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", "statistics": {{ @@ -53,7 +55,6 @@ def test_statistics_ok(self): "datetime": {datetime} }} }}""", - } ) stats = model_reference_dataset.statistics() @@ -71,7 +72,7 @@ def test_statistics_ok(self): @responses.activate def test_statistics_validation_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -80,27 +81,25 @@ def test_statistics_validation_error(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/statistics", - "status": 200, - "body": '{"statistics": "wrong"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/statistics', + status=200, + body='{"statistics": "wrong"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.statistics() @responses.activate def test_statistics_key_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -109,27 +108,25 @@ def test_statistics_key_error(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/statistics", - "status": 200, - "body": '{"wrong": "json"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/statistics', + status=200, + body='{"wrong": "json"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.statistics() @responses.activate def test_model_metrics_ok(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() f1 = 0.75 @@ -156,18 +153,17 @@ def test_model_metrics_ok(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/model-quality", - "status": 200, - "body": f"""{{ + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/model-quality', + status=200, + body=f"""{{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", "modelQuality": {{ @@ -191,7 +187,6 @@ def test_model_metrics_ok(self): "falseNegativeCount": {false_negative_count} }} }}""", - } ) metrics = model_reference_dataset.model_quality() @@ -218,7 +213,7 @@ def test_model_metrics_ok(self): @responses.activate def test_model_metrics_validation_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -227,27 +222,25 @@ def test_model_metrics_validation_error(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/model-quality", - "status": 200, - "body": '{"modelQuality": "wrong"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/model-quality', + status=200, + body='{"modelQuality": "wrong"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.model_quality() @responses.activate def test_model_metrics_key_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -256,27 +249,25 @@ def test_model_metrics_key_error(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/model-quality", - "status": 200, - "body": '{"wrong": "json"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/model-quality', + status=200, + body='{"wrong": "json"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.model_quality() @responses.activate def test_data_quality_ok(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -285,18 +276,17 @@ def test_data_quality_ok(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality", - "status": 200, - "body": """{ + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/data-quality', + status=200, + body="""{ "datetime": "something_not_used", "jobStatus": "SUCCEEDED", "dataQuality": { @@ -346,28 +336,27 @@ def test_data_quality_ok(self): ] } }""", - } ) metrics = model_reference_dataset.data_quality() assert metrics.n_observations == 200 assert len(metrics.class_metrics) == 2 - assert metrics.class_metrics[0].name == "classA" + assert metrics.class_metrics[0].name == 'classA' assert metrics.class_metrics[0].count == 100 assert metrics.class_metrics[0].percentage == 50.0 assert len(metrics.feature_metrics) == 2 - assert metrics.feature_metrics[0].feature_name == "age" - assert metrics.feature_metrics[0].type == "numerical" + assert metrics.feature_metrics[0].feature_name == 'age' + assert metrics.feature_metrics[0].type == 'numerical' assert metrics.feature_metrics[0].mean == 29.5 - assert metrics.feature_metrics[1].feature_name == "gender" - assert metrics.feature_metrics[1].type == "categorical" + assert metrics.feature_metrics[1].feature_name == 'gender' + assert metrics.feature_metrics[1].type == 'categorical' assert metrics.feature_metrics[1].distinct_value == 2 assert model_reference_dataset.status() == JobStatus.SUCCEEDED @responses.activate def test_data_quality_validation_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -376,27 +365,25 @@ def test_data_quality_validation_error(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality", - "status": 200, - "body": '{"dataQuality": "wrong"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/data-quality', + status=200, + body='{"dataQuality": "wrong"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.data_quality() @responses.activate def test_data_quality_key_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() import_uuid = uuid.uuid4() model_reference_dataset = ModelReferenceDataset( @@ -405,20 +392,18 @@ def test_data_quality_key_error(self): ModelType.BINARY, ReferenceFileUpload( uuid=import_uuid, - path="s3://bucket/file.csv", - date="2014", + path='s3://bucket/file.csv', + date='2014', status=JobStatus.IMPORTING, ), ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}/reference/data-quality", - "status": 200, - "body": '{"wrong": "json"}', - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}/reference/data-quality', + status=200, + body='{"wrong": "json"}', ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model_reference_dataset.data_quality() diff --git a/sdk/tests/apis/model_test.py b/sdk/tests/apis/model_test.py index 0782c382..fd93c8b2 100644 --- a/sdk/tests/apis/model_test.py +++ b/sdk/tests/apis/model_test.py @@ -1,36 +1,39 @@ +import time +import unittest +import uuid + +import boto3 +from moto import mock_aws +import pytest +import responses + from radicalbit_platform_sdk.apis import Model +from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( - ModelDefinition, - ModelType, + ColumnDefinition, + CurrentFileUpload, DataType, - JobStatus, Granularity, - ColumnDefinition, + JobStatus, + ModelDefinition, + ModelType, OutputType, ReferenceFileUpload, - CurrentFileUpload, ) -from radicalbit_platform_sdk.errors import ClientError -import uuid -import unittest -import time -import responses -import boto3 -from moto import mock_aws class ModelTest(unittest.TestCase): @responses.activate def test_delete_model(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - column_def = ColumnDefinition(name="column", type="my_type") + column_def = ColumnDefinition(name='column', type='my_type') outputs = OutputType(prediction=column_def, output=[column_def]) model = Model( base_url, ModelDefinition( uuid=model_id, - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.MONTH, @@ -43,205 +46,197 @@ def test_delete_model(self): ), ) responses.add( - **{ - "method": responses.DELETE, - "url": f"{base_url}/api/models/{str(model_id)}", - "status": 200, - } + method=responses.DELETE, + url=f'{base_url}/api/models/{str(model_id)}', + status=200, ) model.delete() @mock_aws @responses.activate def test_load_reference_dataset_without_object_name(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - bucket_name = "test-bucket" - file_name = "test.txt" - column_def = ColumnDefinition(name="prediction", type="float") - expected_path = f"s3://{bucket_name}/{model_id}/reference/{file_name}" - conn = boto3.resource("s3", region_name="us-east-1") + bucket_name = 'test-bucket' + file_name = 'test.txt' + column_def = ColumnDefinition(name='prediction', type='float') + expected_path = f's3://{bucket_name}/{model_id}/reference/{file_name}' + conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) model = Model( base_url, ModelDefinition( uuid=model_id, - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.HOUR, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), ) response = ReferenceFileUpload( - uuid=uuid.uuid4(), path=expected_path, date="", status=JobStatus.IMPORTING + uuid=uuid.uuid4(), path=expected_path, date='', status=JobStatus.IMPORTING ) responses.add( - **{ - "method": responses.POST, - "url": f"{base_url}/api/models/{str(model_id)}/reference/bind", - "body": response.model_dump_json(), - "status": 200, - "content_type": "application/json", - } + method=responses.POST, + url=f'{base_url}/api/models/{str(model_id)}/reference/bind', + body=response.model_dump_json(), + status=200, + content_type='application/json', ) response = model.load_reference_dataset( - "tests_resources/people.csv", bucket_name + 'tests_resources/people.csv', bucket_name ) assert response.path() == expected_path @mock_aws @responses.activate def test_load_reference_dataset_with_different_separator(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - bucket_name = "test-bucket" - file_name = "test.txt" - column_def = ColumnDefinition(name="prediction", type="float") - expected_path = f"s3://{bucket_name}/{model_id}/reference/{file_name}" - conn = boto3.resource("s3", region_name="us-east-1") + bucket_name = 'test-bucket' + file_name = 'test.txt' + column_def = ColumnDefinition(name='prediction', type='float') + expected_path = f's3://{bucket_name}/{model_id}/reference/{file_name}' + conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) model = Model( base_url, ModelDefinition( uuid=model_id, - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.DAY, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), ) response = ReferenceFileUpload( - uuid=uuid.uuid4(), path=expected_path, date="", status=JobStatus.IMPORTING + uuid=uuid.uuid4(), path=expected_path, date='', status=JobStatus.IMPORTING ) responses.add( - **{ - "method": responses.POST, - "url": f"{base_url}/api/models/{str(model_id)}/reference/bind", - "body": response.model_dump_json(), - "status": 200, - "content_type": "application/json", - } + method=responses.POST, + url=f'{base_url}/api/models/{str(model_id)}/reference/bind', + body=response.model_dump_json(), + status=200, + content_type='application/json', ) response = model.load_reference_dataset( - "tests_resources/people_pipe_separated.csv", bucket_name, separator="|" + 'tests_resources/people_pipe_separated.csv', bucket_name, separator='|' ) assert response.path() == expected_path @mock_aws @responses.activate def test_load_reference_dataset_with_object_name(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - bucket_name = "test-bucket" - file_name = "test.txt" - column_def = ColumnDefinition(name="prediction", type="float") - expected_path = f"s3://{bucket_name}/{file_name}" - conn = boto3.resource("s3", region_name="us-east-1") + bucket_name = 'test-bucket' + file_name = 'test.txt' + column_def = ColumnDefinition(name='prediction', type='float') + expected_path = f's3://{bucket_name}/{file_name}' + conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) model = Model( base_url, ModelDefinition( uuid=model_id, - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.WEEK, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), ) response = ReferenceFileUpload( - uuid=uuid.uuid4(), path=expected_path, date="", status=JobStatus.IMPORTING + uuid=uuid.uuid4(), path=expected_path, date='', status=JobStatus.IMPORTING ) responses.add( - **{ - "method": responses.POST, - "url": f"{base_url}/api/models/{str(model_id)}/reference/bind", - "body": response.model_dump_json(), - "status": 200, - "content_type": "application/json", - } + method=responses.POST, + url=f'{base_url}/api/models/{str(model_id)}/reference/bind', + body=response.model_dump_json(), + status=200, + content_type='application/json', ) response = model.load_reference_dataset( - "tests_resources/people.csv", bucket_name + 'tests_resources/people.csv', bucket_name ) assert response.path() == expected_path def test_load_reference_dataset_wrong_headers(self): - column_def = ColumnDefinition(name="prediction", type="float") + column_def = ColumnDefinition(name='prediction', type='float') model = Model( - "http://api:9000", + 'http://api:9000', ModelDefinition( uuid=uuid.uuid4(), - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.MONTH, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), ) - with self.assertRaises(ClientError): - model.load_reference_dataset("tests_resources/wrong.csv", "bucket_name") + with pytest.raises(ClientError): + model.load_reference_dataset('tests_resources/wrong.csv', 'bucket_name') @mock_aws @responses.activate def test_load_current_dataset_without_object_name(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - bucket_name = "test-bucket" - file_name = "test.txt" - column_def = ColumnDefinition(name="prediction", type="float") - expected_path = f"s3://{bucket_name}/{model_id}/reference/{file_name}" - conn = boto3.resource("s3", region_name="us-east-1") + bucket_name = 'test-bucket' + file_name = 'test.txt' + column_def = ColumnDefinition(name='prediction', type='float') + expected_path = f's3://{bucket_name}/{model_id}/reference/{file_name}' + conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) model = Model( base_url, ModelDefinition( uuid=model_id, - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.DAY, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -249,21 +244,19 @@ def test_load_current_dataset_without_object_name(self): response = CurrentFileUpload( uuid=uuid.uuid4(), path=expected_path, - date="", + date='', status=JobStatus.IMPORTING, - correlation_id_column="correlation", + correlation_id_column='correlation', ) responses.add( - **{ - "method": responses.POST, - "url": f"{base_url}/api/models/{str(model_id)}/current/bind", - "body": response.model_dump_json(), - "status": 200, - "content_type": "application/json", - } + method=responses.POST, + url=f'{base_url}/api/models/{str(model_id)}/current/bind', + body=response.model_dump_json(), + status=200, + content_type='application/json', ) response = model.load_current_dataset( - "tests_resources/people_current.csv", + 'tests_resources/people_current.csv', bucket_name, response.correlation_id_column, ) @@ -272,29 +265,29 @@ def test_load_current_dataset_without_object_name(self): @mock_aws @responses.activate def test_load_current_dataset_with_object_name(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - bucket_name = "test-bucket" - file_name = "test.txt" - column_def = ColumnDefinition(name="prediction", type="float") - expected_path = f"s3://{bucket_name}/{file_name}" - conn = boto3.resource("s3", region_name="us-east-1") + bucket_name = 'test-bucket' + file_name = 'test.txt' + column_def = ColumnDefinition(name='prediction', type='float') + expected_path = f's3://{bucket_name}/{file_name}' + conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket=bucket_name) model = Model( base_url, ModelDefinition( uuid=model_id, - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.HOUR, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), @@ -302,48 +295,46 @@ def test_load_current_dataset_with_object_name(self): response = CurrentFileUpload( uuid=uuid.uuid4(), path=expected_path, - date="", + date='', status=JobStatus.IMPORTING, - correlation_id_column="correlation", + correlation_id_column='correlation', ) responses.add( - **{ - "method": responses.POST, - "url": f"{base_url}/api/models/{str(model_id)}/current/bind", - "body": response.model_dump_json(), - "status": 200, - "content_type": "application/json", - } + method=responses.POST, + url=f'{base_url}/api/models/{str(model_id)}/current/bind', + body=response.model_dump_json(), + status=200, + content_type='application/json', ) response = model.load_current_dataset( - "tests_resources/people_current.csv", + 'tests_resources/people_current.csv', bucket_name, response.correlation_id_column, ) assert response.path() == expected_path def test_load_current_dataset_wrong_headers(self): - column_def = ColumnDefinition(name="prediction", type="float") + column_def = ColumnDefinition(name='prediction', type='float') model = Model( - "http://api:9000", + 'http://api:9000', ModelDefinition( uuid=uuid.uuid4(), - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.MONTH, features=[ - ColumnDefinition(name="first_name", type="str"), - ColumnDefinition(name="age", type="int"), + ColumnDefinition(name='first_name', type='str'), + ColumnDefinition(name='age', type='int'), ], outputs=OutputType(prediction=column_def, output=[column_def]), - target=ColumnDefinition(name="adult", type="bool"), - timestamp=ColumnDefinition(name="created_at", type="str"), + target=ColumnDefinition(name='adult', type='bool'), + timestamp=ColumnDefinition(name='created_at', type='str'), created_at=str(time.time()), updated_at=str(time.time()), ), ) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): model.load_current_dataset( - "tests_resources/wrong.csv", "bucket_name", "correlation" + 'tests_resources/wrong.csv', 'bucket_name', 'correlation' ) diff --git a/sdk/tests/client_test.py b/sdk/tests/client_test.py index 3b540a54..2949c0e1 100644 --- a/sdk/tests/client_test.py +++ b/sdk/tests/client_test.py @@ -1,41 +1,44 @@ +import time +import unittest +import uuid + +import pytest +import responses + from radicalbit_platform_sdk.client import Client +from radicalbit_platform_sdk.errors import ClientError from radicalbit_platform_sdk.models import ( + ColumnDefinition, CreateModel, - ModelDefinition, DataType, - ModelType, Granularity, - ColumnDefinition, - PaginatedModelDefinitions, + ModelDefinition, + ModelType, OutputType, + PaginatedModelDefinitions, ) -from radicalbit_platform_sdk.errors import ClientError -import uuid -import unittest -import time -import responses class ClientTest(unittest.TestCase): @responses.activate def test_get_model(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() - name = "My super Model" + name = 'My super Model' model_type = ModelType.MULTI_CLASS data_type = DataType.IMAGE granularity = Granularity.HOUR - description = "some boring description about this model" - algorithm = "brainfucker" - frameworks = "mlflow" - feature_name = "age" - feature_type = "int" - output_name = "adult" - output_type = "bool" - target_name = "adult" - target_type = "bool" - timestamp_name = "when" - timestamp_type = "str" + description = 'some boring description about this model' + algorithm = 'brainfucker' + frameworks = 'mlflow' + feature_name = 'age' + feature_type = 'int' + output_name = 'adult' + output_type = 'bool' + target_name = 'adult' + target_type = 'bool' + timestamp_name = 'when' + timestamp_type = 'str' ts = str(time.time()) json_string = f"""{{ "uuid": "{str(model_id)}", @@ -76,13 +79,11 @@ def test_get_model(self): "updatedAt": "{ts}" }}""" responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}", - "body": json_string, - "status": 200, - "content_type": "application/json", - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}', + body=json_string, + status=200, + content_type='application/json', ) client = Client(base_url) model = client.get_model(id=model_id) @@ -111,37 +112,35 @@ def test_get_model(self): @responses.activate def test_get_model_client_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model_id = uuid.uuid4() invalid_json = '{"name": "Random Name"}' responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models/{str(model_id)}", - "body": invalid_json, - "status": 200, - "content_type": "application/json", - } + method=responses.GET, + url=f'{base_url}/api/models/{str(model_id)}', + body=invalid_json, + status=200, + content_type='application/json', ) client = Client(base_url) - with self.assertRaises(ClientError): + with pytest.raises(ClientError): client.get_model(id=model_id) @responses.activate def test_create_model(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' model = CreateModel( - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.WEEK, - features=[ColumnDefinition(name="feature_column", type="string")], + features=[ColumnDefinition(name='feature_column', type='string')], outputs=OutputType( - prediction=ColumnDefinition(name="result_column", type="int"), - output=[ColumnDefinition(name="result_column", type="int")], + prediction=ColumnDefinition(name='result_column', type='int'), + output=[ColumnDefinition(name='result_column', type='int')], ), - target=ColumnDefinition(name="target_column", type="string"), - timestamp=ColumnDefinition(name="tst_column", type="string"), + target=ColumnDefinition(name='target_column', type='string'), + timestamp=ColumnDefinition(name='tst_column', type='string'), ) model_definition = ModelDefinition( name=model.name, @@ -156,13 +155,11 @@ def test_create_model(self): updated_at=str(time.time()), ) responses.add( - **{ - "method": responses.POST, - "url": f"{base_url}/api/models", - "body": model_definition.model_dump_json(), - "status": 201, - "content_type": "application/json", - } + method=responses.POST, + url=f'{base_url}/api/models', + body=model_definition.model_dump_json(), + status=201, + content_type='application/json', ) client = Client(base_url) @@ -181,22 +178,22 @@ def test_create_model(self): @responses.activate def test_search_models(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' paginated_response = PaginatedModelDefinitions( items=[ ModelDefinition( - name="My Model", + name='My Model', model_type=ModelType.BINARY, data_type=DataType.TABULAR, granularity=Granularity.DAY, - features=[ColumnDefinition(name="feature_column", type="string")], + features=[ColumnDefinition(name='feature_column', type='string')], outputs=OutputType( - prediction=ColumnDefinition(name="result_column", type="int"), - output=[ColumnDefinition(name="result_column", type="int")], + prediction=ColumnDefinition(name='result_column', type='int'), + output=[ColumnDefinition(name='result_column', type='int')], ), - target=ColumnDefinition(name="target_column", type="string"), - timestamp=ColumnDefinition(name="tst_column", type="string"), + target=ColumnDefinition(name='target_column', type='string'), + timestamp=ColumnDefinition(name='tst_column', type='string'), created_at=str(time.time()), updated_at=str(time.time()), ) @@ -204,13 +201,11 @@ def test_search_models(self): ) responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api/models", - "body": paginated_response.model_dump_json(), - "status": 200, - "content_type": "application/json", - } + method=responses.GET, + url=f'{base_url}/api/models', + body=paginated_response.model_dump_json(), + status=200, + content_type='application/json', ) client = Client(base_url) diff --git a/sdk/tests/commons/rest_utils_test.py b/sdk/tests/commons/rest_utils_test.py index f89fee00..05c09469 100644 --- a/sdk/tests/commons/rest_utils_test.py +++ b/sdk/tests/commons/rest_utils_test.py @@ -1,3 +1,10 @@ +import random +import unittest + +import pytest +import requests +import responses + from radicalbit_platform_sdk.commons import invoke from radicalbit_platform_sdk.errors import ( APIError, @@ -5,77 +12,57 @@ ServerError, UnhandledResponseCode, ) -import unittest -import random -import requests -import responses class RestUtilsTest(unittest.TestCase): @responses.activate def test_invoke_network_error(self): - base_url = "http://api:80" - responses.add( - **{"method": responses.GET, "url": f"{base_url}0/api", "status": 200} - ) - with self.assertRaises(NetworkError): - invoke("GET", f"{base_url}/api", 200, lambda resp: None) + base_url = 'http://api:80' + responses.add(method=responses.GET, url=f'{base_url}0/api', status=200) + with pytest.raises(NetworkError): + invoke('GET', f'{base_url}/api', 200, lambda resp: None) @responses.activate def test_invoke_server_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api", - "status": random.randint(500, 599), - } + method=responses.GET, url=f'{base_url}/api', status=random.randint(500, 599) ) - with self.assertRaises(ServerError): - invoke("GET", f"{base_url}/api", 200, lambda resp: None) + with pytest.raises(ServerError): + invoke('GET', f'{base_url}/api', 200, lambda resp: None) @responses.activate def test_invoke_api_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api", - "status": random.randint(400, 499), - } + method=responses.GET, url=f'{base_url}/api', status=random.randint(400, 499) ) - with self.assertRaises(APIError): - invoke("GET", f"{base_url}/api", 200, lambda resp: None) + with pytest.raises(APIError): + invoke('GET', f'{base_url}/api', 200, lambda resp: None) @responses.activate def test_invoke_unhandled_response_code_error(self): - base_url = "http://api:9000" + base_url = 'http://api:9000' responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api", - "status": random.randint(201, 299), - } + method=responses.GET, url=f'{base_url}/api', status=random.randint(201, 299) ) - with self.assertRaises(UnhandledResponseCode): - invoke("GET", f"{base_url}/api", 200, lambda resp: None) + with pytest.raises(UnhandledResponseCode): + invoke('GET', f'{base_url}/api', 200, lambda resp: None) @responses.activate def test_invoke_ok(self): - base_url = "http://api:9000" - response_body = "Hooray, it works" + base_url = 'http://api:9000' + response_body = 'Hooray, it works' responses.add( - **{ - "method": responses.GET, - "url": f"{base_url}/api", - "body": response_body, - "status": 200, - "content_type": "text/plain", - } + method=responses.GET, + url=f'{base_url}/api', + body=response_body, + status=200, + content_type='text/plain', ) def __callback(response: requests.Response): return response.text - result = invoke("GET", f"{base_url}/api", 200, __callback) + result = invoke('GET', f'{base_url}/api', 200, __callback) assert result == response_body diff --git a/sdk/tests/models/column_definition_test.py b/sdk/tests/models/column_definition_test.py index e92b4514..9da442cf 100644 --- a/sdk/tests/models/column_definition_test.py +++ b/sdk/tests/models/column_definition_test.py @@ -1,12 +1,13 @@ -from radicalbit_platform_sdk.models import ColumnDefinition import json import unittest +from radicalbit_platform_sdk.models import ColumnDefinition + class ColumnDefinitionTest(unittest.TestCase): def test_from_json(self): - field_name = "age" - field_type = "int" + field_name = 'age' + field_type = 'int' json_string = f'{{"name": "{field_name}", "type": "{field_type}"}}' column_definition = ColumnDefinition.model_validate(json.loads(json_string)) assert column_definition.name == field_name diff --git a/sdk/tests/models/model_definition_test.py b/sdk/tests/models/model_definition_test.py index 807352ba..c7a9a48b 100644 --- a/sdk/tests/models/model_definition_test.py +++ b/sdk/tests/models/model_definition_test.py @@ -1,33 +1,34 @@ +import json +import time +import unittest +import uuid + from radicalbit_platform_sdk.models import ( - ModelDefinition, DataType, - ModelType, Granularity, + ModelDefinition, + ModelType, ) -import json -import unittest -import uuid -import time class ModelDefinitionTest(unittest.TestCase): def test_model_definition_from_json(self): id = uuid.uuid4() - name = "My super Model" + name = 'My super Model' model_type = ModelType.BINARY data_type = DataType.TEXT granularity = Granularity.HOUR - description = "some boring description about this model" - algorithm = "brainfucker" - frameworks = "mlflow" - feature_name = "age" - feature_type = "int" - output_name = "adult" - output_type = "bool" - target_name = "adult" - target_type = "bool" - timestamp_name = "when" - timestamp_type = "str" + description = 'some boring description about this model' + algorithm = 'brainfucker' + frameworks = 'mlflow' + feature_name = 'age' + feature_type = 'int' + output_name = 'adult' + output_type = 'bool' + target_name = 'adult' + target_type = 'bool' + timestamp_name = 'when' + timestamp_type = 'str' ts = str(time.time()) json_string = f"""{{ "uuid": "{str(id)}",