Skip to content

Commit

Permalink
remove is_training parameter from encode_input() and encode_inputs() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Nov 8, 2023
1 parent a25bf0f commit 2f57ef5
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/pytorch_ie/core/taskmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ def batch_encode(
) -> Tuple[
Sequence[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], Sequence[DocumentType]
]:
## TODO: revisit the assumption that encode_target=True always implies that
## is_training=True
task_encodings, documents_in_order = self.encode_inputs(
documents, is_training=encode_target, show_progress=show_progress
documents, show_progress=show_progress
)

if encode_target:
Expand Down Expand Up @@ -298,7 +296,6 @@ def encode(
def encode_inputs(
self,
documents: Sequence[DocumentType],
is_training: bool = False,
show_progress: bool = False,
) -> Tuple[
Sequence[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
Expand All @@ -310,7 +307,7 @@ def encode_inputs(
# a document might be generated on the fly (e.g. with a Dataset), so we add it here
documents_in_order.append(document)

possible_task_encodings = self.encode_input(document, is_training)
possible_task_encodings = self.encode_input(document)

# encode_input returns None or an empty list
if possible_task_encodings is None or not possible_task_encodings:
Expand All @@ -328,7 +325,6 @@ def encode_inputs(
def encode_input(
self,
document: DocumentType,
is_training: bool = False,
) -> Optional[
Union[
TaskEncoding[DocumentType, InputEncoding, TargetEncoding],
Expand Down

0 comments on commit 2f57ef5

Please sign in to comment.