Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into pr-osparc-deflate…
Browse files Browse the repository at this point in the history
…64-zip
  • Loading branch information
Andrei Neagu committed Dec 19, 2024
2 parents e69bc9a + d75b0a3 commit e282c95
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta

from fastapi import FastAPI
from models_library.resource_tracker import (
Expand All @@ -15,7 +15,11 @@
from ..core.settings import ApplicationSettings
from ..models.credit_transactions import CreditTransactionCreditsAndStatusUpdate
from ..models.service_runs import ServiceRunStoppedAtUpdate
from .modules.db import credit_transactions_db, service_runs_db
from .modules.db import (
credit_transactions_db,
licensed_items_checkouts_db,
service_runs_db,
)
from .utils import compute_service_run_credit_costs, make_negative

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -116,6 +120,11 @@ async def _close_unhealthy_service(
db_engine, data=update_credit_transaction
)

# 3. Release license seats in case some were checked out but not properly released.
await licensed_items_checkouts_db.force_release_license_seats_by_run_id(
db_engine, service_run_id=service_run_id
)


async def periodic_check_of_running_services_task(app: FastAPI) -> None:
_logger.info("Periodic check started")
Expand All @@ -124,7 +133,7 @@ async def periodic_check_of_running_services_task(app: FastAPI) -> None:
app_settings: ApplicationSettings = app.state.settings
_db_engine = app.state.engine

base_start_timestamp = datetime.now(tz=timezone.utc)
base_start_timestamp = datetime.now(tz=UTC)

# Get all current running services (across all products)
total_count: PositiveInt = await service_runs_db.total_service_runs_with_running_status_across_all_products(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from datetime import datetime
from typing import cast

Expand All @@ -8,6 +9,7 @@
LicensedItemCheckoutID,
)
from models_library.rest_ordering import OrderBy, OrderDirection
from models_library.services_types import ServiceRunID
from models_library.wallets import WalletID
from pydantic import NonNegativeInt
from servicelib.rabbitmq.rpc_interfaces.resource_usage_tracker.errors import (
Expand All @@ -27,6 +29,9 @@
LicensedItemCheckoutDB,
)

_logger = logging.getLogger(__name__)


_SELECTION_ARGS = (
resource_tracker_licensed_items_checkouts.c.licensed_item_checkout_id,
resource_tracker_licensed_items_checkouts.c.licensed_item_id,
Expand Down Expand Up @@ -214,3 +219,41 @@ async def get_currently_used_seats_for_item_and_wallet(
if total_sum is None:
return 0
return cast(int, total_sum)


async def force_release_license_seats_by_run_id(
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
service_run_id: ServiceRunID,
) -> None:
"""
Purpose: This function is utilized by a periodic heartbeat check task that monitors whether running services are
sending heartbeat signals. If heartbeat signals are not received within a specified timeframe and a service is
deemed unhealthy, this function ensures the proper release of any licensed seats that were not correctly released by
the unhealthy service.
Currently, this functionality is primarily used to handle the release of a single seat allocated to the VIP model.
"""
update_stmt = (
resource_tracker_licensed_items_checkouts.update()
.values(
modified=sa.func.now(),
stopped_at=sa.func.now(),
)
.where(
(
resource_tracker_licensed_items_checkouts.c.service_run_id
== service_run_id
)
& (resource_tracker_licensed_items_checkouts.c.stopped_at.is_(None))
)
.returning(sa.literal_column("*"))
)

async with transaction_context(engine, connection) as conn:
result = await conn.execute(update_stmt)
released_seats = result.fetchall()
if released_seats:
_logger.error(
"Force release of %s seats: %s", len(released_seats), released_seats
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
ServiceRunLastHeartbeatUpdate,
ServiceRunStoppedAtUpdate,
)
from .modules.db import credit_transactions_db, pricing_plans_db, service_runs_db
from .modules.db import (
credit_transactions_db,
licensed_items_checkouts_db,
pricing_plans_db,
service_runs_db,
)
from .modules.rabbitmq import RabbitMQClient, get_rabbitmq_client
from .utils import (
compute_service_run_credit_costs,
Expand Down Expand Up @@ -269,9 +274,15 @@ async def _process_stop_event(
running_service = await service_runs_db.update_service_run_stopped_at(
db_engine, data=update_service_run_stopped_at
)
await licensed_items_checkouts_db.force_release_license_seats_by_run_id(
db_engine, service_run_id=msg.service_run_id
)

if running_service is None:
_logger.error("Nothing to update. This should not happen investigate.")
_logger.error(
"Nothing to update. This should not happen investigate. service_run_id: %s",
msg.service_run_id,
)
return

if running_service.wallet_id and running_service.pricing_unit_cost is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# pylint:disable=unused-variable
# pylint:disable=unused-argument
# pylint:disable=redefined-outer-name
# pylint:disable=too-many-arguments


from datetime import UTC, datetime
from typing import Generator
from unittest import mock

import pytest
import sqlalchemy as sa
from models_library.basic_types import IDStr
from models_library.rest_ordering import OrderBy
from simcore_postgres_database.models.resource_tracker_licensed_items_checkouts import (
resource_tracker_licensed_items_checkouts,
)
from simcore_postgres_database.models.resource_tracker_service_runs import (
resource_tracker_service_runs,
)
from simcore_service_resource_usage_tracker.models.licensed_items_checkouts import (
CreateLicensedItemCheckoutDB,
)
from simcore_service_resource_usage_tracker.services.modules.db import (
licensed_items_checkouts_db,
)

pytest_simcore_core_services_selection = [
"postgres",
]
pytest_simcore_ops_services_selection = [
"adminer",
]


_USER_ID_1 = 1
_WALLET_ID = 6


@pytest.fixture()
def resource_tracker_service_run_id(
postgres_db: sa.engine.Engine, random_resource_tracker_service_run
) -> Generator[str, None, None]:
with postgres_db.connect() as con:
result = con.execute(
resource_tracker_service_runs.insert()
.values(
**random_resource_tracker_service_run(
user_id=_USER_ID_1, wallet_id=_WALLET_ID
)
)
.returning(resource_tracker_service_runs.c.service_run_id)
)
row = result.first()
assert row

yield row[0]

con.execute(resource_tracker_licensed_items_checkouts.delete())
con.execute(resource_tracker_service_runs.delete())


async def test_licensed_items_checkouts_db__force_release_license_seats_by_run_id(
mocked_redis_server: None,
mocked_setup_rabbitmq: mock.Mock,
resource_tracker_service_run_id,
initialized_app,
):
engine = initialized_app.state.engine

# SETUP
_create_license_item_checkout_db_1 = CreateLicensedItemCheckoutDB(
licensed_item_id="beb16d18-d57d-44aa-a638-9727fa4a72ef",
wallet_id=_WALLET_ID,
user_id=_USER_ID_1,
user_email="[email protected]",
product_name="osparc",
service_run_id=resource_tracker_service_run_id,
started_at=datetime.now(tz=UTC),
num_of_seats=1,
)
await licensed_items_checkouts_db.create(
engine, data=_create_license_item_checkout_db_1
)

_create_license_item_checkout_db_2 = _create_license_item_checkout_db_1.model_dump()
_create_license_item_checkout_db_2[
"licensed_item_id"
] = "b1b96583-333f-44d6-b1e0-5c0a8af555bf"
await licensed_items_checkouts_db.create(
engine,
data=CreateLicensedItemCheckoutDB.model_construct(
**_create_license_item_checkout_db_2
),
)

_create_license_item_checkout_db_3 = _create_license_item_checkout_db_1.model_dump()
_create_license_item_checkout_db_3[
"licensed_item_id"
] = "38a5ce59-876f-482a-ace1-d3b2636feac6"
checkout = await licensed_items_checkouts_db.create(
engine,
data=CreateLicensedItemCheckoutDB.model_construct(
**_create_license_item_checkout_db_3
),
)

_helper_time = datetime.now(UTC)
await licensed_items_checkouts_db.update(
engine,
licensed_item_checkout_id=checkout.licensed_item_checkout_id,
product_name="osparc",
stopped_at=_helper_time,
)

# TEST FORCE RELEASE LICENSE SEATS
await licensed_items_checkouts_db.force_release_license_seats_by_run_id(
engine, service_run_id=resource_tracker_service_run_id
)

# ASSERT
total, items = await licensed_items_checkouts_db.list_(
engine,
product_name="osparc",
filter_wallet_id=_WALLET_ID,
offset=0,
limit=5,
order_by=OrderBy(field=IDStr("started_at")),
)
assert total == 3
assert len(items) == 3

_helper_count = 0
for item in items:
assert isinstance(item.stopped_at, datetime)
if item.stopped_at > _helper_time:
_helper_count += 1

assert _helper_count == 2

0 comments on commit e282c95

Please sign in to comment.