diff --git a/common/osu/difficultycalculator.py b/common/osu/difficultycalculator.py index b3dd6da..1a2a2b9 100644 --- a/common/osu/difficultycalculator.py +++ b/common/osu/difficultycalculator.py @@ -518,16 +518,29 @@ def gamemode(): return Gamemode.MANIA -difficulty_calculators: dict[str, type[AbstractDifficultyCalculator]] = { +difficulty_calculators_classes: dict[str, type[AbstractDifficultyCalculator]] = { name: import_string(calculator_class) for name, calculator_class in settings.DIFFICULTY_CALCULATOR_CLASSES.items() } def get_difficulty_calculator_class(name: str) -> Type[AbstractDifficultyCalculator]: - return difficulty_calculators[name] + return difficulty_calculators_classes[name] -DifficultyCalculator: Type[AbstractDifficultyCalculator] = import_string( - settings.DIFFICULTY_CALCULATOR_CLASS -) +def get_default_difficulty_calculator_class( + gamemode: Gamemode, +) -> Type[AbstractDifficultyCalculator]: + return get_difficulty_calculator_class( + settings.DEFAULT_DIFFICULTY_CALCULATORS[gamemode] + ) + + +def get_difficulty_calculators_for_gamemode( + gamemode: Gamemode, +) -> list[Type[AbstractDifficultyCalculator]]: + return [ + calculator_class + for calculator_class in difficulty_calculators_classes.values() + if calculator_class.gamemode() == gamemode + ] diff --git a/osuchan/settings.py b/osuchan/settings.py index 5f35b76..96c6997 100644 --- a/osuchan/settings.py +++ b/osuchan/settings.py @@ -8,6 +8,8 @@ from celery.schedules import crontab from pydantic_settings import BaseSettings +from common.osu.enums import Gamemode + class EnvSettings(BaseSettings): SECRET_KEY: str @@ -324,10 +326,6 @@ class EnvSettings(BaseSettings): # Difficulty calculation -DIFFICULTY_CALCULATOR_CLASS = ( - "common.osu.difficultycalculator.RosuppDifficultyCalculator" -) - DIFFICULTY_CALCULATOR_CLASSES = { "oppai": "common.osu.difficultycalculator.OppaiDifficultyCalculator", "rosupp": "common.osu.difficultycalculator.RosuppDifficultyCalculator", @@ -337,6 +335,13 @@ class EnvSettings(BaseSettings): "difficalcy-mania": "common.osu.difficultycalculator.DifficalcyManiaDifficultyCalculator", } +DEFAULT_DIFFICULTY_CALCULATORS = { + Gamemode.STANDARD: "rosupp", + Gamemode.TAIKO: "difficalcy-taiko", + Gamemode.CATCH: "difficalcy-catch", + Gamemode.MANIA: "difficalcy-mania", +} + DIFFICALCY_OSU_URL = f"http://{env_settings.DIFFICALCY_OSU_HOST}" DIFFICALCY_TAIKO_URL = f"http://{env_settings.DIFFICALCY_TAIKO_HOST}" DIFFICALCY_CATCH_URL = f"http://{env_settings.DIFFICALCY_CATCH_HOST}" diff --git a/profiles/management/commands/calculationbreakdown.py b/profiles/management/commands/calculationbreakdown.py index 2e87888..7fe4d7b 100644 --- a/profiles/management/commands/calculationbreakdown.py +++ b/profiles/management/commands/calculationbreakdown.py @@ -3,7 +3,7 @@ from common.osu.difficultycalculator import ( AbstractDifficultyCalculator, - difficulty_calculators, + difficulty_calculators_classes, ) from common.osu.enums import Gamemode from profiles.models import ( @@ -18,7 +18,7 @@ class Command(BaseCommand): help = "Displays current db calculation breakdown (new models)" def handle(self, *args, **options): - for name, difficulty_calculator_class in difficulty_calculators.items(): + for name, difficulty_calculator_class in difficulty_calculators_classes.items(): difficulty_calculator = difficulty_calculator_class() gamemode = difficulty_calculator.gamemode() diff --git a/profiles/management/commands/calculationstatus.py b/profiles/management/commands/calculationstatus.py index 944acde..3491ece 100644 --- a/profiles/management/commands/calculationstatus.py +++ b/profiles/management/commands/calculationstatus.py @@ -5,7 +5,7 @@ from common.osu.difficultycalculator import ( AbstractDifficultyCalculator, - difficulty_calculators, + difficulty_calculators_classes, ) from common.osu.enums import Gamemode from profiles.models import Beatmap, Score @@ -15,7 +15,7 @@ class Command(BaseCommand): help = "Displays current db calculation status (old models)" def handle(self, *args, **options): - for name, difficulty_calculator_class in difficulty_calculators.items(): + for name, difficulty_calculator_class in difficulty_calculators_classes.items(): gamemode = difficulty_calculator_class.gamemode() self.stdout.write( diff --git a/profiles/management/commands/recalculate.py b/profiles/management/commands/recalculate.py index 574f5c0..c11b6d3 100644 --- a/profiles/management/commands/recalculate.py +++ b/profiles/management/commands/recalculate.py @@ -9,8 +9,8 @@ from common.osu.difficultycalculator import ( AbstractDifficultyCalculator, CalculationException, - DifficultyCalculator, get_difficulty_calculator_class, + get_difficulty_calculators_for_gamemode, ) from common.osu.enums import Gamemode from leaderboards.models import Membership @@ -49,44 +49,45 @@ def handle(self, *args, **options): if diffcalc_name: difficulty_calculator_class = get_difficulty_calculator_class(diffcalc_name) else: - difficulty_calculator_class = DifficultyCalculator + difficulty_calculator_class = get_difficulty_calculators_for_gamemode( + Gamemode.STANDARD + )[0] - difficulty_calculator = difficulty_calculator_class() + with difficulty_calculator_class() as difficulty_calculator: - gamemode = difficulty_calculator.gamemode() - - self.stdout.write( - f"Gamemode: {Gamemode(gamemode).name}\n" - f"Difficulty Calculator Engine: {difficulty_calculator.engine()}\n" - f"Difficulty Calculator Version: {difficulty_calculator.version()}\n" - ) - - if v2: - # Recalculate beatmaps - beatmaps = Beatmap.objects.filter(gamemode=gamemode) - self.recalculate_beatmaps_v2(difficulty_calculator, beatmaps, force) - - # Recalculate scores - scores = Score.objects.filter(gamemode=gamemode) - self.recalculate_scores_v2(difficulty_calculator, scores, force) - else: - # Recalculate beatmaps - beatmaps = Beatmap.objects.filter(gamemode=gamemode) - self.recalculate_beatmaps(difficulty_calculator_class, beatmaps, force) - - # Recalculate scores - scores = Score.objects.filter(gamemode=gamemode) - self.recalculate_scores(difficulty_calculator_class, scores, force) - - # Recalculate user stats - all_user_stats = UserStats.objects.filter(gamemode=gamemode) - self.recalculate_user_stats(all_user_stats) + gamemode = difficulty_calculator.gamemode() + self.stdout.write( + f"Gamemode: {Gamemode(gamemode).name}\n" + f"Difficulty Calculator Engine: {difficulty_calculator.engine()}\n" + f"Difficulty Calculator Version: {difficulty_calculator.version()}\n" + ) - # Recalculate memberships - memberships = Membership.objects.select_related("leaderboard").filter( - leaderboard__gamemode=gamemode - ) - self.recalculate_memberships(memberships) + if v2: + # Recalculate beatmaps + beatmaps = Beatmap.objects.filter(gamemode=gamemode) + self.recalculate_beatmaps_v2(difficulty_calculator, beatmaps, force) + + # Recalculate scores + scores = Score.objects.filter(gamemode=gamemode) + self.recalculate_scores_v2(difficulty_calculator, scores, force) + else: + # Recalculate beatmaps + beatmaps = Beatmap.objects.filter(gamemode=gamemode) + self.recalculate_beatmaps(difficulty_calculator_class, beatmaps, force) + + # Recalculate scores + scores = Score.objects.filter(gamemode=gamemode) + self.recalculate_scores(difficulty_calculator_class, scores, force) + + # Recalculate user stats + all_user_stats = UserStats.objects.filter(gamemode=gamemode) + self.recalculate_user_stats(all_user_stats) + + # Recalculate memberships + memberships = Membership.objects.select_related("leaderboard").filter( + leaderboard__gamemode=gamemode + ) + self.recalculate_memberships(memberships) def recalculate_beatmap_page( self, diff --git a/profiles/models.py b/profiles/models.py index 17acddc..7d259ce 100644 --- a/profiles/models.py +++ b/profiles/models.py @@ -8,10 +8,10 @@ from common.osu import utils from common.osu.difficultycalculator import ( AbstractDifficultyCalculator, - DifficultyCalculator, DifficultyCalculatorException, ) from common.osu.difficultycalculator import Score as DifficultyCalculatorScore +from common.osu.difficultycalculator import get_difficulty_calculator_class from common.osu.enums import BeatmapStatus, Gamemode, Mods from profiles.enums import AllowedBeatmapStatus, ScoreResult, ScoreSet @@ -260,10 +260,6 @@ def from_data(cls, beatmap_data): beatmap_data["last_update"], "%Y-%m-%d %H:%M:%S" ).replace(tzinfo=timezone.utc) - beatmap.difficulty_total = float(beatmap_data["difficultyrating"]) - beatmap.difficulty_calculator_engine = "legacy" - beatmap.difficulty_calculator_version = "legacy" - # Update foreign key ids beatmap.creator_id = int(beatmap_data["creator_id"]) @@ -541,7 +537,7 @@ def process(self): self.__process_score_result() try: # only need to pass beatmap_id, 100s, 50s, and mods since all other options default to best possible - with DifficultyCalculator() as calc: + with get_difficulty_calculator_class("rosupp")() as calc: nochoke_calculation = calc.calculate_score( DifficultyCalculatorScore( beatmap_id=self.beatmap_id, @@ -553,11 +549,6 @@ def process(self): self.nochoke_performance_total = ( nochoke_calculation.performance_values["total"] ) - self.difficulty_total = nochoke_calculation.difficulty_values[ - "total" - ] - self.difficulty_calculator_engine = "legacy" # legacy because performance_total is still coming from the api response - self.difficulty_calculator_version = "legacy" except DifficultyCalculatorException as e: error_reporter = ErrorReporter() error_reporter.report_error(e) diff --git a/profiles/services.py b/profiles/services.py index f9d191f..d0218bd 100644 --- a/profiles/services.py +++ b/profiles/services.py @@ -9,10 +9,13 @@ from common.osu.apiv1 import OsuApiV1 from common.osu.difficultycalculator import ( AbstractDifficultyCalculator, - DifficultyCalculator, DifficultyCalculatorException, ) from common.osu.difficultycalculator import Score as DifficultyCalculatorScore +from common.osu.difficultycalculator import ( + get_default_difficulty_calculator_class, + get_difficulty_calculators_for_gamemode, +) from common.osu.enums import BeatmapStatus, Gamemode, Mods from leaderboards.models import Leaderboard, Membership from profiles.models import ( @@ -217,24 +220,22 @@ def refresh_user_from_api( score_data_list.extend( osu_api_v1.get_user_best_scores(user_stats.user_id, gamemode) ) - if gamemode == Gamemode.STANDARD: - # If standard, check user recent because we will be able to calculate pp for those scores - score_data_list.extend( - score - for score in osu_api_v1.get_user_recent_scores(user_stats.user_id, gamemode) - if score["rank"] != "F" - ) + score_data_list.extend( + score + for score in osu_api_v1.get_user_recent_scores(user_stats.user_id, gamemode) + if score["rank"] != "F" + ) user_stats.save() # Process and add scores created_scores = add_scores_from_data(user_stats, score_data_list) - # TODO: iterate all registered difficulty calculators for gamemode - if gamemode == Gamemode.STANDARD: - difficulty_calculator = DifficultyCalculator() - for score in created_scores: - update_performance_calculation(score, difficulty_calculator) + difficulty_calculators = get_difficulty_calculators_for_gamemode(gamemode) + for difficulty_calculator in difficulty_calculators: + with difficulty_calculator() as calc: + for score in created_scores: + update_performance_calculation(score, calc) return user_stats @@ -258,8 +259,27 @@ def refresh_beatmap_from_api(beatmap_id: int): ]: return None + gamemode = Gamemode(beatmap.gamemode) + + with get_default_difficulty_calculator_class(gamemode)() as calc: + calculation = calc.calculate_score( + DifficultyCalculatorScore( + mods=Mods.NONE.value, + beatmap_id=str(beatmap.id), + ) + ) + beatmap.difficulty_total = calculation.difficulty_values["total"] + beatmap.difficulty_calculator_engine = calc.engine() + beatmap.difficulty_calculator_version = calc.version() + beatmap.save() + for difficulty_calculator_class in get_difficulty_calculators_for_gamemode( + gamemode + ): + with difficulty_calculator_class() as difficulty_calculator: + update_difficulty_calculations([beatmap], difficulty_calculator) + return beatmap @@ -293,10 +313,12 @@ def fetch_scores(user_id, beatmap_ids, gamemode): # Process add scores created_scores = add_scores_from_data(user_stats, full_score_data_list) - # TODO: iterate all registered difficulty calculators for gamemode - difficulty_calculator = DifficultyCalculator() - for score in created_scores: - update_performance_calculation(score, difficulty_calculator) + for difficulty_calculator_class in get_difficulty_calculators_for_gamemode( + gamemode + ): + with difficulty_calculator_class() as difficulty_calculator: + for score in created_scores: + update_performance_calculation(score, difficulty_calculator) return created_scores @@ -382,36 +404,34 @@ def add_scores_from_data(user_stats: UserStats, score_data_list: list[dict]): score.beatmap = beatmap score.user_stats = user_stats - # Update pp - if "pp" in score_data and score_data["pp"] is not None: - score.performance_total = float(score_data["pp"]) - score.difficulty_calculator_engine = "legacy" - score.difficulty_calculator_version = "legacy" - else: - # Check for gamemode - if user_stats.gamemode != Gamemode.STANDARD: - # We cant calculate pp for this mode yet so we need to disregard this score - continue + gamemode = Gamemode(user_stats.gamemode) - try: - with DifficultyCalculator() as calc: - calculation = calc.calculate_score( - DifficultyCalculatorScore( - mods=score.mods, - beatmap_id=beatmap_id, - count_100=score.count_100, - count_50=score.count_50, - count_miss=score.count_miss, - combo=score.best_combo, - ) + # Calculate performance total + try: + with get_default_difficulty_calculator_class(gamemode)() as calc: + calculation = calc.calculate_score( + DifficultyCalculatorScore( + mods=score.mods, + beatmap_id=str(beatmap_id), + count_katu=score.count_katu, + count_300=score.count_300, + count_100=score.count_100, + count_50=score.count_50, + count_miss=score.count_miss, + combo=score.best_combo, ) - score.performance_total = calculation.performance_values["total"] - score.difficulty_calculator_engine = DifficultyCalculator.engine() - score.difficulty_calculator_version = DifficultyCalculator.version() - except DifficultyCalculatorException as e: - error_reporter = ErrorReporter() - error_reporter.report_error(e) - continue + ) + score.performance_total = calculation.performance_values["total"] + score.difficulty_total = calculation.difficulty_values["total"] + score.difficulty_calculator_engine = calc.engine() + score.difficulty_calculator_version = calc.version() + except DifficultyCalculatorException as e: + error_reporter = ErrorReporter() + error_reporter.report_error(e) + score.performance_total = 0 + score.difficulty_total = 0 + score.difficulty_calculator_engine = "error" + score.difficulty_calculator_version = "error" # Update convenience fields score.gamemode = user_stats.gamemode diff --git a/profiles/test_models.py b/profiles/test_models.py index 220f026..f0bca49 100644 --- a/profiles/test_models.py +++ b/profiles/test_models.py @@ -1,16 +1,8 @@ import pytest -from common.osu.difficultycalculator import DifficultyCalculator -from common.osu.enums import Mods +from common.osu.difficultycalculator import get_difficulty_calculator_class from profiles.enums import ScoreResult -from profiles.models import ( - Beatmap, - DifficultyCalculation, - OsuUser, - PerformanceCalculation, - Score, - UserStats, -) +from profiles.models import Beatmap, OsuUser, Score, UserStats @pytest.mark.django_db @@ -46,7 +38,7 @@ def test_from_data(self): pass def test_update_difficulty_values(self, beatmap: Beatmap): - beatmap.update_difficulty_values(DifficultyCalculator) + beatmap.update_difficulty_values(get_difficulty_calculator_class("rosupp")) assert beatmap.difficulty_total == 6.711556915919059 assert beatmap.difficulty_calculator_engine == "rosu-pp-py" assert beatmap.difficulty_calculator_version == "1.0.1" @@ -60,14 +52,10 @@ def test_magic_str(self, score: Score): def test_process(self, score: Score): score.process() assert score.result == ScoreResult.END_CHOKE - assert score.performance_total == 395.282 assert score.nochoke_performance_total == 626.7353926695473 - assert score.difficulty_total == 8.975730066553297 - assert score.difficulty_calculator_engine == "legacy" - assert score.difficulty_calculator_version == "legacy" def test_update_performance_values(self, score: Score): - score.update_performance_values(DifficultyCalculator) + score.update_performance_values(get_difficulty_calculator_class("rosupp")) assert score.performance_total == 626.7353926695473 assert score.nochoke_performance_total == 626.7353926695473 assert score.difficulty_total == 8.975730066553297 diff --git a/profiles/test_services.py b/profiles/test_services.py index 928a6aa..6597852 100644 --- a/profiles/test_services.py +++ b/profiles/test_services.py @@ -1,6 +1,6 @@ import pytest -from common.osu.difficultycalculator import DifficultyCalculator +from common.osu.difficultycalculator import get_difficulty_calculator_class from common.osu.enums import Mods from profiles.models import DifficultyCalculation, PerformanceCalculation from profiles.services import ( @@ -35,15 +35,15 @@ def test_refresh_user_from_api(self): PerformanceCalculation.objects.filter( score__user_stats_id=user_stats.id ).count() - == 5 + == 15 # 5 scores * 3 calculators ) @pytest.mark.django_db class TestDifficultyCalculationServices: def test_update_difficulty_calculations(self, beatmap): - difficulty_calculator = DifficultyCalculator() - update_difficulty_calculations([beatmap], difficulty_calculator) + with get_difficulty_calculator_class("rosupp")() as difficulty_calculator: + update_difficulty_calculations([beatmap], difficulty_calculator) calculation = DifficultyCalculation.objects.get( beatmap_id=beatmap.id, mods=Mods.NONE @@ -55,10 +55,10 @@ def test_update_difficulty_calculations(self, beatmap): assert difficulty_values[0].value == 6.711556915919059 def test_update_performance_calculations_for_unique_beatmap(self, score): - difficulty_calculator = DifficultyCalculator() - update_performance_calculations_for_unique_beatmap( - score.beatmap_id, score.mods, [score], difficulty_calculator - ) + with get_difficulty_calculator_class("rosupp")() as difficulty_calculator: + update_performance_calculations_for_unique_beatmap( + score.beatmap_id, score.mods, [score], difficulty_calculator + ) difficulty_calculation = DifficultyCalculation.objects.get( beatmap_id=score.beatmap_id, mods=score.mods @@ -82,8 +82,8 @@ def test_update_performance_calculation( self, score, ): - difficulty_calculator = DifficultyCalculator() - update_performance_calculation(score, difficulty_calculator) + with get_difficulty_calculator_class("rosupp")() as difficulty_calculator: + update_performance_calculation(score, difficulty_calculator) difficulty_calculation = DifficultyCalculation.objects.get( beatmap_id=score.beatmap_id, mods=score.mods @@ -113,9 +113,10 @@ def difficulty_calculation(self, beatmap): ) def test_calculate_difficulty_values(self, difficulty_calculation): - difficulty_values = calculate_difficulty_values( - [difficulty_calculation], DifficultyCalculator() - ) + with get_difficulty_calculator_class("rosupp")() as difficulty_calculator: + difficulty_values = calculate_difficulty_values( + [difficulty_calculation], difficulty_calculator + ) assert len(difficulty_values) == 1 assert len(difficulty_values[0]) == 1 assert difficulty_values[0][0].name == "total" @@ -131,9 +132,10 @@ def performance_calculation(self, score, difficulty_calculation): ) def test_calculate_performance_values(self, performance_calculation): - performance_values = calculate_performance_values( - [performance_calculation], DifficultyCalculator() - ) + with get_difficulty_calculator_class("rosupp")() as difficulty_calculator: + performance_values = calculate_performance_values( + [performance_calculation], difficulty_calculator + ) assert len(performance_values) == 1 assert len(performance_values[0]) == 1 assert performance_values[0][0].name == "total" diff --git a/profiles/test_views.py b/profiles/test_views.py index d2d9016..76b2792 100644 --- a/profiles/test_views.py +++ b/profiles/test_views.py @@ -70,13 +70,13 @@ def test_get(self, arf: APIRequestFactory, view, stub_user_stats): assert response.status_code == HTTPStatus.OK assert len(response.data) == 4 assert response.data[0]["difficulty_total"] == 6.264344677869616 - assert response.data[0]["performance_total"] == 395.281 + assert response.data[0]["performance_total"] == 395.2821554526868 assert response.data[0]["nochoke_performance_total"] == 395.39084780089814 assert response.data[1]["difficulty_total"] == 6.679077669651381 - assert response.data[1]["performance_total"] == 381.606 + assert response.data[1]["performance_total"] == 381.60801992603007 assert response.data[1]["nochoke_performance_total"] == 381.60801992603007 assert response.data[2]["difficulty_total"] == 6.28551550473302 - assert response.data[2]["performance_total"] == 371.204 + assert response.data[2]["performance_total"] == 371.203519484766 assert response.data[2]["nochoke_performance_total"] == 371.203519484766 assert response.data[3]["difficulty_total"] == 5.5699192504372625 assert response.data[3]["performance_total"] == 143.53942289330428