Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
sveinbjornt committed Dec 12, 2022
2 parents 9809e5a + 41adb9b commit b3b1269
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
16 changes: 7 additions & 9 deletions src/reynir_correct/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,13 @@
"""

from typing import List, overload
from typing import List, Union, overload


try:
from datasets import load_dataset
from transformers import pipeline # type: ignore
except:
import sys
import warnings

warningtext = (
Expand Down Expand Up @@ -82,19 +81,18 @@ def classify(self, text: str) -> bool:
def classify(self, text: List[str]) -> List[bool]:
...

def classify(self, text):
def classify(self, text: Union[str, List[str]]) -> Union[List[bool], bool]:
"""Classify a sentence or sentences.
For each sentence, return true iff the sentence probably contains an error."""
if isinstance(text, str):
text = [text]

result = self.pipe([self._domain_prefix + t for t in text])
result = [r["generated_text"] == self._true_label for r in result]
pipe_result = self.pipe([self._domain_prefix + t for t in text])
result: List[bool] = [
r["generated_text"] == self._true_label for r in pipe_result
]

if len(result) == 1:
result = result[0]

return result
return result[0] if len(result) == 1 else result


def _main() -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/reynir_correct/config/GreynirCorrect.conf
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@
"áratugs", "áratugar"
"áratugsins", "áratugarins"
"árð", "árið"
"arfleið", "arfleifð"
"arfleiðar", "arfleifðar"
"Argentísk", "Argentínsk"
"argentíska", "argentínska"
"argentíski", "argentínski"
Expand Down
2 changes: 1 addition & 1 deletion src/reynir_correct/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def wrong_preposition_grin_af(self, match: SimpleTree) -> None:
pp = match.first_match('P > { "af" }')
if pp is None:
pp = match.first_match('ADVP > { "af" }')
if np is None or pp is None:
if vp is None or np is None or pp is None:
return
pp_af = pp.first_match('"af"')
if pp_af is None:
Expand Down
4 changes: 4 additions & 0 deletions src/reynir_correct/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@
Any,
Union,
cast,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from .classifier import SentenceClassifier

import sys
import argparse
import json
Expand Down

0 comments on commit b3b1269

Please sign in to comment.