diff --git a/metrics/crps/crps.py b/metrics/crps/crps.py index 1b2a8f80f..ccbecffc3 100644 --- a/metrics/crps/crps.py +++ b/metrics/crps/crps.py @@ -61,14 +61,17 @@ Examples: >>> crps_metric = evaluate.load("crps") - >>> predictions = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + >>> predictions = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]) >>> references = np.array([0.3]) >>> results = crps_metric.compute(predictions=predictions, references=references) >>> print(results) - {'mase': 0.18333333333333335} - - + {'crps': 1.9999999254941967} + >>> crps_metric = evaluate.load("crps", "multilist") + >>> predictions = [[[0, 2, 4]], [[-1, 2, 5]], [[8, -5, 6]]] + >>> references = [[0.5], [-1], [7]] + >>> print(results) + {'crps': 3.0522875816993467} """ @@ -126,9 +129,9 @@ def _compute( ) if multioutput == "raw_values": - return weighted_quantile_loss + return {"crps": weighted_quantile_loss} elif multioutput == "uniform_average": - return np.average(weighted_quantile_loss) + return {"crps": np.average(weighted_quantile_loss)} else: raise ValueError( "The multioutput parameter should be one of the following: "