-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RNN.predict should return labels, not probabilties #20
Comments
@dnouri -- What makes this a bit tricky is that Passage does not have distinct classes for regression and classification tasks, meaning that there's no simple interface fix to make this fit the sklearn interface properly (changing to @Newmu -- Would you prefer that this be a documentation / example change or would you prefer a solution where we provide sklearn-compatible interfaces via subclasses of |
result = model.predict(tokenizer.transform(dataTest)) |
result = model.predict(tokenizer.transform(dataTest)) > 0.5 |
How to solve this for multiclass predictions? Thanks! |
There's a slight incompatibility with sklearn in the
RNN.predict
method: this one should return predicted class labels.predict_proba
is the name of the method that returns probabilities. In Passage's case it's like the existingpredict
except that, for binary classification tasks, sklearn expects a (n,2) matrix with one column for each of negative and positive probabilities.Here's the two methods (a hack) that I use in a subclass to implement
predict
andpredict_proba
to work with sklearn, on top of the existingRNN.predict
. As it is, it only works with binary classification:As I'm not sure what else the current
predict
can return (i.e. when it's not doing binary classification), I'm also not sure what's the right way to change the original code, so that it still works with all the tasks that it was designed for.The text was updated successfully, but these errors were encountered: