Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] use dataset.map in pipeline #179

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
123 changes: 87 additions & 36 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import tqdm
from datasets import disable_caching, enable_caching, is_caching_enabled
from packaging import version
from torch import Tensor
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -200,7 +201,12 @@ def _sanitize_parameters(
if p_name in pipeline_parameters:
postprocess_parameters[p_name] = pipeline_parameters[p_name]

return preprocess_parameters, dataloader_params, forward_parameters, postprocess_parameters
return (
preprocess_parameters,
dataloader_params,
forward_parameters,
postprocess_parameters,
)

def preprocess(
self,
Expand Down Expand Up @@ -292,12 +298,50 @@ def get_dataloader(

return dataloader

def _process_documents(
self,
documents: Sequence[Document],
preprocess_params: Dict[str, Any],
dataloader_params: Dict[str, Any],
forward_params: Dict[str, Any],
postprocess_params: Dict[str, Any],
) -> Sequence[Document]:
# This creates encodings from the documents. It modifies the documents and may produce multiple entries per
# document.
model_inputs = self.preprocess(documents, **preprocess_params)
if forward_params.pop("fast_dev_run", False):
warnings.warn(
"Execute a fast dev run, only the first two model inputs will be processed."
)
model_inputs = model_inputs[:2]
# Create a dataloader from the model inputs. This uses taskmodule.collate().
dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params)

show_progress_bar = forward_params.pop("show_progress_bar", False)
model_outputs: List = []
with torch.no_grad():
for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar):
output = self.forward(batch, **forward_params)
processed_output = self.taskmodule.unbatch_output(output)
model_outputs.extend(processed_output)

assert len(model_inputs) == len(
model_outputs
), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]"

documents = self.postprocess(
model_inputs=model_inputs,
model_outputs=model_outputs,
**postprocess_params,
)
return documents

def __call__(
self,
documents: Union[Document, Sequence[Document], Dataset],
*args,
**kwargs,
) -> Union[Document, Sequence[Document]]:
) -> Union[Document, Sequence[Document], Dataset]:
if args:
logger.warning(f"Ignoring args : {args}")
(
Expand All @@ -307,12 +351,6 @@ def __call__(
postprocess_params,
) = self._sanitize_parameters(**kwargs)

in_place: bool = postprocess_params.get("inplace", True)
if in_place and isinstance(documents, Dataset):
raise InplaceNotSupportedException(
"Datasets can't be modified in place. Please set inplace=False."
)

if "TOKENIZERS_PARALLELISM" not in os.environ:
logger.info(
"Disabling tokenizer parallelism, we're using DataLoader multithreading already"
Expand All @@ -337,35 +375,48 @@ def __call__(
single_document = True
documents = [documents]

# This creates encodings from the documents. It modifies the documents and may produce multiple entries per
# document.
model_inputs = self.preprocess(documents, **preprocess_params)
if forward_params.pop("fast_dev_run", False):
warnings.warn(
"Execute a fast dev run, only the first two model inputs will be processed."
)
model_inputs = model_inputs[:2]
# Create a dataloader from the model inputs. This uses taskmodule.collate().
dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params)

show_progress_bar = forward_params.pop("show_progress_bar", False)
model_outputs: List = []
with torch.no_grad():
for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar):
output = self.forward(batch, **forward_params)
processed_output = self.taskmodule.unbatch_output(output)
model_outputs.extend(processed_output)
processed_documents: Union[Sequence[Document], Dataset]
if isinstance(documents, Dataset):
in_place: bool = postprocess_params.get("inplace", True)
if in_place:
raise InplaceNotSupportedException(
"Datasets can't be modified in place. Please set inplace=False."
)
# do not show inner progress bar
forward_params["show_progress_bar"] = False

# For now, we do not allow caching for pipeline results since fingerprinting may be incorrect
# TODO: elaborate why it may be incorrect, see https://huggingface.co/docs/datasets/about_cache
was_caching_enabled = is_caching_enabled()
disable_caching()
try:
processed_documents = documents.map(
self._process_documents,
fn_kwargs=dict(
preprocess_params=preprocess_params,
dataloader_params=dataloader_params,
forward_params=forward_params,
postprocess_params=postprocess_params,
),
result_document_type=documents.document_type,
batched=True,
)
finally:
if was_caching_enabled:
enable_caching()

assert len(model_inputs) == len(
model_outputs
), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]"
else:
processed_documents = self._process_documents(
documents=documents,
preprocess_params=preprocess_params,
dataloader_params=dataloader_params,
forward_params=forward_params,
postprocess_params=postprocess_params,
)

documents = self.postprocess(
model_inputs=model_inputs,
model_outputs=model_outputs,
**postprocess_params,
)
if single_document:
return documents[0]
# TODO: fix "type: ignore" (if processed_documents is a Dataset, mypy assumes the result is Dict[Any, Any])
processed_document: Document = processed_documents[0] # type: ignore
return processed_document
else:
return documents
return processed_documents
12 changes: 12 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,15 @@ def test_pipeline_with_dataset(dataset, prepared_taskmodule, mock_model, inplace
assert not (id(returned_document) == id(document))
assert not document.entities.predictions
assert returned_document.entities.predictions


@pytest.mark.slow
def test_pipeline_with_dataset_never_cached(dataset, prepared_taskmodule, mock_model):
train_dataset = dataset["train"]

pipeline = Pipeline(model=mock_model, taskmodule=prepared_taskmodule, device=-1, inplace=False)

returned_documents1 = pipeline(train_dataset, predict_field="entities")
returned_documents2 = pipeline(train_dataset, predict_field="entities")

assert returned_documents1._fingerprint != returned_documents2._fingerprint