Skip to content

Commit

Permalink
Add a timeout to the os.read function (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
Miauwkeru authored Nov 26, 2024
1 parent ef2e3d1 commit c284e87
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
37 changes: 35 additions & 2 deletions acquire/volatilestream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from concurrent import futures
from io import SEEK_SET, UnsupportedOperation
from pathlib import Path
from stat import S_IRGRP, S_IROTH, S_IRUSR
from typing import Any, Callable

from dissect.util.stream import AlignedStream

Expand All @@ -14,6 +16,35 @@
HAS_FCNTL = False


def timeout(func: Callable, *, timelimit: int) -> Callable:
"""Timeout a function if it takes too long to complete.
Args:
func: a function to wrap.
timelimit: The time in seconds that an operation is allowed to run.
Raises:
TimeoutError: If its time exceeds the timelimit
"""

def wrapper(*args: Any, **kwargs: Any) -> Any:
with futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(func, *args, **kwargs)

try:
result = future.result(timelimit)
except futures.TimeoutError:
raise TimeoutError
finally:
# Make sure the thread stops right away.
executor._threads.clear()
futures.thread._threads_queues.clear()

return result

return wrapper


class VolatileStream(AlignedStream):
"""Streaming class to handle various procfs and sysfs edge-cases. Backed by `AlignedStream`.
Expand Down Expand Up @@ -41,6 +72,8 @@ def __init__(
st_mode = os.fstat(self.fd).st_mode
write_only = (st_mode & (S_IRUSR | S_IRGRP | S_IROTH)) == 0 # novermin

self._os_read = timeout(os.read, timelimit=5)

super().__init__(0 if write_only else size)

def seek(self, pos: int, whence: int = SEEK_SET) -> int:
Expand All @@ -53,8 +86,8 @@ def _read(self, offset: int, length: int) -> bytes:
result = []
while length:
try:
buf = os.read(self.fd, min(length, self.size - offset))
except BlockingIOError:
buf = self._os_read(self.fd, min(length, self.size - offset))
except (BlockingIOError, TimeoutError):
break

if not buf:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_volatile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from time import sleep, time

import pytest

from acquire.volatilestream import timeout


def test_timeout():
def snooze():
sleep(10)

function = timeout(snooze, timelimit=5)
start = time()

with pytest.raises(TimeoutError):
function()

end = time()

assert end - start < 6

0 comments on commit c284e87

Please sign in to comment.