Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor RUT to use new transactional context #6874

Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,11 @@
#

import logging
from collections.abc import AsyncGenerator, Callable
from typing import Annotated

from fastapi import Depends
from fastapi.requests import Request
from servicelib.fastapi.dependencies import get_app, get_reverse_url_mapper
from sqlalchemy.ext.asyncio import AsyncEngine

from ...services.modules.db.repositories._base import BaseRepository

logger = logging.getLogger(__name__)


Expand All @@ -23,15 +18,6 @@ def get_resource_tracker_db_engine(request: Request) -> AsyncEngine:
return engine


def get_repository(repo_type: type[BaseRepository]) -> Callable:
async def _get_repo(
engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)],
) -> AsyncGenerator[BaseRepository, None]:
yield repo_type(db_engine=engine)

return _get_repo


assert get_reverse_url_mapper # nosec
assert get_app # nosec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@

from ...core.settings import ApplicationSettings
from ...services import pricing_plans, pricing_units, service_runs
from ...services.modules.db.repositories.resource_tracker import (
ResourceTrackerRepository,
)
from ...services.modules.s3 import get_s3_client

router = RPCRouter()
Expand All @@ -56,7 +53,7 @@ async def get_service_run_page(
return await service_runs.list_service_runs(
user_id=user_id,
product_name=product_name,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
limit=limit,
offset=offset,
wallet_id=wallet_id,
Expand Down Expand Up @@ -87,7 +84,7 @@ async def export_service_runs(
s3_region=s3_settings.S3_REGION,
user_id=user_id,
product_name=product_name,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
wallet_id=wallet_id,
access_all_wallet_usage=access_all_wallet_usage,
order_by=order_by,
Expand All @@ -111,7 +108,7 @@ async def get_osparc_credits_aggregated_usages_page(
return await service_runs.get_osparc_credits_aggregated_usages_page(
user_id=user_id,
product_name=product_name,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
aggregated_by=aggregated_by,
time_period=time_period,
limit=limit,
Expand All @@ -134,7 +131,7 @@ async def get_pricing_plan(
return await pricing_plans.get_pricing_plan(
product_name=product_name,
pricing_plan_id=pricing_plan_id,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -146,7 +143,7 @@ async def list_pricing_plans(
) -> list[PricingPlanGet]:
return await pricing_plans.list_pricing_plans_by_product(
product_name=product_name,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -158,7 +155,7 @@ async def create_pricing_plan(
) -> PricingPlanGet:
return await pricing_plans.create_pricing_plan(
data=data,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -172,7 +169,7 @@ async def update_pricing_plan(
return await pricing_plans.update_pricing_plan(
product_name=product_name,
data=data,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -191,7 +188,7 @@ async def get_pricing_unit(
product_name=product_name,
pricing_plan_id=pricing_plan_id,
pricing_unit_id=pricing_unit_id,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -205,7 +202,7 @@ async def create_pricing_unit(
return await pricing_units.create_pricing_unit(
product_name=product_name,
data=data,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -219,7 +216,7 @@ async def update_pricing_unit(
return await pricing_units.update_pricing_unit(
product_name=product_name,
data=data,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)


Expand All @@ -238,7 +235,7 @@ async def list_connected_services_to_pricing_plan_by_pricing_plan(
] = await pricing_plans.list_connected_services_to_pricing_plan_by_pricing_plan(
product_name=product_name,
pricing_plan_id=pricing_plan_id,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)
return output

Expand All @@ -257,5 +254,5 @@ async def connect_service_to_pricing_plan(
pricing_plan_id=pricing_plan_id,
service_key=service_key,
service_version=service_version,
resource_tracker_repo=ResourceTrackerRepository(db_engine=app.state.engine),
db_engine=app.state.engine,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
ServiceRunStatus,
)
from pydantic import NonNegativeInt, PositiveInt
from sqlalchemy.ext.asyncio import AsyncEngine

from ..core.settings import ApplicationSettings
from ..models.credit_transactions import CreditTransactionCreditsAndStatusUpdate
from ..models.service_runs import ServiceRunStoppedAtUpdate
from .modules.db.repositories.resource_tracker import ResourceTrackerRepository
from .modules.db import credit_transactions_db, service_runs_db
from .utils import compute_service_run_credit_costs, make_negative

_logger = logging.getLogger(__name__)
Expand All @@ -23,7 +24,7 @@


async def _check_service_heartbeat(
resource_tracker_repo: ResourceTrackerRepository,
db_engine: AsyncEngine,
base_start_timestamp: datetime,
resource_usage_tracker_missed_heartbeat_interval: timedelta,
resource_usage_tracker_missed_heartbeat_counter_fail: NonNegativeInt,
Expand Down Expand Up @@ -55,21 +56,24 @@ async def _check_service_heartbeat(
missed_heartbeat_counter,
)
await _close_unhealthy_service(
resource_tracker_repo, service_run_id, base_start_timestamp
db_engine, service_run_id, base_start_timestamp
)
else:
_logger.warning(
"Service run id: %s missed heartbeat. Counter %s",
service_run_id,
missed_heartbeat_counter,
)
await resource_tracker_repo.update_service_missed_heartbeat_counter(
service_run_id, last_heartbeat_at, missed_heartbeat_counter
await service_runs_db.update_service_missed_heartbeat_counter(
db_engine,
service_run_id=service_run_id,
last_heartbeat_at=last_heartbeat_at,
missed_heartbeat_counter=missed_heartbeat_counter,
)


async def _close_unhealthy_service(
resource_tracker_repo: ResourceTrackerRepository,
db_engine: AsyncEngine,
service_run_id: ServiceRunId,
base_start_timestamp: datetime,
):
Expand All @@ -80,8 +84,8 @@ async def _close_unhealthy_service(
service_run_status=ServiceRunStatus.ERROR,
service_run_status_msg="Service missed more heartbeats. It's considered unhealthy.",
)
running_service = await resource_tracker_repo.update_service_run_stopped_at(
update_service_run_stopped_at
running_service = await service_runs_db.update_service_run_stopped_at(
db_engine, data=update_service_run_stopped_at
)

if running_service is None:
Expand All @@ -108,8 +112,8 @@ async def _close_unhealthy_service(
else CreditTransactionStatus.BILLED
),
)
await resource_tracker_repo.update_credit_transaction_credits_and_status(
update_credit_transaction
await credit_transactions_db.update_credit_transaction_credits_and_status(
db_engine, data=update_credit_transaction
)


Expand All @@ -118,27 +122,26 @@ async def periodic_check_of_running_services_task(app: FastAPI) -> None:

# This check runs across all products
app_settings: ApplicationSettings = app.state.settings
resource_tracker_repo: ResourceTrackerRepository = ResourceTrackerRepository(
db_engine=app.state.engine
)
_db_engine = app.state.engine

base_start_timestamp = datetime.now(tz=timezone.utc)

# Get all current running services (across all products)
total_count: PositiveInt = (
await resource_tracker_repo.total_service_runs_with_running_status_across_all_products()
total_count: PositiveInt = await service_runs_db.total_service_runs_with_running_status_across_all_products(
_db_engine
)

for offset in range(0, total_count, _BATCH_SIZE):
batch_check_services = await resource_tracker_repo.list_service_runs_with_running_status_across_all_products(
batch_check_services = await service_runs_db.list_service_runs_with_running_status_across_all_products(
_db_engine,
offset=offset,
limit=_BATCH_SIZE,
)

await asyncio.gather(
*(
_check_service_heartbeat(
resource_tracker_repo=resource_tracker_repo,
db_engine=_db_engine,
base_start_timestamp=base_start_timestamp,
resource_usage_tracker_missed_heartbeat_interval=app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_INTERVAL_SEC,
resource_usage_tracker_missed_heartbeat_counter_fail=app_settings.RESOURCE_USAGE_TRACKER_MISSED_HEARTBEAT_COUNTER_FAIL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@
)
from models_library.wallets import WalletID
from servicelib.rabbitmq import RabbitMQClient
from sqlalchemy.ext.asyncio import AsyncEngine

from ..api.rest.dependencies import get_repository
from ..api.rest.dependencies import get_resource_tracker_db_engine
from ..models.credit_transactions import CreditTransactionCreate
from .modules.db.repositories.resource_tracker import ResourceTrackerRepository
from .modules.db import credit_transactions_db
from .modules.rabbitmq import get_rabbitmq_client_from_request
from .utils import sum_credit_transactions_and_publish_to_rabbitmq


async def create_credit_transaction(
credit_transaction_create_body: CreditTransactionCreateBody,
resource_tracker_repo: Annotated[
ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository))
],
db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)],
rabbitmq_client: Annotated[
RabbitMQClient, Depends(get_rabbitmq_client_from_request)
],
Expand All @@ -47,12 +46,12 @@ async def create_credit_transaction(
created_at=credit_transaction_create_body.created_at,
last_heartbeat_at=credit_transaction_create_body.created_at,
)
transaction_id = await resource_tracker_repo.create_credit_transaction(
transaction_create
transaction_id = await credit_transactions_db.create_credit_transaction(
db_engine, data=transaction_create
)

await sum_credit_transactions_and_publish_to_rabbitmq(
resource_tracker_repo,
db_engine,
rabbitmq_client,
credit_transaction_create_body.product_name,
credit_transaction_create_body.wallet_id,
Expand All @@ -64,10 +63,8 @@ async def create_credit_transaction(
async def sum_credit_transactions_by_product_and_wallet(
product_name: ProductName,
wallet_id: WalletID,
resource_tracker_repo: Annotated[
ResourceTrackerRepository, Depends(get_repository(ResourceTrackerRepository))
],
db_engine: Annotated[AsyncEngine, Depends(get_resource_tracker_db_engine)],
) -> WalletTotalCredits:
return await resource_tracker_repo.sum_credit_transactions_by_product_and_wallet(
product_name, wallet_id
return await credit_transactions_db.sum_credit_transactions_by_product_and_wallet(
db_engine, product_name=product_name, wallet_id=wallet_id
)
Loading
Loading