diff --git a/src/pytorch_ie/pipeline.py b/src/pytorch_ie/pipeline.py index 27969133..9bc4f222 100644 --- a/src/pytorch_ie/pipeline.py +++ b/src/pytorch_ie/pipeline.py @@ -70,7 +70,6 @@ def __init__( self._dataloader_params, self._forward_params, self._postprocess_params, - self._dataset_map_params, ) = self._sanitize_parameters(**kwargs) def save_pretrained(self, save_directory: str): @@ -167,7 +166,7 @@ def _ensure_tensor_on_device(self, inputs, device): def _sanitize_parameters( self, **pipeline_parameters - ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]: """ _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__` methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`, @@ -181,7 +180,6 @@ def _sanitize_parameters( dataloader_params = {} forward_parameters = {} postprocess_parameters: Dict[str, Any] = {} - dataset_map_parameters = {} # set preprocess parameters for p_name in ["document_batch_size"]: @@ -203,16 +201,11 @@ def _sanitize_parameters( if p_name in pipeline_parameters: postprocess_parameters[p_name] = pipeline_parameters[p_name] - for p_name in ["document_batch_size"]: - if p_name in pipeline_parameters: - dataset_map_parameters["batch_size"] = pipeline_parameters[p_name] - return ( preprocess_parameters, dataloader_params, forward_parameters, postprocess_parameters, - dataset_map_parameters, ) def preprocess( @@ -356,7 +349,6 @@ def __call__( dataloader_params, forward_params, postprocess_params, - dataset_map_params, ) = self._sanitize_parameters(**kwargs) if "TOKENIZERS_PARALLELISM" not in os.environ: @@ -370,7 +362,6 @@ def __call__( dataloader_params = {**self._dataloader_params, **dataloader_params} forward_params = {**self._forward_params, **forward_params} postprocess_params = {**self._postprocess_params, **postprocess_params} - dataset_map_params = {**self._dataset_map_params, **dataset_map_params} self.call_count += 1 if self.call_count > 10 and self.device.type == "cuda": @@ -408,7 +399,6 @@ def __call__( postprocess_params=postprocess_params, ), batched=True, - **dataset_map_params, ) finally: if was_caching_enabled: