Skip to content

Commit

Permalink
Merge pull request caikit#304 from evaline-ju/empty-text
Browse files Browse the repository at this point in the history
🥅 Handle empty text for filtered span classification
  • Loading branch information
gkumbhat authored Jan 17, 2024
2 parents c381f6e + e0c8633 commit 6bc17e2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,16 @@ def run(
Returns:
TokenClassificationResults
"""
error.type_check("<NLP82129006E>", str, text=text)
error.type_check("<NLP01414077E>", float, allow_none=True, threshold=threshold)

if threshold is None:
threshold = self.default_threshold
if not text:
# Allow empty text case to fall through - some tokenizers or
# classifiers may error on this
return TokenClassificationResults(results=[])

token_classification_results = []
if self.classification_task == TextClassificationTask:
# Split document into spans
Expand Down Expand Up @@ -196,10 +204,17 @@ def run_bidi_stream(
Returns:
Iterable[TokenClassificationStreamResult]
"""
error.type_check("<NLP96166348E>", float, allow_none=True, threshold=threshold)
# TODO: For optimization implement window based approach.
if threshold is None:
threshold = self.default_threshold

# Types on the stream are checked later on iteration
if len(text_stream) == 0:
# Allow empty text case to fall through - some tokenizers or
# classifiers may error on this
yield TokenClassificationStreamResult(results=[], processed_index=0)

for span_output in self._stream_span_output(text_stream):
classification_result = self.classifier.run(span_output.text)
results_to_end_of_span = False
Expand Down Expand Up @@ -344,6 +359,7 @@ def __update_spans(token):
return token

for text in text_stream:
error.type_check("<NLP38357927E>", str, text=text)
stream_accumulator += text
# In order to avoid processing all of the spans again, we only
# send out the spans that are not yet finalized in detected_spans
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
)
TOK_CLASSIFICATION_RESULT = TokenClassificationResults(results=[FOX_CLASS, DOG_CLASS])

# NOTE: First test will test this separately
BOOTSTRAPPED_MODEL = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

# Modules that already returns token classification for tests
@module(
"44d61711-c64b-4774-a39f-a9f40f1fcff0",
Expand Down Expand Up @@ -120,13 +128,7 @@ def test_bootstrap_run():

def test_bootstrap_run_with_threshold():
"""Check if we can bootstrap span classification models with overriden threshold"""
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)
token_classification_result = model.run(DOCUMENT, threshold=0.0)
token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0.0)
assert isinstance(token_classification_result, TokenClassificationResults)
assert (
len(token_classification_result.results) == 4
Expand Down Expand Up @@ -187,16 +189,17 @@ def test_bootstrap_run_with_token_classification_no_results():
assert len(token_classification_result.results) == 0


def test_bootstrap_run_empty():
"""Check if span classification model can run with empty string"""
token_classification_result = BOOTSTRAPPED_MODEL.run("")
assert isinstance(token_classification_result, TokenClassificationResults)
assert len(token_classification_result.results) == 0


def test_save_load_and_run_model():
"""Check if we can run a saved model successfully"""
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)
with tempfile.TemporaryDirectory() as model_dir:
model.save(model_dir)
BOOTSTRAPPED_MODEL.save(model_dir)
assert os.path.exists(os.path.join(model_dir, "config.yml"))
assert os.path.exists(os.path.join(model_dir, "tokenizer"))
assert os.path.exists(os.path.join(model_dir, "classification"))
Expand All @@ -216,14 +219,9 @@ def test_run_bidi_stream_model():
"""Check if model prediction works as expected for bi-directional stream"""

stream_input = data_model.DataStream.from_iterable(DOCUMENT)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)

streaming_token_classification_result = model.run_bidi_stream(stream_input)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
Expand Down Expand Up @@ -351,14 +349,10 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk():
works as expected for bi-directional stream"""
doc_stream = (DOCUMENT, " I am another sentence.")
stream_input = data_model.DataStream.from_iterable(doc_stream)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

streaming_token_classification_result = model.run_bidi_stream(stream_input)
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
Expand All @@ -385,22 +379,30 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk():
assert count == expected_number_of_sentences


def test_run_bidi_stream_empty():
"""Check if span classification model can run with empty string for streaming"""
stream_input = data_model.DataStream.from_iterable("")
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
assert len(result_list) == 1
assert result_list[0].results == []
assert result_list[0].processed_index == 0


def test_run_stream_vs_no_stream():
"""Check if model prediction on stream with multiple sentences/spans
works as expected for bi-directional stream and gives expected span results
as non-stream"""
multiple_sentences = (
"The dragon hoarded gold. The cow ate grass. What is happening? What a day!"
)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

# Non-stream run
nonstream_classification_result = model.run(multiple_sentences)
nonstream_classification_result = BOOTSTRAPPED_MODEL.run(multiple_sentences)
assert len(nonstream_classification_result.results) == 4
assert nonstream_classification_result.results[0].word == "The dragon hoarded gold."
assert nonstream_classification_result.results[0].start == 0
Expand All @@ -411,7 +413,7 @@ def test_run_stream_vs_no_stream():

# Char-based stream
stream_input = data_model.DataStream.from_iterable(multiple_sentences)
stream_classification_result = model.run_bidi_stream(stream_input)
stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(stream_input)
# Convert to list to more easily check outputs
result_list = list(stream_classification_result)
assert len(result_list) == 4 # one per sentence
Expand All @@ -422,7 +424,9 @@ def test_run_stream_vs_no_stream():

# Chunk-based stream
chunk_stream_input = data_model.DataStream.from_iterable((multiple_sentences,))
chunk_stream_classification_result = model.run_bidi_stream(chunk_stream_input)
chunk_stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
chunk_stream_input
)
result_list = list(chunk_stream_classification_result)
assert len(result_list) == 4 # one per sentence
assert result_list[0].processed_index == 24
Expand Down

0 comments on commit 6bc17e2

Please sign in to comment.