Skip to content

Commit

Permalink
probablity summing issue is suppressed because of sigmoid output of t…
Browse files Browse the repository at this point in the history
…he models
  • Loading branch information
mohaliyet committed Dec 12, 2024
1 parent 0a25434 commit ce2d960
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 95 deletions.
3 changes: 2 additions & 1 deletion mapie/conformity_scores/sets/aps.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_predictions(
Array of predictions.
"""
y_pred_proba = estimator.predict(X, agg_scores)
y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
# y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
if agg_scores != "crossval":
y_pred_proba = np.repeat(
y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2
Expand Down Expand Up @@ -161,6 +161,7 @@ def get_conformity_scores(
y_proba_true = np.take_along_axis(
y_pred, y_enc.reshape(-1, 1), axis=1
)

random_state = check_random_state(self.random_state)
u = random_state.uniform(size=len(y_pred)).reshape(-1, 1)
conformity_scores -= u * y_proba_true
Expand Down
2 changes: 1 addition & 1 deletion mapie/conformity_scores/sets/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_predictions(
Array of predictions.
"""
y_pred_proba = estimator.predict(X, agg_scores)
y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
# y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
if agg_scores != "crossval":
y_pred_proba = np.repeat(
y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2
Expand Down
2 changes: 1 addition & 1 deletion mapie/conformity_scores/sets/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_predictions(
Array of predictions.
"""
y_pred_proba = estimator.predict(X, agg_scores='mean')
y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
# y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
y_pred_proba = np.repeat(
y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2
)
Expand Down
2 changes: 1 addition & 1 deletion mapie/conformity_scores/sets/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_predictions(
Array of predictions.
"""
y_pred_proba = estimator.predict(X, agg_scores="mean")
y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
# y_pred_proba = check_proba_normalized(y_pred_proba, axis=1)
y_pred_proba = np.repeat(
y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2
)
Expand Down
2 changes: 1 addition & 1 deletion mapie/estimator/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def predict_proba_calib(

if self.cv == "prefit":
y_pred_proba = self.single_estimator_.predict_proba(X)
y_pred_proba = self._check_proba_normalized(y_pred_proba)
# y_pred_proba = self._check_proba_normalized(y_pred_proba)
else:
X = cast(NDArray, X)
y_pred_proba = np.empty((len(X), self.n_classes), dtype=float)
Expand Down
443 changes: 353 additions & 90 deletions notebooks/classification/Cifar10.ipynb

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions notebooks/classification/text.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "detectron2",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
56 changes: 56 additions & 0 deletions text.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from mapie.conformity_scores.sets.topk import TopKConformityScore"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "TopKConformityScore.get_conformity_scores() missing 2 required positional arguments: 'y' and 'y_pred'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[6], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m topK \u001b[38;5;241m=\u001b[39m TopKConformityScore()\n\u001b[1;32m----> 2\u001b[0m \u001b[43mtopK\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_conformity_scores\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[1;31mTypeError\u001b[0m: TopKConformityScore.get_conformity_scores() missing 2 required positional arguments: 'y' and 'y_pred'"
]
}
],
"source": [
"topK = TopKConformityScore()\n",
"topK.get_conformity_scores()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "detectron2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit ce2d960

Please sign in to comment.