Skip to content

Commit

Permalink
allow string argument for pipeline device
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 5, 2024
1 parent 3145b17 commit 0c3e5e4
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/pytorch_ie/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,17 @@ def __init__(
self,
model: PyTorchIEModel,
taskmodule: TaskModule,
# args_parser: ArgumentHandler = None,
device: int = -1,
device: Union[int, str] = "cpu",
binary_output: bool = False,
**kwargs,
):
self.taskmodule = taskmodule
self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
device_str = (
torch.device("cpu" if device < 0 else f"cuda:{device}")
if isinstance(device, int)
else device
)
self.device = torch.device(device_str)
self.binary_output = binary_output

# Module.to() returns just self, but moved to the device. This is not correctly
Expand Down

0 comments on commit 0c3e5e4

Please sign in to comment.