diff --git a/common/osu/difficultycalculator.py b/common/osu/difficultycalculator.py index ec9120c..ac57ca0 100644 --- a/common/osu/difficultycalculator.py +++ b/common/osu/difficultycalculator.py @@ -33,8 +33,8 @@ class Score(NamedTuple): class Calculation(NamedTuple): - difficulty: float - performance: float + difficulty_values: dict[str, float] + performance_values: dict[str, float] class DifficultyCalculatorException(Exception): @@ -90,7 +90,8 @@ def calculate_score(self, score: Score) -> Calculation: self.calculate() return Calculation( - difficulty=self.difficulty_total, performance=self.performance_total + difficulty_values={"total": self.difficulty_total}, + performance_values={"total": self.performance_total}, ) def calculate_score_batch(self, scores: Iterable[Score]) -> list[Calculation]: @@ -328,12 +329,12 @@ def calculate_score(self, score: Score) -> Calculation: data = response.json() except httpx.HTTPStatusError as e: raise CalculationException( - f"An error occured in calculating the beatmap {score.beatmap_id}" + f"An error occured in calculating the beatmap {score.beatmap_id}: {e.response.text}" ) from e return Calculation( - difficulty=data["difficulty"]["total"], - performance=data["performance"]["total"], + difficulty_values=data["difficulty"], + performance_values=data["performance"], ) def calculate_score_batch(self, scores: Iterable[Score]) -> list[Calculation]: @@ -346,13 +347,13 @@ def calculate_score_batch(self, scores: Iterable[Score]) -> list[Calculation]: data = response.json() except httpx.HTTPStatusError as e: raise CalculationException( - f"An error occured in calculating the beatmaps" + f"An error occured in calculating the beatmaps: {e.response.text}" ) from e return [ Calculation( - difficulty=calculation_data["difficulty"]["total"], - performance=calculation_data["performance"]["total"], + difficulty_values=calculation_data["difficulty"], + performance_values=calculation_data["performance"], ) for calculation_data in data ] diff --git a/common/osu/test_difficultycalculator.py b/common/osu/test_difficultycalculator.py index 60e5b8c..b3a3ed1 100644 --- a/common/osu/test_difficultycalculator.py +++ b/common/osu/test_difficultycalculator.py @@ -43,7 +43,8 @@ def test_calculate_score(self): combo=2000, ) assert calc.calculate_score(score) == Calculation( - difficulty=5.919765949249268, performance=298.1595153808594 + difficulty_values={"total": 5.919765949249268}, + performance_values={"total": 298.1595153808594}, ) def test_calculate_score_batch(self): @@ -71,9 +72,18 @@ def test_calculate_score_batch(self): ), ] assert calc.calculate_score_batch(scores) == [ - Calculation(difficulty=5.919765949249268, performance=298.1595153808594), - Calculation(difficulty=6.20743465423584, performance=476.4307861328125), - Calculation(difficulty=6.20743465423584, performance=630.419677734375), + Calculation( + difficulty_values={"total": 5.919765949249268}, + performance_values={"total": 298.1595153808594}, + ), + Calculation( + difficulty_values={"total": 6.20743465423584}, + performance_values={"total": 476.4307861328125}, + ), + Calculation( + difficulty_values={"total": 6.20743465423584}, + performance_values={"total": 630.419677734375}, + ), ] @pytest.fixture @@ -148,7 +158,8 @@ def test_calculate_score(self): combo=2000, ) assert calc.calculate_score(score) == Calculation( - difficulty=6.264344677869616, performance=312.43705315450256 + difficulty_values={"total": 6.264344677869616}, + performance_values={"total": 312.43705315450256}, ) def test_calculate_score_batch(self): @@ -176,9 +187,18 @@ def test_calculate_score_batch(self): ), ] assert calc.calculate_score_batch(scores) == [ - Calculation(difficulty=6.264344677869616, performance=312.43705315450256), - Calculation(difficulty=6.531051472171891, performance=487.5904861756349), - Calculation(difficulty=6.531051472171891, performance=655.9388807525456), + Calculation( + difficulty_values={"total": 6.264344677869616}, + performance_values={"total": 312.43705315450256}, + ), + Calculation( + difficulty_values={"total": 6.531051472171891}, + performance_values={"total": 487.5904861756349}, + ), + Calculation( + difficulty_values={"total": 6.531051472171891}, + performance_values={"total": 655.9388807525456}, + ), ] @pytest.fixture @@ -234,7 +254,19 @@ def test_version(self): def test_context_manager(self): with DifficalcyOsuDifficultyCalculator() as calc: assert calc.calculate_score(Score("307618")) == Calculation( - difficulty=4.4569433791337945, performance=135.0040504515237 + difficulty_values={ + "aim": 2.08629357857818, + "speed": 2.1778593015565684, + "flashlight": 0, + "total": 4.4569433791337945, + }, + performance_values={ + "aim": 44.12278272319251, + "speed": 50.54174287197802, + "accuracy": 36.07670429437059, + "flashlight": 0, + "total": 135.0040504515237, + }, ) def test_invalid_beatmap(self): @@ -253,7 +285,19 @@ def test_calculate_score(self): combo=2000, ) assert calc.calculate_score(score) == Calculation( - difficulty=6.263707394408435, performance=312.36671287580185 + difficulty_values={ + "aim": 2.892063051954271, + "speed": 3.0958487396004704, + "flashlight": 0, + "total": 6.263707394408435, + }, + performance_values={ + "aim": 98.6032935956297, + "speed": 118.92511309917593, + "accuracy": 84.96884392557897, + "flashlight": 0, + "total": 312.36671287580185, + }, ) def test_calculate_score_batch(self): @@ -281,7 +325,49 @@ def test_calculate_score_batch(self): ), ] assert calc.calculate_score_batch(scores) == [ - Calculation(difficulty=6.263707394408435, performance=312.36671287580185), - Calculation(difficulty=6.530286188377548, performance=487.4810004992573), - Calculation(difficulty=6.530286188377548, performance=655.7872855036575), + Calculation( + difficulty_values={ + "aim": 2.892063051954271, + "speed": 3.0958487396004704, + "flashlight": 0, + "total": 6.263707394408435, + }, + performance_values={ + "aim": 98.6032935956297, + "speed": 118.92511309917593, + "accuracy": 84.96884392557897, + "flashlight": 0, + "total": 312.36671287580185, + }, + ), + Calculation( + difficulty_values={ + "aim": 3.1381340530266333, + "speed": 3.1129549941521066, + "flashlight": 0, + "total": 6.530286188377548, + }, + performance_values={ + "aim": 153.058022351103, + "speed": 153.10941688245896, + "accuracy": 166.32370945374015, + "flashlight": 0, + "total": 487.4810004992573, + }, + ), + Calculation( + difficulty_values={ + "aim": 3.1381340530266333, + "speed": 3.1129549941521066, + "flashlight": 0, + "total": 6.530286188377548, + }, + performance_values={ + "aim": 207.5808620241847, + "speed": 215.2746980112218, + "accuracy": 212.8087296294707, + "flashlight": 0, + "total": 655.7872855036575, + }, + ), ] diff --git a/profiles/management/commands/recalculate.py b/profiles/management/commands/recalculate.py index 4a400b5..eeecb98 100644 --- a/profiles/management/commands/recalculate.py +++ b/profiles/management/commands/recalculate.py @@ -2,7 +2,6 @@ from django.core.management.base import BaseCommand from django.core.paginator import Paginator -from django.db import transaction from django.db.models import Count, QuerySet from tqdm import tqdm @@ -11,16 +10,12 @@ DifficultyCalculator, get_difficulty_calculator_class, ) -from common.osu.enums import Gamemode, Mods +from common.osu.enums import Gamemode from leaderboards.models import Membership -from profiles.models import ( - Beatmap, - DifficultyCalculation, - DifficultyValue, - PerformanceCalculation, - PerformanceValue, - Score, - UserStats, +from profiles.models import Beatmap, Score, UserStats +from profiles.services import ( + update_difficulty_calculations, + update_performance_calculations_for_unique_beatmap, ) @@ -28,7 +23,6 @@ class Command(BaseCommand): help = "Recalculates beatmap difficulty values, score performance values and user stats score styles" def add_arguments(self, parser): - parser.add_argument("gamemode", nargs=1, type=int) parser.add_argument( "--force", action="store_true", @@ -45,9 +39,7 @@ def add_arguments(self, parser): ) def handle(self, *args, **options): - gamemode = options["gamemode"][0] force = options["force"] - # the v2 flag is used to determine whether to use the new difficulty and performance models v2 = options["v2"] diffcalc_name = options["diffcalc"] @@ -56,28 +48,24 @@ def handle(self, *args, **options): else: difficulty_calculator_class = DifficultyCalculator - if gamemode != difficulty_calculator_class.gamemode(): - self.stdout.write( - self.style.ERROR( - f"Gamemode {gamemode} is not supported by {difficulty_calculator_class.__name__}" - ) - ) - return + difficulty_calculator = difficulty_calculator_class() + + gamemode = difficulty_calculator.gamemode() self.stdout.write( f"Gamemode: {Gamemode(gamemode).name}\n" - f"Difficulty Calculator Engine: {difficulty_calculator_class.engine()}\n" - f"Difficulty Calculator Version: {difficulty_calculator_class.version()}\n" + f"Difficulty Calculator Engine: {difficulty_calculator.engine()}\n" + f"Difficulty Calculator Version: {difficulty_calculator.version()}\n" ) if v2: # Recalculate beatmaps beatmaps = Beatmap.objects.filter(gamemode=gamemode) - self.recalculate_beatmaps_v2(difficulty_calculator_class, beatmaps, force) + self.recalculate_beatmaps_v2(difficulty_calculator, beatmaps, force) # Recalculate scores scores = Score.objects.filter(gamemode=gamemode) - self.recalculate_scores_v2(difficulty_calculator_class, scores, force) + self.recalculate_scores_v2(difficulty_calculator, scores, force) else: # Recalculate beatmaps beatmaps = Beatmap.objects.filter(gamemode=gamemode) @@ -235,7 +223,7 @@ def recalculate_scores( def recalculate_beatmaps_v2( self, - difficulty_calculator_class: Type[AbstractDifficultyCalculator], + difficulty_calculator: AbstractDifficultyCalculator, beatmaps: QuerySet[Beatmap], force: bool = False, ): @@ -246,13 +234,12 @@ def recalculate_beatmaps_v2( with tqdm(desc="Beatmaps", total=beatmaps.count(), smoothing=0) as pbar: for page in paginator: - self.recalculate_beatmap_page_v2( - difficulty_calculator_class, page, pbar - ) + update_difficulty_calculations(page, difficulty_calculator) + pbar.update(len(page)) else: beatmaps_to_recalculate = beatmaps.exclude( - difficulty_calculations__calculator_engine=difficulty_calculator_class.engine(), - difficulty_calculations__calculator_version=difficulty_calculator_class.version(), + difficulty_calculations__calculator_engine=difficulty_calculator.engine(), + difficulty_calculations__calculator_version=difficulty_calculator.version(), ) if beatmaps_to_recalculate.count() == 0: @@ -273,9 +260,8 @@ def recalculate_beatmaps_v2( smoothing=0, ) as pbar: while len(page := beatmaps_to_recalculate[:2000]) > 0: - self.recalculate_beatmap_page_v2( - difficulty_calculator_class, page, pbar - ) + update_difficulty_calculations(page, difficulty_calculator) + pbar.update(len(page)) self.stdout.write( self.style.SUCCESS( @@ -283,54 +269,9 @@ def recalculate_beatmaps_v2( ) ) - @transaction.atomic - def recalculate_beatmap_page_v2( - self, - difficulty_calculator_class: Type[AbstractDifficultyCalculator], - page: Iterable[Beatmap], - progress_bar: tqdm, - ): - calculations = [] - beatmap_ids = [] - for beatmap in page: - calculations.append( - DifficultyCalculation( - beatmap_id=beatmap.id, - mods=Mods.NONE, - calculator_engine=difficulty_calculator_class.engine(), - calculator_version=difficulty_calculator_class.version(), - ) - ) - beatmap_ids.append(beatmap.id) - - # Create calculations - DifficultyCalculation.objects.bulk_create(calculations, ignore_conflicts=True) - calculations = DifficultyCalculation.objects.filter( - beatmap_id__in=beatmap_ids, - mods=Mods.NONE, - calculator_engine=difficulty_calculator_class.engine(), - calculator_version=difficulty_calculator_class.version(), - ) - - # Perform calculations - values = [] - for calculation in calculations: - values.extend( - calculation.calculate_difficulty_values(difficulty_calculator_class) - ) - progress_bar.update() - - # Create values - DifficultyValue.objects.bulk_create( - values, - update_conflicts=True, - update_fields=["value"], - unique_fields=["calculation_id", "name"], - ) - def recalculate_scores_v2( self, - difficulty_calculator_class: Type[AbstractDifficultyCalculator], + difficulty_calculator: AbstractDifficultyCalculator, scores: QuerySet[Score], force: bool = False, ): @@ -341,8 +282,8 @@ def recalculate_scores_v2( initial = 0 else: scores_to_recalculate = scores.exclude( - performance_calculations__calculator_engine=difficulty_calculator_class.engine(), - performance_calculations__calculator_version=difficulty_calculator_class.version(), + performance_calculations__calculator_engine=difficulty_calculator.engine(), + performance_calculations__calculator_version=difficulty_calculator.version(), ) if scores_to_recalculate.count() == 0: @@ -374,13 +315,13 @@ def recalculate_scores_v2( unique_beatmap_scores = scores_to_recalculate.filter( beatmap_id=unique_beatmap["beatmap_id"], mods=unique_beatmap["mods"] ) - self.recalculate_scores_for_unique_beatmap_v2( - difficulty_calculator_class, + update_performance_calculations_for_unique_beatmap( unique_beatmap["beatmap_id"], unique_beatmap["mods"], unique_beatmap_scores, - pbar, + difficulty_calculator, ) + pbar.update(unique_beatmap_scores.count()) self.stdout.write( self.style.SUCCESS( @@ -388,82 +329,6 @@ def recalculate_scores_v2( ) ) - @transaction.atomic - def recalculate_scores_for_unique_beatmap_v2( - self, - difficulty_calculator_class: Type[AbstractDifficultyCalculator], - beatmap_id: int, - mods: int, - scores: Iterable[Score], - progress_bar: tqdm, - ): - # Validate all scores are of same beatmap/mods - for score in scores: - if score.beatmap_id != beatmap_id or score.mods != mods: - raise Exception( - f"Score {score.id} does not match beatmap {beatmap_id} and mods {mods}" - ) - - # Create difficulty calculation - difficulty_calculation, _ = DifficultyCalculation.objects.get_or_create( - beatmap_id=beatmap_id, - mods=mods, - calculator_engine=difficulty_calculator_class.engine(), - calculator_version=difficulty_calculator_class.version(), - ) - - # Do difficulty calculation - difficulty_values = difficulty_calculation.calculate_difficulty_values( - difficulty_calculator_class - ) - DifficultyValue.objects.bulk_create( - difficulty_values, - update_conflicts=True, - update_fields=["value"], - unique_fields=["calculation_id", "name"], - ) - - score_dict = {} - performance_calculations = [] - for score in scores: - performance_calculations.append( - PerformanceCalculation( - score_id=score.id, - difficulty_calculation_id=difficulty_calculation.id, - calculator_engine=difficulty_calculator_class.engine(), - calculator_version=difficulty_calculator_class.version(), - ) - ) - score_dict[score.id] = score - - # Create calculations - PerformanceCalculation.objects.bulk_create( - performance_calculations, ignore_conflicts=True - ) - performance_calculations = PerformanceCalculation.objects.filter( - score_id__in=score_dict.keys(), - calculator_engine=difficulty_calculator_class.engine(), - calculator_version=difficulty_calculator_class.version(), - ) - - # Perform calculations - values = [] - for calculation in performance_calculations: - values.extend( - calculation.calculate_performance_values( - score_dict[calculation.score_id], difficulty_calculator_class - ) - ) - progress_bar.update() - - # Create values - PerformanceValue.objects.bulk_create( - values, - update_conflicts=True, - update_fields=["value"], - unique_fields=["calculation_id", "name"], - ) - def recalculate_user_stats( self, all_user_stats: QuerySet[UserStats], diff --git a/profiles/migrations/0018_remove_difficultycalculation_unique_difficulty_calculation_and_more.py b/profiles/migrations/0018_remove_difficultycalculation_unique_difficulty_calculation_and_more.py new file mode 100644 index 0000000..e951aee --- /dev/null +++ b/profiles/migrations/0018_remove_difficultycalculation_unique_difficulty_calculation_and_more.py @@ -0,0 +1,45 @@ +# Generated by Django 4.2.11 on 2024-05-08 14:02 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("profiles", "0017_difficultycalculation_performancecalculation_and_more"), + ] + + operations = [ + migrations.RemoveConstraint( + model_name="difficultycalculation", + name="unique_difficulty_calculation", + ), + migrations.RemoveConstraint( + model_name="performancecalculation", + name="unique_performance_calculation", + ), + migrations.AlterField( + model_name="performancevalue", + name="calculation", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="performance_values", + to="profiles.performancecalculation", + ), + ), + migrations.AddConstraint( + model_name="difficultycalculation", + constraint=models.UniqueConstraint( + fields=("beatmap_id", "mods", "calculator_engine"), + name="unique_difficulty_calculation", + ), + ), + migrations.AddConstraint( + model_name="performancecalculation", + constraint=models.UniqueConstraint( + fields=("score_id", "calculator_engine"), + name="unique_performance_calculation", + ), + ), + ] diff --git a/profiles/models.py b/profiles/models.py index 3c6949a..79c07ed 100644 --- a/profiles/models.py +++ b/profiles/models.py @@ -207,7 +207,9 @@ def add_scores_from_data(self, score_data_list: list[dict]): combo=score.best_combo, ) ) - score.performance_total = calculation.performance + score.performance_total = calculation.performance_values[ + "total" + ] score.difficulty_calculator_engine = ( DifficultyCalculator.engine() ) @@ -447,7 +449,7 @@ def update_difficulty_values( DifficultyCalculatorScore(beatmap_id=self.id) ) - self.difficulty_total = calculation.difficulty + self.difficulty_total = calculation.difficulty_values["total"] self.difficulty_calculator_engine = difficulty_calculator.engine() self.difficulty_calculator_version = difficulty_calculator.version() except DifficultyCalculatorException as e: @@ -479,32 +481,6 @@ class DifficultyCalculation(models.Model): calculator_engine = models.CharField(max_length=50) calculator_version = models.CharField(max_length=50) - def calculate_difficulty_values( - self, difficulty_calculator: type[AbstractDifficultyCalculator] - ) -> list["DifficultyValue"]: - values = [] - try: - with difficulty_calculator() as calculator: - calculation = calculator.calculate_score( - DifficultyCalculatorScore( - beatmap_id=self.beatmap_id, - mods=self.mods, - ) - ) - - values.append( - DifficultyValue( - calculation_id=self.id, - name="total", - value=calculation.difficulty, - ) - ) - except DifficultyCalculatorException as e: - error_reporter = ErrorReporter() - error_reporter.report_error(e) - - return values - def __str__(self): if self.mods == 0: map_string = f"{self.beatmap_id}" @@ -515,14 +491,13 @@ def __str__(self): class Meta: constraints = [ - # Difficulty values are unique on beatmap + mods + calculator_engine + calculator_version + # Difficulty calculations are unique on beatmap + mods + calculator_engine # The implicit unique b-tree index on these columns is useful also models.UniqueConstraint( fields=[ "beatmap_id", "mods", "calculator_engine", - "calculator_version", ], name="unique_difficulty_calculation", ) @@ -746,8 +721,12 @@ def process(self): count_50=self.count_50, ) ) - self.nochoke_performance_total = nochoke_calculation.performance - self.difficulty_total = nochoke_calculation.difficulty + self.nochoke_performance_total = ( + nochoke_calculation.performance_values["total"] + ) + self.difficulty_total = nochoke_calculation.difficulty_values[ + "total" + ] self.difficulty_calculator_engine = "legacy" # legacy because performance_total is still coming from the api response self.difficulty_calculator_version = "legacy" except DifficultyCalculatorException as e: @@ -769,8 +748,8 @@ def update_performance_values( combo=self.best_combo, ) ) - self.performance_total = calculation.performance - self.difficulty_total = calculation.difficulty + self.performance_total = calculation.performance_values["total"] + self.difficulty_total = calculation.difficulty_values["total"] self.difficulty_calculator_engine = calculator.engine() self.difficulty_calculator_version = calculator.version() @@ -782,7 +761,9 @@ def update_performance_values( count_50=self.count_50, ) ) - self.nochoke_performance_total = nochoke_calculation.performance + self.nochoke_performance_total = nochoke_calculation.performance_values[ + "total" + ] except DifficultyCalculatorException as e: # TODO: handle this properly self.nochoke_performance_total = 0 @@ -850,48 +831,17 @@ class PerformanceCalculation(models.Model): calculator_engine = models.CharField(max_length=50) calculator_version = models.CharField(max_length=50) - def calculate_performance_values( - self, score: Score, difficulty_calculator: type[AbstractDifficultyCalculator] - ) -> list["PerformanceValue"]: - values = [] - try: - with difficulty_calculator() as calculator: - calculation = calculator.calculate_score( - DifficultyCalculatorScore( - beatmap_id=score.beatmap_id, - mods=score.mods, - count_100=score.count_100, - count_50=score.count_50, - count_miss=score.count_miss, - combo=score.best_combo, - ) - ) - - values.append( - PerformanceValue( - calculation_id=self.id, - name="total", - value=calculation.performance, - ) - ) - except DifficultyCalculatorException as e: - error_reporter = ErrorReporter() - error_reporter.report_error(e) - - return values - def __str__(self): return f"{self.score_id}: {self.calculator_engine} ({self.calculator_version})" class Meta: constraints = [ - # Performance values are unique on score + calculator_engine + calculator_version + # Performance calculations are unique on score + calculator_engine # The implicit unique b-tree index on these columns is useful also models.UniqueConstraint( fields=[ "score_id", "calculator_engine", - "calculator_version", ], name="unique_performance_calculation", ) @@ -908,7 +858,7 @@ class PerformanceValue(models.Model): calculation = models.ForeignKey( PerformanceCalculation, on_delete=models.CASCADE, - related_name="performance_calculations", + related_name="performance_values", ) name = models.CharField(max_length=20) diff --git a/profiles/services.py b/profiles/services.py index 39333cb..2af7386 100644 --- a/profiles/services.py +++ b/profiles/services.py @@ -1,10 +1,26 @@ +import itertools from datetime import datetime, timedelta, timezone +from typing import Iterable from django.db import transaction +from common.error_reporter import ErrorReporter from common.osu.apiv1 import OsuApiV1 -from common.osu.enums import Gamemode -from profiles.models import UserStats +from common.osu.difficultycalculator import ( + AbstractDifficultyCalculator, + DifficultyCalculatorException, +) +from common.osu.difficultycalculator import Score as DifficultyCalculatorScore +from common.osu.enums import Gamemode, Mods +from profiles.models import ( + Beatmap, + DifficultyCalculation, + DifficultyValue, + PerformanceCalculation, + PerformanceValue, + Score, + UserStats, +) from profiles.tasks import update_user @@ -69,3 +85,201 @@ def fetch_scores(user_id, beatmap_ids, gamemode): new_scores = user_stats.add_scores_from_data(full_score_data_list) return new_scores + + +@transaction.atomic +def update_difficulty_calculations( + beatmaps: Iterable[Beatmap], difficulty_calculator: AbstractDifficultyCalculator +): + """ + Update difficulty calculations for passed beatmaps using passed difficulty calculator. + Existing calculations will be updated. + """ + # Create calculations + calculations = [] + beatmap_ids = [] + for beatmap in beatmaps: + calculations.append( + DifficultyCalculation( + beatmap_id=beatmap.id, + mods=Mods.NONE, + calculator_engine=difficulty_calculator.engine(), + calculator_version=difficulty_calculator.version(), + ) + ) + beatmap_ids.append(beatmap.id) + + DifficultyCalculation.objects.bulk_create( + calculations, + update_conflicts=True, + update_fields=["calculator_version"], + unique_fields=["beatmap_id", "mods", "calculator_engine"], + ) + # TODO: remove when bulk_create(update_conflicts) returns pks in django 5.0 + calculations = DifficultyCalculation.objects.filter( + beatmap_id__in=beatmap_ids, + mods=Mods.NONE, + calculator_engine=difficulty_calculator.engine(), + ) + + values = calculate_difficulty_values(calculations, difficulty_calculator) + + DifficultyValue.objects.bulk_create( + itertools.chain.from_iterable(values), + update_conflicts=True, + update_fields=["value"], + unique_fields=["calculation_id", "name"], + ) + + # TODO: what happens if the calculator is updated to remove a diff value? + # do we need to delete all values not returned for a calculation? + # with update_conflicts=True returning pks in django 5.0 we can just add a delete where not id in pks + + +@transaction.atomic +def update_performance_calculations_for_unique_beatmap( + beatmap_id: int, + mods: Mods, + scores: Iterable[Score], + difficulty_calculator: AbstractDifficultyCalculator, +): + """ + Update performance (and difficulty) calculations for passed scores using passed difficulty calculator. + Existing calculations will be updated. + """ + # Validate all scores are of same beatmap/mods + for score in scores: + if score.beatmap_id != beatmap_id or score.mods != mods: + raise ValueError( + f"Score {score.id} does not match beatmap {beatmap_id} and mods {mods}" + ) + + # Create difficulty calculation + difficulty_calculation, _ = DifficultyCalculation.objects.get_or_create( + beatmap_id=beatmap_id, + mods=mods, + calculator_engine=difficulty_calculator.engine(), + calculator_version=difficulty_calculator.version(), + ) + + # Do difficulty calculation + difficulty_values = calculate_difficulty_values( + [difficulty_calculation], difficulty_calculator + )[0] + DifficultyValue.objects.bulk_create( + difficulty_values, + update_conflicts=True, + update_fields=["value"], + unique_fields=["calculation_id", "name"], + ) + + # TODO: delete potentially outdated (removed from calc) values? + + # Create calculations + score_ids = [] + performance_calculations = [] + for score in scores: + performance_calculations.append( + PerformanceCalculation( + score=score, + difficulty_calculation_id=difficulty_calculation.id, + calculator_engine=difficulty_calculator.engine(), + calculator_version=difficulty_calculator.version(), + ) + ) + score_ids.append(score.id) + + PerformanceCalculation.objects.bulk_create( + performance_calculations, + update_conflicts=True, + update_fields=["calculator_version", "difficulty_calculation_id"], + unique_fields=["score_id", "calculator_engine"], + ) + # TODO: remove when bulk_create(update_conflicts) returns pks in django 5.0 + performance_calculations = PerformanceCalculation.objects.filter( + score_id__in=score_ids, + calculator_engine=difficulty_calculator.engine(), + ) + + values = calculate_performance_values( + performance_calculations, difficulty_calculator + ) + + PerformanceValue.objects.bulk_create( + itertools.chain.from_iterable(values), + update_conflicts=True, + update_fields=["value"], + unique_fields=["calculation_id", "name"], + ) + + # TODO: what happens if the calculator is updated to remove a perf value? + # do we need to delete all values not returned for a calculation? + # with update_conflicts=True returning pks in django 5.0 we can just add a delete where not id in pks + + +def calculate_difficulty_values( + difficulty_calculations: Iterable[DifficultyCalculation], + difficulty_calculator: AbstractDifficultyCalculator, +) -> list[list[DifficultyValue]]: + """ + Calculate difficulty values for the passed difficulty calculations using passed difficulty calculator. + """ + calc_scores = [ + DifficultyCalculatorScore( + beatmap_id=str(difficulty_calculation.beatmap_id), + mods=difficulty_calculation.mods, + ) + for difficulty_calculation in difficulty_calculations + ] + + results = difficulty_calculator.calculate_score_batch(calc_scores) + + values = [ + [ + DifficultyValue( + calculation_id=difficulty_calculation.id, + name=name, + value=value, + ) + ] + for difficulty_calculation, result in zip(difficulty_calculations, results) + for name, value in result.difficulty_values.items() + ] + + return values + + +def calculate_performance_values( + performance_calculations: Iterable[PerformanceCalculation], + difficulty_calculator: AbstractDifficultyCalculator, +) -> list[list[PerformanceValue]]: + """ + Calculate performance values for the passed performance calculations using passed difficulty calculator. + """ + calc_scores = [ + DifficultyCalculatorScore( + beatmap_id=str(performance_calculation.score.beatmap_id), + mods=performance_calculation.score.mods, + count_100=performance_calculation.score.count_100, + count_50=performance_calculation.score.count_50, + count_miss=performance_calculation.score.count_miss, + combo=performance_calculation.score.best_combo, + ) + for performance_calculation in performance_calculations + ] + + results = difficulty_calculator.calculate_score_batch(calc_scores) + + values = [ + [ + PerformanceValue( + calculation_id=performance_calculation.id, + name=name, + value=value, + ) + ] + for performance_calculation, result in zip(performance_calculations, results) + for name, value in result.performance_values.items() + ] + + return values diff --git a/profiles/test_models.py b/profiles/test_models.py index 04b9637..83de770 100644 --- a/profiles/test_models.py +++ b/profiles/test_models.py @@ -73,50 +73,3 @@ def test_update_performance_values(self, score: Score): assert score.difficulty_total == 8.975730066553297 assert score.difficulty_calculator_engine == "rosu-pp-py" assert score.difficulty_calculator_version == "1.0.0" - - -@pytest.fixture -def difficulty_calculation(beatmap: Beatmap): - return DifficultyCalculation.objects.create( - beatmap=beatmap, - mods=Mods.DOUBLETIME + Mods.HIDDEN, - calculator_engine="testcalc", - calculator_version="v1", - ) - - -@pytest.mark.django_db -class TestDifficultyCalculation: - def test_calculate_difficulty_values( - self, difficulty_calculation: DifficultyCalculation - ): - difficulty_values = difficulty_calculation.calculate_difficulty_values( - DifficultyCalculator - ) - assert len(difficulty_values) == 1 - assert difficulty_values[0].name == "total" - assert difficulty_values[0].value == 8.975730066553297 - - -@pytest.mark.django_db -class TestPerformanceCalculation: - @pytest.fixture - def performance_calculation( - self, score: Score, difficulty_calculation: DifficultyCalculation - ): - return PerformanceCalculation.objects.create( - score=score, - difficulty_calculation=difficulty_calculation, - calculator_engine="testcalc", - calculator_version="v1", - ) - - def test_calculate_performance_values( - self, performance_calculation: PerformanceCalculation - ): - performance_values = performance_calculation.calculate_performance_values( - performance_calculation.score, DifficultyCalculator - ) - assert len(performance_values) == 1 - assert performance_values[0].name == "total" - assert performance_values[0].value == 626.7353926695473 diff --git a/profiles/test_services.py b/profiles/test_services.py new file mode 100644 index 0000000..fc81374 --- /dev/null +++ b/profiles/test_services.py @@ -0,0 +1,87 @@ +import pytest + +from common.osu.difficultycalculator import DifficultyCalculator +from common.osu.enums import Mods +from profiles.models import DifficultyCalculation, PerformanceCalculation +from profiles.services import ( + calculate_difficulty_values, + calculate_performance_values, + update_difficulty_calculations, + update_performance_calculations_for_unique_beatmap, +) + + +@pytest.mark.django_db +class TestDifficultyCalculationServices: + def test_update_difficulty_calculations(self, beatmap): + difficulty_calculator = DifficultyCalculator() + update_difficulty_calculations([beatmap], difficulty_calculator) + + calculation = DifficultyCalculation.objects.get( + beatmap_id=beatmap.id, mods=Mods.NONE + ) + + difficulty_values = calculation.difficulty_values.all() + assert len(difficulty_values) == 1 + assert difficulty_values[0].name == "total" + assert difficulty_values[0].value == 6.711556915919059 + + def test_update_performance_calculations_for_unique_beatmap(self, score): + difficulty_calculator = DifficultyCalculator() + update_performance_calculations_for_unique_beatmap( + score.beatmap_id, score.mods, [score], difficulty_calculator + ) + + difficulty_calculation = DifficultyCalculation.objects.get( + beatmap_id=score.beatmap_id, mods=score.mods + ) + + difficulty_values = difficulty_calculation.difficulty_values.all() + assert len(difficulty_values) == 1 + assert difficulty_values[0].name == "total" + assert difficulty_values[0].value == 8.975730066553297 + + performance_calculation = difficulty_calculation.performance_calculations.get( + score_id=score.id + ) + + performance_values = performance_calculation.performance_values.all() + assert len(performance_values) == 1 + assert performance_values[0].name == "total" + assert performance_values[0].value == 626.7353926695473 + + @pytest.fixture + def difficulty_calculation(self, beatmap): + return DifficultyCalculation.objects.create( + beatmap=beatmap, + mods=Mods.DOUBLETIME + Mods.HIDDEN, + calculator_engine="testcalc", + calculator_version="v1", + ) + + def test_calculate_difficulty_values(self, difficulty_calculation): + difficulty_values = calculate_difficulty_values( + [difficulty_calculation], DifficultyCalculator() + ) + assert len(difficulty_values) == 1 + assert len(difficulty_values[0]) == 1 + assert difficulty_values[0][0].name == "total" + assert difficulty_values[0][0].value == 8.975730066553297 + + @pytest.fixture + def performance_calculation(self, score, difficulty_calculation): + return PerformanceCalculation.objects.create( + score=score, + difficulty_calculation=difficulty_calculation, + calculator_engine="testcalc", + calculator_version="v1", + ) + + def test_calculate_performance_values(self, performance_calculation): + performance_values = calculate_performance_values( + [performance_calculation], DifficultyCalculator() + ) + assert len(performance_values) == 1 + assert len(performance_values[0]) == 1 + assert performance_values[0][0].name == "total" + assert performance_values[0][0].value == 626.7353926695473