Skip to content

Commit

Permalink
return dict
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Oct 27, 2023
1 parent 421166e commit 7f21013
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions metrics/crps/crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@
Examples:
>>> crps_metric = evaluate.load("crps")
>>> predictions = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
>>> predictions = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]])
>>> references = np.array([0.3])
>>> results = crps_metric.compute(predictions=predictions, references=references)
>>> print(results)
{'mase': 0.18333333333333335}
{'crps': 1.9999999254941967}
>>> crps_metric = evaluate.load("crps", "multilist")
>>> predictions = [[[0, 2, 4]], [[-1, 2, 5]], [[8, -5, 6]]]
>>> references = [[0.5], [-1], [7]]
>>> print(results)
{'crps': 3.0522875816993467}
"""


Expand Down Expand Up @@ -126,9 +129,9 @@ def _compute(
)

if multioutput == "raw_values":
return weighted_quantile_loss
return {"crps": weighted_quantile_loss}
elif multioutput == "uniform_average":
return np.average(weighted_quantile_loss)
return {"crps": np.average(weighted_quantile_loss)}
else:
raise ValueError(
"The multioutput parameter should be one of the following: "
Expand Down

0 comments on commit 7f21013

Please sign in to comment.