Skip to content

Commit

Permalink
fix disable caching for pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 17, 2023
1 parent 734fc3c commit 6290566
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader
from transformers.utils import ModelOutput

from datasets import is_caching_enabled
from datasets import is_caching_enabled, disable_caching, enable_caching
from pytorch_ie.core.document import Document
from pytorch_ie.core.model import PyTorchIEModel
from pytorch_ie.core.taskmodule import (
Expand Down Expand Up @@ -394,21 +394,26 @@ def __call__(
# do not show inner progress bar
forward_params["show_progress_bar"] = False

processed_documents = documents.map(
self._process_documents,
fn_kwargs=dict(
preprocess_params=preprocess_params,
dataloader_params=dataloader_params,
forward_params=forward_params,
postprocess_params=postprocess_params,
),
batched=True,
**dataset_map_params,
)
# For now, we do not allow caching of pipeline results since fingerprinting may be incorrect
# TODO: elaborate why it may be incorrect
if is_caching_enabled() and documents._fingerprint == processed_documents._fingerprint:
raise Exception("Caching is not allowed for pipeline calls")
# For now, we do not allow caching for pipeline results since fingerprinting may be incorrect
# TODO: elaborate why it may be incorrect, see https://huggingface.co/docs/datasets/about_cache
was_caching_enabled = is_caching_enabled()
disable_caching()
try:
processed_documents = documents.map(
self._process_documents,
fn_kwargs=dict(
preprocess_params=preprocess_params,
dataloader_params=dataloader_params,
forward_params=forward_params,
postprocess_params=postprocess_params,
),
batched=True,
**dataset_map_params,
)
finally:
if was_caching_enabled:
enable_caching()

else:
processed_documents = self._process_documents(
documents=documents,
Expand Down

0 comments on commit 6290566

Please sign in to comment.