Skip to content

Commit

Permalink
Refactor diffcalc to return all values
Browse files Browse the repository at this point in the history
  • Loading branch information
Syriiin committed May 8, 2024
1 parent e815c46 commit 221ffe3
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 35 deletions.
15 changes: 8 additions & 7 deletions common/osu/difficultycalculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class Score(NamedTuple):


class Calculation(NamedTuple):
difficulty: float
performance: float
difficulty_values: dict[str, float]
performance_values: dict[str, float]


class DifficultyCalculatorException(Exception):
Expand Down Expand Up @@ -90,7 +90,8 @@ def calculate_score(self, score: Score) -> Calculation:

self.calculate()
return Calculation(
difficulty=self.difficulty_total, performance=self.performance_total
difficulty_values={"total": self.difficulty_total},
performance_values={"total": self.performance_total},
)

def calculate_score_batch(self, scores: Iterable[Score]) -> list[Calculation]:
Expand Down Expand Up @@ -332,8 +333,8 @@ def calculate_score(self, score: Score) -> Calculation:
) from e

return Calculation(
difficulty=data["difficulty"]["total"],
performance=data["performance"]["total"],
difficulty_values=data["difficulty"],
performance_values=data["performance"],
)

def calculate_score_batch(self, scores: Iterable[Score]) -> list[Calculation]:
Expand All @@ -351,8 +352,8 @@ def calculate_score_batch(self, scores: Iterable[Score]) -> list[Calculation]:

return [
Calculation(
difficulty=calculation_data["difficulty"]["total"],
performance=calculation_data["performance"]["total"],
difficulty_values=calculation_data["difficulty"],
performance_values=calculation_data["performance"],
)
for calculation_data in data
]
Expand Down
112 changes: 99 additions & 13 deletions common/osu/test_difficultycalculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_calculate_score(self):
combo=2000,
)
assert calc.calculate_score(score) == Calculation(
difficulty=5.919765949249268, performance=298.1595153808594
difficulty_values={"total": 5.919765949249268},
performance_values={"total": 298.1595153808594},
)

def test_calculate_score_batch(self):
Expand Down Expand Up @@ -71,9 +72,18 @@ def test_calculate_score_batch(self):
),
]
assert calc.calculate_score_batch(scores) == [
Calculation(difficulty=5.919765949249268, performance=298.1595153808594),
Calculation(difficulty=6.20743465423584, performance=476.4307861328125),
Calculation(difficulty=6.20743465423584, performance=630.419677734375),
Calculation(
difficulty_values={"total": 5.919765949249268},
performance_values={"total": 298.1595153808594},
),
Calculation(
difficulty_values={"total": 6.20743465423584},
performance_values={"total": 476.4307861328125},
),
Calculation(
difficulty_values={"total": 6.20743465423584},
performance_values={"total": 630.419677734375},
),
]

@pytest.fixture
Expand Down Expand Up @@ -148,7 +158,8 @@ def test_calculate_score(self):
combo=2000,
)
assert calc.calculate_score(score) == Calculation(
difficulty=6.264344677869616, performance=312.43705315450256
difficulty_values={"total": 6.264344677869616},
performance_values={"total": 312.43705315450256},
)

def test_calculate_score_batch(self):
Expand Down Expand Up @@ -176,9 +187,18 @@ def test_calculate_score_batch(self):
),
]
assert calc.calculate_score_batch(scores) == [
Calculation(difficulty=6.264344677869616, performance=312.43705315450256),
Calculation(difficulty=6.531051472171891, performance=487.5904861756349),
Calculation(difficulty=6.531051472171891, performance=655.9388807525456),
Calculation(
difficulty_values={"total": 6.264344677869616},
performance_values={"total": 312.43705315450256},
),
Calculation(
difficulty_values={"total": 6.531051472171891},
performance_values={"total": 487.5904861756349},
),
Calculation(
difficulty_values={"total": 6.531051472171891},
performance_values={"total": 655.9388807525456},
),
]

@pytest.fixture
Expand Down Expand Up @@ -234,7 +254,19 @@ def test_version(self):
def test_context_manager(self):
with DifficalcyOsuDifficultyCalculator() as calc:
assert calc.calculate_score(Score("307618")) == Calculation(
difficulty=4.4569433791337945, performance=135.0040504515237
difficulty_values={
"aim": 2.08629357857818,
"speed": 2.1778593015565684,
"flashlight": 0,
"total": 4.4569433791337945,
},
performance_values={
"aim": 44.12278272319251,
"speed": 50.54174287197802,
"accuracy": 36.07670429437059,
"flashlight": 0,
"total": 135.0040504515237,
},
)

def test_invalid_beatmap(self):
Expand All @@ -253,7 +285,19 @@ def test_calculate_score(self):
combo=2000,
)
assert calc.calculate_score(score) == Calculation(
difficulty=6.263707394408435, performance=312.36671287580185
difficulty_values={
"aim": 2.892063051954271,
"speed": 3.0958487396004704,
"flashlight": 0,
"total": 6.263707394408435,
},
performance_values={
"aim": 98.6032935956297,
"speed": 118.92511309917593,
"accuracy": 84.96884392557897,
"flashlight": 0,
"total": 312.36671287580185,
},
)

def test_calculate_score_batch(self):
Expand Down Expand Up @@ -281,7 +325,49 @@ def test_calculate_score_batch(self):
),
]
assert calc.calculate_score_batch(scores) == [
Calculation(difficulty=6.263707394408435, performance=312.36671287580185),
Calculation(difficulty=6.530286188377548, performance=487.4810004992573),
Calculation(difficulty=6.530286188377548, performance=655.7872855036575),
Calculation(
difficulty_values={
"aim": 2.892063051954271,
"speed": 3.0958487396004704,
"flashlight": 0,
"total": 6.263707394408435,
},
performance_values={
"aim": 98.6032935956297,
"speed": 118.92511309917593,
"accuracy": 84.96884392557897,
"flashlight": 0,
"total": 312.36671287580185,
},
),
Calculation(
difficulty_values={
"aim": 3.1381340530266333,
"speed": 3.1129549941521066,
"flashlight": 0,
"total": 6.530286188377548,
},
performance_values={
"aim": 153.058022351103,
"speed": 153.10941688245896,
"accuracy": 166.32370945374015,
"flashlight": 0,
"total": 487.4810004992573,
},
),
Calculation(
difficulty_values={
"aim": 3.1381340530266333,
"speed": 3.1129549941521066,
"flashlight": 0,
"total": 6.530286188377548,
},
performance_values={
"aim": 207.5808620241847,
"speed": 215.2746980112218,
"accuracy": 212.8087296294707,
"flashlight": 0,
"total": 655.7872855036575,
},
),
]
22 changes: 15 additions & 7 deletions profiles/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def add_scores_from_data(self, score_data_list: list[dict]):
combo=score.best_combo,
)
)
score.performance_total = calculation.performance
score.performance_total = calculation.performance_values[
"total"
]
score.difficulty_calculator_engine = (
DifficultyCalculator.engine()
)
Expand Down Expand Up @@ -447,7 +449,7 @@ def update_difficulty_values(
DifficultyCalculatorScore(beatmap_id=self.id)
)

self.difficulty_total = calculation.difficulty
self.difficulty_total = calculation.difficulty_values["total"]
self.difficulty_calculator_engine = difficulty_calculator.engine()
self.difficulty_calculator_version = difficulty_calculator.version()
except DifficultyCalculatorException as e:
Expand Down Expand Up @@ -719,8 +721,12 @@ def process(self):
count_50=self.count_50,
)
)
self.nochoke_performance_total = nochoke_calculation.performance
self.difficulty_total = nochoke_calculation.difficulty
self.nochoke_performance_total = (
nochoke_calculation.performance_values["total"]
)
self.difficulty_total = nochoke_calculation.difficulty_values[
"total"
]
self.difficulty_calculator_engine = "legacy" # legacy because performance_total is still coming from the api response
self.difficulty_calculator_version = "legacy"
except DifficultyCalculatorException as e:
Expand All @@ -742,8 +748,8 @@ def update_performance_values(
combo=self.best_combo,
)
)
self.performance_total = calculation.performance
self.difficulty_total = calculation.difficulty
self.performance_total = calculation.performance_values["total"]
self.difficulty_total = calculation.difficulty_values["total"]
self.difficulty_calculator_engine = calculator.engine()
self.difficulty_calculator_version = calculator.version()

Expand All @@ -755,7 +761,9 @@ def update_performance_values(
count_50=self.count_50,
)
)
self.nochoke_performance_total = nochoke_calculation.performance
self.nochoke_performance_total = nochoke_calculation.performance_values[
"total"
]
except DifficultyCalculatorException as e:
# TODO: handle this properly
self.nochoke_performance_total = 0
Expand Down
14 changes: 6 additions & 8 deletions profiles/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,16 @@ def calculate_difficulty_values(

results = difficulty_calculator.calculate_score_batch(calc_scores)

# TODO: handle multiple values per calculation
# do we need a "primary" or "total" boolean to determine the main value?
# or should calculation have a "primary_value" field?
# or should we just always use "total" as the main value?
values = [
[
DifficultyValue(
calculation_id=difficulty_calculation.id,
name="total",
value=result.difficulty,
name=name,
value=value,
)
]
for difficulty_calculation, result in zip(difficulty_calculations, results)
for name, value in result.difficulty_values.items()
]

return values
Expand Down Expand Up @@ -281,11 +278,12 @@ def calculate_performance_values(
[
PerformanceValue(
calculation_id=performance_calculation.id,
name="total",
value=result.performance,
name=name,
value=value,
)
]
for performance_calculation, result in zip(performance_calculations, results)
for name, value in result.performance_values.items()
]

return values

0 comments on commit 221ffe3

Please sign in to comment.