Skip to content

Commit

Permalink
do not allow parameters for documents.map to simplify pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 17, 2023
1 parent 3d7c1c0 commit 47ca43b
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(
self._dataloader_params,
self._forward_params,
self._postprocess_params,
self._dataset_map_params,
) = self._sanitize_parameters(**kwargs)

def save_pretrained(self, save_directory: str):
Expand Down Expand Up @@ -167,7 +166,7 @@ def _ensure_tensor_on_device(self, inputs, device):

def _sanitize_parameters(
self, **pipeline_parameters
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
"""
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
Expand All @@ -181,7 +180,6 @@ def _sanitize_parameters(
dataloader_params = {}
forward_parameters = {}
postprocess_parameters: Dict[str, Any] = {}
dataset_map_parameters = {}

# set preprocess parameters
for p_name in ["document_batch_size"]:
Expand All @@ -203,16 +201,11 @@ def _sanitize_parameters(
if p_name in pipeline_parameters:
postprocess_parameters[p_name] = pipeline_parameters[p_name]

for p_name in ["document_batch_size"]:
if p_name in pipeline_parameters:
dataset_map_parameters["batch_size"] = pipeline_parameters[p_name]

return (
preprocess_parameters,
dataloader_params,
forward_parameters,
postprocess_parameters,
dataset_map_parameters,
)

def preprocess(
Expand Down Expand Up @@ -356,7 +349,6 @@ def __call__(
dataloader_params,
forward_params,
postprocess_params,
dataset_map_params,
) = self._sanitize_parameters(**kwargs)

if "TOKENIZERS_PARALLELISM" not in os.environ:
Expand All @@ -370,7 +362,6 @@ def __call__(
dataloader_params = {**self._dataloader_params, **dataloader_params}
forward_params = {**self._forward_params, **forward_params}
postprocess_params = {**self._postprocess_params, **postprocess_params}
dataset_map_params = {**self._dataset_map_params, **dataset_map_params}

self.call_count += 1
if self.call_count > 10 and self.device.type == "cuda":
Expand Down Expand Up @@ -408,7 +399,6 @@ def __call__(
postprocess_params=postprocess_params,
),
batched=True,
**dataset_map_params,
)
finally:
if was_caching_enabled:
Expand Down

0 comments on commit 47ca43b

Please sign in to comment.