Skip to content

Commit

Permalink
allow string argument for Pipeline device and remove multi call w…
Browse files Browse the repository at this point in the history
…arning (#435)

* allow string argument for pipeline device

* remove warning when called multiple times on gpu

* fix

* make pre-commit happy
  • Loading branch information
ArneBinder authored Nov 5, 2024
1 parent 3145b17 commit 7c63106
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ def __init__(
self,
model: PyTorchIEModel,
taskmodule: TaskModule,
# args_parser: ArgumentHandler = None,
device: int = -1,
device: Union[int, str] = "cpu",
binary_output: bool = False,
**kwargs,
):
self.taskmodule = taskmodule
self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
device_str = (
("cpu" if device < 0 else f"cuda:{device}") if isinstance(device, int) else device
)
self.device = torch.device(device_str)
self.binary_output = binary_output

# Module.to() returns just self, but moved to the device. This is not correctly
Expand Down Expand Up @@ -324,13 +326,6 @@ def __call__(
forward_params = {**self._forward_params, **forward_params}
postprocess_params = {**self._postprocess_params, **postprocess_params}

self.call_count += 1
if self.call_count > 10 and self.device.type == "cuda":
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset",
UserWarning,
)

single_document = False
if isinstance(documents, Document):
single_document = True
Expand Down

0 comments on commit 7c63106

Please sign in to comment.