diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 1be37d5c..d10c291e 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -7,6 +7,7 @@ import torch import tqdm +from datasets import disable_caching, enable_caching, is_caching_enabled from packaging import version from torch import Tensor from torch.utils.data import DataLoader @@ -200,7 +201,12 @@ def _sanitize_parameters( if p_name in pipeline_parameters: postprocess_parameters[p_name] = pipeline_parameters[p_name] - return preprocess_parameters, dataloader_params, forward_parameters, postprocess_parameters + return ( + preprocess_parameters, + dataloader_params, + forward_parameters, + postprocess_parameters, + ) def preprocess( self, @@ -292,12 +298,50 @@ def get_dataloader( return dataloader + def _process_documents( + self, + documents: Sequence[Document], + preprocess_params: Dict[str, Any], + dataloader_params: Dict[str, Any], + forward_params: Dict[str, Any], + postprocess_params: Dict[str, Any], + ) -> Sequence[Document]: + # This creates encodings from the documents. It modifies the documents and may produce multiple entries per + # document. + model_inputs = self.preprocess(documents, **preprocess_params) + if forward_params.pop("fast_dev_run", False): + warnings.warn( + "Execute a fast dev run, only the first two model inputs will be processed." + ) + model_inputs = model_inputs[:2] + # Create a dataloader from the model inputs. This uses taskmodule.collate(). + dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params) + + show_progress_bar = forward_params.pop("show_progress_bar", False) + model_outputs: List = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar): + output = self.forward(batch, **forward_params) + processed_output = self.taskmodule.unbatch_output(output) + model_outputs.extend(processed_output) + + assert len(model_inputs) == len( + model_outputs + ), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]" + + documents = self.postprocess( + model_inputs=model_inputs, + model_outputs=model_outputs, + **postprocess_params, + ) + return documents + def __call__( self, documents: Union[Document, Sequence[Document], Dataset], *args, **kwargs, - ) -> Union[Document, Sequence[Document]]: + ) -> Union[Document, Sequence[Document], Dataset]: if args: logger.warning(f"Ignoring args : {args}") ( @@ -307,12 +351,6 @@ def __call__( postprocess_params, ) = self._sanitize_parameters(**kwargs) - in_place: bool = postprocess_params.get("inplace", True) - if in_place and isinstance(documents, Dataset): - raise InplaceNotSupportedException( - "Datasets can't be modified in place. Please set inplace=False." - ) - if "TOKENIZERS_PARALLELISM" not in os.environ: logger.info( "Disabling tokenizer parallelism, we're using DataLoader multithreading already" @@ -337,35 +375,48 @@ def __call__( single_document = True documents = [documents] - # This creates encodings from the documents. It modifies the documents and may produce multiple entries per - # document. - model_inputs = self.preprocess(documents, **preprocess_params) - if forward_params.pop("fast_dev_run", False): - warnings.warn( - "Execute a fast dev run, only the first two model inputs will be processed." - ) - model_inputs = model_inputs[:2] - # Create a dataloader from the model inputs. This uses taskmodule.collate(). - dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params) - - show_progress_bar = forward_params.pop("show_progress_bar", False) - model_outputs: List = [] - with torch.no_grad(): - for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar): - output = self.forward(batch, **forward_params) - processed_output = self.taskmodule.unbatch_output(output) - model_outputs.extend(processed_output) + processed_documents: Union[Sequence[Document], Dataset] + if isinstance(documents, Dataset): + in_place: bool = postprocess_params.get("inplace", True) + if in_place: + raise InplaceNotSupportedException( + "Datasets can't be modified in place. Please set inplace=False." + ) + # do not show inner progress bar + forward_params["show_progress_bar"] = False + + # For now, we do not allow caching for pipeline results since fingerprinting may be incorrect + # TODO: elaborate why it may be incorrect, see https://huggingface.co/docs/datasets/about_cache + was_caching_enabled = is_caching_enabled() + disable_caching() + try: + processed_documents = documents.map( + self._process_documents, + fn_kwargs=dict( + preprocess_params=preprocess_params, + dataloader_params=dataloader_params, + forward_params=forward_params, + postprocess_params=postprocess_params, + ), + result_document_type=documents.document_type, + batched=True, + ) + finally: + if was_caching_enabled: + enable_caching() - assert len(model_inputs) == len( - model_outputs - ), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]" + else: + processed_documents = self._process_documents( + documents=documents, + preprocess_params=preprocess_params, + dataloader_params=dataloader_params, + forward_params=forward_params, + postprocess_params=postprocess_params, + ) - documents = self.postprocess( - model_inputs=model_inputs, - model_outputs=model_outputs, - **postprocess_params, - ) if single_document: - return documents[0] + # TODO: fix "type: ignore" (if processed_documents is a Dataset, mypy assumes the result is Dict[Any, Any]) + processed_document: Document = processed_documents[0] # type: ignore + return processed_document else: - return documents + return processed_documents diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dbdc2adb..9eb83154 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -145,3 +145,15 @@ def test_pipeline_with_dataset(dataset, prepared_taskmodule, mock_model, inplace assert not (id(returned_document) == id(document)) assert not document.entities.predictions assert returned_document.entities.predictions + + +@pytest.mark.slow +def test_pipeline_with_dataset_never_cached(dataset, prepared_taskmodule, mock_model): + train_dataset = dataset["train"] + + pipeline = Pipeline(model=mock_model, taskmodule=prepared_taskmodule, device=-1, inplace=False) + + returned_documents1 = pipeline(train_dataset, predict_field="entities") + returned_documents2 = pipeline(train_dataset, predict_field="entities") + + assert returned_documents1._fingerprint != returned_documents2._fingerprint