Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️Pydantic V2 and SQLAlchemy warning fixes #6877

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,6 @@ def LOG_LEVEL(self) -> LogLevel: # noqa: N802
def _valid_log_level(cls, value: str) -> str:
return cls.validate_log_level(value)

@field_validator("SERVICE_TRACKING_HEARTBEAT", mode="before")
@classmethod
def _validate_interval(
cls, value: str | datetime.timedelta
) -> int | datetime.timedelta:
if isinstance(value, str):
return int(value)
return value


def get_application_settings(app: FastAPI) -> ApplicationSettings:
return cast(ApplicationSettings, app.state.settings)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import dataclasses
import datetime
from collections.abc import Awaitable, Callable
from typing import Final
from unittest.mock import MagicMock
Expand Down Expand Up @@ -36,7 +37,9 @@ def wallet_id(faker: Faker, request: pytest.FixtureRequest) -> WalletID | None:
return faker.pyint(min_value=1) if request.param == "with_wallet" else None


_FAST_TIME_BEFORE_TERMINATION_SECONDS: Final[int] = 10
_FAST_TIME_BEFORE_TERMINATION_SECONDS: Final[datetime.timedelta] = datetime.timedelta(
seconds=10
)


@pytest.fixture
Expand Down Expand Up @@ -149,7 +152,7 @@ async def test_cluster_management_core_properly_removes_unused_instances(
mocked_dask_ping_scheduler.is_scheduler_busy.reset_mock()

# running the cluster management task after the heartbeat came in shall not remove anything
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await cluster_heartbeat(initialized_app, user_id=user_id, wallet_id=wallet_id)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
Expand All @@ -161,7 +164,7 @@ async def test_cluster_management_core_properly_removes_unused_instances(
mocked_dask_ping_scheduler.is_scheduler_busy.reset_mock()

# after waiting the termination time, running the task shall remove the cluster
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
ec2_client, instances=created_clusters, state="terminated"
Expand Down Expand Up @@ -201,7 +204,7 @@ async def test_cluster_management_core_properly_removes_workers_on_shutdown(
ec2_client, instance_ids=worker_instance_ids, state="running"
)
# after waiting the termination time, running the task shall remove the cluster
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
ec2_client, instances=created_clusters, state="terminated"
Expand Down Expand Up @@ -314,7 +317,7 @@ async def test_cluster_management_core_removes_broken_clusters_after_some_delay(
mocked_dask_ping_scheduler.is_scheduler_busy.reset_mock()

# waiting for the termination time will now terminate the cluster
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS + 1)
await asyncio.sleep(_FAST_TIME_BEFORE_TERMINATION_SECONDS.total_seconds() + 1)
await check_clusters(initialized_app)
await _assert_cluster_exist_and_state(
ec2_client, instances=created_clusters, state="terminated"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,7 @@ async def apply(
project_id: ProjectID,
iteration: Iteration,
) -> None:
"""schedules a pipeline for a given user, project and iteration.

Arguments:
wake_up_callback -- a callback function that is called in a separate thread everytime a pipeline node is completed
"""
"""apply the scheduling of a pipeline for a given user, project and iteration."""
with log_context(
_logger,
level=logging.INFO,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ async def list(
return [
CompRunsAtDB.model_validate(row)
async for row in conn.execute(
sa.select(comp_runs).where(sa.and_(*conditions))
sa.select(comp_runs).where(
sa.and_(True, *conditions) # noqa: FBT003
)
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ async def _get_service_details(
node.version,
product_name,
)
obj: ServiceMetaDataPublished = ServiceMetaDataPublished.model_construct(
**service_details
)
obj: ServiceMetaDataPublished = ServiceMetaDataPublished(**service_details)
return obj


Expand All @@ -105,7 +103,7 @@ def _compute_node_requirements(
node_defined_resources[resource_name] = node_defined_resources.get(
resource_name, 0
) + min(resource_value.limit, resource_value.reservation)
return NodeRequirements.model_validate(node_defined_resources)
return NodeRequirements(**node_defined_resources)


def _compute_node_boot_mode(node_resources: ServiceResourcesDict) -> BootMode:
Expand Down Expand Up @@ -146,12 +144,12 @@ async def _get_node_infos(
None,
)

result: tuple[ServiceMetaDataPublished, ServiceExtras, SimcoreServiceLabels] = (
await asyncio.gather(
_get_service_details(catalog_client, user_id, product_name, node),
director_client.get_service_extras(node.key, node.version),
director_client.get_service_labels(node),
)
result: tuple[
ServiceMetaDataPublished, ServiceExtras, SimcoreServiceLabels
] = await asyncio.gather(
_get_service_details(catalog_client, user_id, product_name, node),
director_client.get_service_extras(node.key, node.version),
director_client.get_service_labels(node),
)
return result

Expand Down Expand Up @@ -189,7 +187,7 @@ async def _generate_task_image(
data.update(envs=_compute_node_envs(node_labels))
if node_extras and node_extras.container_spec:
data.update(command=node_extras.container_spec.command)
return Image.model_validate(data)
return Image(**data)


async def _get_pricing_and_hardware_infos(
Expand Down Expand Up @@ -247,9 +245,9 @@ async def _get_pricing_and_hardware_infos(
return pricing_info, hardware_info


_RAM_SAFE_MARGIN_RATIO: Final[float] = (
0.1 # NOTE: machines always have less available RAM than advertised
)
_RAM_SAFE_MARGIN_RATIO: Final[
float
] = 0.1 # NOTE: machines always have less available RAM than advertised
_CPUS_SAFE_MARGIN: Final[float] = 0.1


Expand All @@ -267,11 +265,11 @@ async def _update_project_node_resources_from_hardware_info(
if not hardware_info.aws_ec2_instances:
return
try:
unordered_list_ec2_instance_types: list[EC2InstanceTypeGet] = (
await get_instance_type_details(
rabbitmq_rpc_client,
instance_type_names=set(hardware_info.aws_ec2_instances),
)
unordered_list_ec2_instance_types: list[
EC2InstanceTypeGet
] = await get_instance_type_details(
rabbitmq_rpc_client,
instance_type_names=set(hardware_info.aws_ec2_instances),
)

assert unordered_list_ec2_instance_types # nosec
Expand Down Expand Up @@ -347,7 +345,7 @@ async def generate_tasks_list_from_project(
list_comp_tasks = []

unique_service_key_versions: set[ServiceKeyVersion] = {
ServiceKeyVersion.model_construct(
ServiceKeyVersion(
key=node.key, version=node.version
) # the service key version is frozen
for node in project.workbench.values()
Expand All @@ -366,9 +364,7 @@ async def generate_tasks_list_from_project(

for internal_id, node_id in enumerate(project.workbench, 1):
node: Node = project.workbench[node_id]
node_key_version = ServiceKeyVersion.model_construct(
key=node.key, version=node.version
)
node_key_version = ServiceKeyVersion(key=node.key, version=node.version)
node_details, node_extras, node_labels = key_version_to_node_infos.get(
node_key_version,
(None, None, None),
Expand Down Expand Up @@ -434,8 +430,8 @@ async def generate_tasks_list_from_project(
task_db = CompTaskAtDB(
project_id=project.uuid,
node_id=NodeID(node_id),
schema=NodeSchema.model_validate(
node_details.model_dump(
schema=NodeSchema(
**node_details.model_dump(
exclude_unset=True, by_alias=True, include={"inputs", "outputs"}
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def minimal_configuration(
redis_service: RedisSettings,
monkeypatch: pytest.MonkeyPatch,
faker: Faker,
with_disabled_auto_scheduling: mock.Mock,
with_disabled_scheduler_publisher: mock.Mock,
):
monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SIDECAR_ENABLED", "false")
monkeypatch.setenv("COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED", "1")
Expand Down Expand Up @@ -588,11 +590,7 @@ async def test_create_computation_with_wallet(

@pytest.mark.parametrize(
"default_pricing_plan",
[
PricingPlanGet.model_construct(
**PricingPlanGet.model_config["json_schema_extra"]["examples"][0]
)
],
[PricingPlanGet(**PricingPlanGet.model_config["json_schema_extra"]["examples"][0])],
)
async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_raises_422(
minimal_configuration: None,
Expand Down Expand Up @@ -631,7 +629,7 @@ async def test_create_computation_with_wallet_with_invalid_pricing_unit_name_rai
@pytest.mark.parametrize(
"default_pricing_plan",
[
PricingPlanGet.model_construct(
PricingPlanGet(
**PricingPlanGet.model_config["json_schema_extra"]["examples"][0] # type: ignore
)
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,9 @@ async def _return_random_task_result(job_id) -> TaskOutputData:
@pytest.fixture
def with_fast_service_heartbeat_s(monkeypatch: pytest.MonkeyPatch) -> int:
seconds = 1
monkeypatch.setenv("SERVICE_TRACKING_HEARTBEAT", f"{seconds}")
monkeypatch.setenv(
"SERVICE_TRACKING_HEARTBEAT", f"{datetime.timedelta(seconds=seconds)}"
)
return seconds


Expand Down
6 changes: 3 additions & 3 deletions services/director-v2/tests/unit/with_dbs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ async def _(user: dict[str, Any], **cluster_kwargs) -> Cluster:
.where(clusters.c.id == created_cluster.id)
):
access_rights_in_db[row.gid] = {
"read": row[cluster_to_groups.c.read],
"write": row[cluster_to_groups.c.write],
"delete": row[cluster_to_groups.c.delete],
"read": row.read,
"write": row.write,
"delete": row.delete,
}

return Cluster(
Expand Down
Loading