From a47d2dc379261b5c0da3f973db69c1c76cad7d91 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sat, 7 May 2022 19:24:45 +0200 Subject: [PATCH 1/9] use dataset.map in pipeline --- src/pytorch_ie/pipeline.py | 124 ++++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 37 deletions(-) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 1be37d5c..945f8bae 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -69,6 +69,7 @@ def __init__( self._dataloader_params, self._forward_params, self._postprocess_params, + self._dataset_map_params, ) = self._sanitize_parameters(**kwargs) def save_pretrained(self, save_directory: str): @@ -165,7 +166,7 @@ def _ensure_tensor_on_device(self, inputs, device): def _sanitize_parameters( self, **pipeline_parameters - ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: """ _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__` methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`, @@ -179,6 +180,7 @@ def _sanitize_parameters( dataloader_params = {} forward_parameters = {} postprocess_parameters: Dict[str, Any] = {} + dataset_map_parameters = {} # set preprocess parameters for p_name in ["document_batch_size"]: @@ -200,7 +202,17 @@ def _sanitize_parameters( if p_name in pipeline_parameters: postprocess_parameters[p_name] = pipeline_parameters[p_name] - return preprocess_parameters, dataloader_params, forward_parameters, postprocess_parameters + for p_name in ["document_batch_size"]: + if p_name in pipeline_parameters: + dataset_map_parameters["batch_size"] = pipeline_parameters[p_name] + + return ( + preprocess_parameters, + dataloader_params, + forward_parameters, + postprocess_parameters, + dataset_map_parameters, + ) def preprocess( self, @@ -292,12 +304,50 @@ def get_dataloader( return dataloader + def _process_documents( + self, + documents: Sequence[Document], + preprocess_params: Dict[str, Any], + dataloader_params: Dict[str, Any], + forward_params: Dict[str, Any], + postprocess_params: Dict[str, Any], + ) -> Sequence[Document]: + # This creates encodings from the documents. It modifies the documents and may produce multiple entries per + # document. + model_inputs = self.preprocess(documents, **preprocess_params) + if forward_params.pop("fast_dev_run", False): + warnings.warn( + "Execute a fast dev run, only the first two model inputs will be processed." + ) + model_inputs = model_inputs[:2] + # Create a dataloader from the model inputs. This uses taskmodule.collate(). + dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params) + + show_progress_bar = forward_params.pop("show_progress_bar", False) + model_outputs: List = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar): + output = self.forward(batch, **forward_params) + processed_output = self.taskmodule.unbatch_output(output) + model_outputs.extend(processed_output) + + assert len(model_inputs) == len( + model_outputs + ), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]" + + documents = self.postprocess( + model_inputs=model_inputs, + model_outputs=model_outputs, + **postprocess_params, + ) + return documents + def __call__( self, documents: Union[Document, Sequence[Document], Dataset], *args, **kwargs, - ) -> Union[Document, Sequence[Document]]: + ) -> Union[Document, Sequence[Document], Dataset]: if args: logger.warning(f"Ignoring args : {args}") ( @@ -305,14 +355,9 @@ def __call__( dataloader_params, forward_params, postprocess_params, + dataset_map_params, ) = self._sanitize_parameters(**kwargs) - in_place: bool = postprocess_params.get("inplace", True) - if in_place and isinstance(documents, Dataset): - raise InplaceNotSupportedException( - "Datasets can't be modified in place. Please set inplace=False." - ) - if "TOKENIZERS_PARALLELISM" not in os.environ: logger.info( "Disabling tokenizer parallelism, we're using DataLoader multithreading already" @@ -324,6 +369,7 @@ def __call__( dataloader_params = {**self._dataloader_params, **dataloader_params} forward_params = {**self._forward_params, **forward_params} postprocess_params = {**self._postprocess_params, **postprocess_params} + dataset_map_params = {**self._dataset_map_params, **dataset_map_params} self.call_count += 1 if self.call_count > 10 and self.device.type == "cuda": @@ -337,35 +383,39 @@ def __call__( single_document = True documents = [documents] - # This creates encodings from the documents. It modifies the documents and may produce multiple entries per - # document. - model_inputs = self.preprocess(documents, **preprocess_params) - if forward_params.pop("fast_dev_run", False): - warnings.warn( - "Execute a fast dev run, only the first two model inputs will be processed." + processed_documents: Union[Sequence[Document], Dataset] + if isinstance(documents, Dataset): + in_place: bool = postprocess_params.get("inplace", True) + if in_place: + raise InplaceNotSupportedException( + "Datasets can't be modified in place. Please set inplace=False." + ) + # do not show inner progress bar + forward_params["show_progress_bar"] = False + + processed_documents = documents.map( + self._process_documents, + fn_kwargs=dict( + preprocess_params=preprocess_params, + dataloader_params=dataloader_params, + forward_params=forward_params, + postprocess_params=postprocess_params, + ), + batched=True, + **dataset_map_params, + ) + else: + processed_documents = self._process_documents( + documents=documents, + preprocess_params=preprocess_params, + dataloader_params=dataloader_params, + forward_params=forward_params, + postprocess_params=postprocess_params, ) - model_inputs = model_inputs[:2] - # Create a dataloader from the model inputs. This uses taskmodule.collate(). - dataloader = self.get_dataloader(model_inputs=model_inputs, **dataloader_params) - - show_progress_bar = forward_params.pop("show_progress_bar", False) - model_outputs: List = [] - with torch.no_grad(): - for batch in tqdm.tqdm(dataloader, desc="inference", disable=not show_progress_bar): - output = self.forward(batch, **forward_params) - processed_output = self.taskmodule.unbatch_output(output) - model_outputs.extend(processed_output) - - assert len(model_inputs) == len( - model_outputs - ), f"length mismatch: len(model_inputs) [{len(model_inputs)}] != len(model_outputs) [{len(model_outputs)}]" - documents = self.postprocess( - model_inputs=model_inputs, - model_outputs=model_outputs, - **postprocess_params, - ) if single_document: - return documents[0] + # TODO: fix "type: ignore" (if processed_documents is a Dataset, mypy assumes the result is Dict[Any, Any]) + processed_document: Document = processed_documents[0] # type: ignore + return processed_document else: - return documents + return processed_documents From 8668486f1a52851f04e39cdd79766e81fe1b6af7 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 8 May 2022 14:47:25 +0200 Subject: [PATCH 2/9] add test_pipeline_with_dataset_never_cached --- tests/test_pipeline.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dbdc2adb..3d58e1a7 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -145,3 +145,16 @@ def test_pipeline_with_dataset(dataset, prepared_taskmodule, mock_model, inplace assert not (id(returned_document) == id(document)) assert not document.entities.predictions assert returned_document.entities.predictions + + +@pytest.mark.slow +def test_pipeline_with_dataset_never_cached(dataset, prepared_taskmodule, mock_model): + + train_dataset = dataset["train"] + + pipeline = Pipeline(model=mock_model, taskmodule=prepared_taskmodule, device=-1, inplace=False) + + returned_documents1 = pipeline(train_dataset, predict_field="entities") + returned_documents2 = pipeline(train_dataset, predict_field="entities") + + assert returned_documents1._fingerprint != returned_documents2._fingerprint From 734fc3cdf91e164efe1440b8353e1c541c81ea8d Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 8 May 2022 14:52:13 +0200 Subject: [PATCH 3/9] raise exception when datasets would like to cache pipeline result --- src/pytorch_ie/pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 945f8bae..b34195ad 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader from transformers.utils import ModelOutput +from datasets import is_caching_enabled from pytorch_ie.core.document import Document from pytorch_ie.core.model import PyTorchIEModel from pytorch_ie.core.taskmodule import ( @@ -404,6 +405,10 @@ def __call__( batched=True, **dataset_map_params, ) + # For now, we do not allow caching of pipeline results since fingerprinting may be incorrect + # TODO: elaborate why it may be incorrect + if is_caching_enabled() and documents._fingerprint == processed_documents._fingerprint: + raise Exception("Caching is not allowed for pipeline calls") else: processed_documents = self._process_documents( documents=documents, From 6290566f9a593bccd7baa759bb5f7270649bf1b3 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 9 May 2022 11:56:36 +0200 Subject: [PATCH 4/9] fix disable caching for pipeline --- src/pytorch_ie/pipeline.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index b34195ad..19916d7b 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from transformers.utils import ModelOutput -from datasets import is_caching_enabled +from datasets import is_caching_enabled, disable_caching, enable_caching from pytorch_ie.core.document import Document from pytorch_ie.core.model import PyTorchIEModel from pytorch_ie.core.taskmodule import ( @@ -394,21 +394,26 @@ def __call__( # do not show inner progress bar forward_params["show_progress_bar"] = False - processed_documents = documents.map( - self._process_documents, - fn_kwargs=dict( - preprocess_params=preprocess_params, - dataloader_params=dataloader_params, - forward_params=forward_params, - postprocess_params=postprocess_params, - ), - batched=True, - **dataset_map_params, - ) - # For now, we do not allow caching of pipeline results since fingerprinting may be incorrect - # TODO: elaborate why it may be incorrect - if is_caching_enabled() and documents._fingerprint == processed_documents._fingerprint: - raise Exception("Caching is not allowed for pipeline calls") + # For now, we do not allow caching for pipeline results since fingerprinting may be incorrect + # TODO: elaborate why it may be incorrect, see https://huggingface.co/docs/datasets/about_cache + was_caching_enabled = is_caching_enabled() + disable_caching() + try: + processed_documents = documents.map( + self._process_documents, + fn_kwargs=dict( + preprocess_params=preprocess_params, + dataloader_params=dataloader_params, + forward_params=forward_params, + postprocess_params=postprocess_params, + ), + batched=True, + **dataset_map_params, + ) + finally: + if was_caching_enabled: + enable_caching() + else: processed_documents = self._process_documents( documents=documents, From 3d7c1c03c4547d07ea72e0e6dfe677122dfc8d06 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 9 May 2022 11:57:08 +0200 Subject: [PATCH 5/9] isort --- src/pytorch_ie/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 19916d7b..27969133 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from transformers.utils import ModelOutput -from datasets import is_caching_enabled, disable_caching, enable_caching +from datasets import disable_caching, enable_caching, is_caching_enabled from pytorch_ie.core.document import Document from pytorch_ie.core.model import PyTorchIEModel from pytorch_ie.core.taskmodule import ( From 47ca43b1278eb2158ecc822c1c0f87b320606c7c Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 9 May 2022 12:42:46 +0200 Subject: [PATCH 6/9] do not allow parameters for documents.map to simplify pipeline --- src/pytorch_ie/pipeline.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 27969133..9bc4f222 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -70,7 +70,6 @@ def __init__( self._dataloader_params, self._forward_params, self._postprocess_params, - self._dataset_map_params, ) = self._sanitize_parameters(**kwargs) def save_pretrained(self, save_directory: str): @@ -167,7 +166,7 @@ def _ensure_tensor_on_device(self, inputs, device): def _sanitize_parameters( self, **pipeline_parameters - ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: """ _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__` methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`, @@ -181,7 +180,6 @@ def _sanitize_parameters( dataloader_params = {} forward_parameters = {} postprocess_parameters: Dict[str, Any] = {} - dataset_map_parameters = {} # set preprocess parameters for p_name in ["document_batch_size"]: @@ -203,16 +201,11 @@ def _sanitize_parameters( if p_name in pipeline_parameters: postprocess_parameters[p_name] = pipeline_parameters[p_name] - for p_name in ["document_batch_size"]: - if p_name in pipeline_parameters: - dataset_map_parameters["batch_size"] = pipeline_parameters[p_name] - return ( preprocess_parameters, dataloader_params, forward_parameters, postprocess_parameters, - dataset_map_parameters, ) def preprocess( @@ -356,7 +349,6 @@ def __call__( dataloader_params, forward_params, postprocess_params, - dataset_map_params, ) = self._sanitize_parameters(**kwargs) if "TOKENIZERS_PARALLELISM" not in os.environ: @@ -370,7 +362,6 @@ def __call__( dataloader_params = {**self._dataloader_params, **dataloader_params} forward_params = {**self._forward_params, **forward_params} postprocess_params = {**self._postprocess_params, **postprocess_params} - dataset_map_params = {**self._dataset_map_params, **dataset_map_params} self.call_count += 1 if self.call_count > 10 and self.device.type == "cuda": @@ -408,7 +399,6 @@ def __call__( postprocess_params=postprocess_params, ), batched=True, - **dataset_map_params, ) finally: if was_caching_enabled: From 85d1a3809917375d3e1763766ba51c923819b45f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 14 Feb 2023 11:39:00 +0100 Subject: [PATCH 7/9] add missing result_document_type --- src/pytorch_ie/pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 9bc4f222..64db8019 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -398,6 +398,7 @@ def __call__( forward_params=forward_params, postprocess_params=postprocess_params, ), + result_document_type=documents.document_type, batched=True, ) finally: From 6d63f08b97a9ebeaa4704529d57c4c574b0c6309 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 14 Feb 2023 11:47:23 +0100 Subject: [PATCH 8/9] make pre-commit happy --- tests/test_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 3d58e1a7..9eb83154 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -149,7 +149,6 @@ def test_pipeline_with_dataset(dataset, prepared_taskmodule, mock_model, inplace @pytest.mark.slow def test_pipeline_with_dataset_never_cached(dataset, prepared_taskmodule, mock_model): - train_dataset = dataset["train"] pipeline = Pipeline(model=mock_model, taskmodule=prepared_taskmodule, device=-1, inplace=False) From 2ddd79a1cec32678dd912dae359e44f062f0103f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 17 Sep 2023 21:56:18 +0200 Subject: [PATCH 9/9] make pre-commit happy --- src/pytorch_ie/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 64db8019..d10c291e 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -7,12 +7,12 @@ import torch import tqdm +from datasets import disable_caching, enable_caching, is_caching_enabled from packaging import version from torch import Tensor from torch.utils.data import DataLoader from transformers.utils import ModelOutput -from datasets import disable_caching, enable_caching, is_caching_enabled from pytorch_ie.core.document import Document from pytorch_ie.core.model import PyTorchIEModel from pytorch_ie.core.taskmodule import (