Skip to content

Commit

Permalink
fixup! feat(doip-discover): Greatly speed up DoIP discover
Browse files Browse the repository at this point in the history
  • Loading branch information
ferdinandjarisch committed Dec 4, 2024
1 parent 6e61b58 commit 9704227
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 62 deletions.
73 changes: 12 additions & 61 deletions src/gallia/commands/discover/doip.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ async def enumerate_target_addresses(
search_space = range(start, stop + 1)

target_template = f"doip://{tgt_hostname}:{tgt_port}?protocol_version={self.protocol_version}&activation_type={correct_rat:#x}&src_addr={correct_src:#x}&target_addr={{:#x}}"
conn = await self.create_DoIP_conn(tgt_hostname, tgt_port, correct_rat, correct_src, 0xAFFE)
conn = await self.create_DoIP_conn(tgt_hostname, tgt_port, correct_rat, correct_src, 0xAFFE, fast_queue=True)
reader_task = asyncio.create_task(self.task_read_diagnostic_messages(conn, target_template))

for target_addr in search_space:
Expand Down Expand Up @@ -413,13 +413,14 @@ async def enumerate_target_addresses(
)

async def task_read_diagnostic_messages(
self, conn: FastDoIPConnection, target_template: str
self, conn: DoIPConnection, target_template: str
) -> None:
responsive_targets = []
potential_broadcasts = []
try:
while True:
source_address, data = await conn.read_diag_request_custom()
_, payload = await conn.read_diag_request_raw()
(source_address, data) = (payload.SourceAddress, payload.UserData)
current_target = target_template.format(source_address)

resp = TesterPresentResponse.parse_static(data)
Expand Down Expand Up @@ -453,9 +454,10 @@ async def task_read_diagnostic_messages(
logger.notice(item)

# TODO: the discoverer could be extended to search for and validate the broadcast address(es) automatically
logger.notice(
"[🕵️] You could also investigate these target addresses that appear to be near broadcasts:"
)
if len(potential_broadcasts) > 0:
logger.notice(
"[🕵️] You could also investigate these target addresses that appear to be near broadcasts:"
)
for target_addr in potential_broadcasts:
logger.notice(f"[🤑] B-B-B-B-B-B-BROADCAST around TargetAddress {target_addr:#x}!")

Expand All @@ -466,16 +468,18 @@ async def create_DoIP_conn(
routing_activation_type: int,
src_addr: int,
target_addr: int,
) -> FastDoIPConnection: # noqa: PLR0913
fast_queue: bool = False,
) -> DoIPConnection: # noqa: PLR0913
while True:
try: # Ensure that connections do not remain in TIME_WAIT
conn = await FastDoIPConnection.connect(
conn = await DoIPConnection.connect(
hostname,
port,
src_addr,
target_addr,
so_linger=True,
protocol_version=self.protocol_version,
separate_diagnostic_message_queue=fast_queue,
)
logger.info("[📫] Sending RoutingActivationRequest")
await conn.write_routing_activation_request(
Expand Down Expand Up @@ -543,56 +547,3 @@ async def run_udp_discovery(self) -> list[tuple[str, int]]:
)

return found


class FastDoIPConnection(DoIPConnection):
# This code is pretty much copied from the original except for the special treatment of DiagnosticMessages which
# are placed in the separate queue ._diagnostic_message_queue which can be read with .read_diag_request_custom()

def __init__( # noqa: PLR0913
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
src_addr: int,
target_addr: int,
protocol_version: int,
):
self._diagnostic_message_queue: asyncio.Queue[DoIPFrame] = asyncio.Queue()
super().__init__(
reader,
writer,
src_addr,
target_addr,
protocol_version,
)

async def _read_worker(self) -> None:
try:
while True:
hdr, data = await self._read_frame()
if hdr is None or data is None:
continue
if hdr.PayloadType == PayloadTypes.AliveCheckRequest:
await self.write_alive_check_response()
continue
if hdr.PayloadType == PayloadTypes.DiagnosticMessage:
await self._diagnostic_message_queue.put((hdr, data))
else:
await self._read_queue.put((hdr, data))
except asyncio.CancelledError:
logger.debug("DoIP read worker got cancelled")
except asyncio.IncompleteReadError as e:
logger.debug(f"DoIP read worker received EOF: {e!r}")
except Exception as e:
logger.info(f"DoIP read worker died with {e!r}")
finally:
logger.debug("Feeding EOF to reader and requesting a close")
self.reader.feed_eof()
await self.close()

async def read_diag_request_custom(self) -> tuple[int, bytes]:
hdr, payload = await self._diagnostic_message_queue.get()
if not isinstance(payload, DiagnosticMessage):
logger.warning(f"[🧨] Unexpected DoIP message: {hdr} {payload}")
return 0, b""
return payload.SourceAddress, payload.UserData
11 changes: 10 additions & 1 deletion src/gallia/transports/doip.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,15 @@ def __init__( # noqa: PLR0913
src_addr: int,
target_addr: int,
protocol_version: int,
separate_diagnostic_message_queue: bool = False,
):
self.reader = reader
self.writer = writer
self.src_addr = src_addr
self.target_addr = target_addr
self.protocol_version = protocol_version
self.separate_diagnostic_message_queue = separate_diagnostic_message_queue
self._diagnostic_message_queue: asyncio.Queue[DoIPDiagFrame] = asyncio.Queue()
self._read_queue: asyncio.Queue[DoIPFrame] = asyncio.Queue()
self._read_task = asyncio.create_task(self._read_worker())
self._read_task.add_done_callback(
Expand All @@ -494,6 +497,7 @@ async def connect( # noqa: PLR0913
target_addr: int,
so_linger: bool = False,
protocol_version: int = ProtocolVersions.ISO_13400_2_2019,
separate_diagnostic_message_queue: bool = False,
) -> Self:
reader, writer = await asyncio.open_connection(host, port)

Expand All @@ -508,7 +512,7 @@ async def connect( # noqa: PLR0913
sock = writer.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0))

return cls(reader, writer, src_addr, target_addr, protocol_version)
return cls(reader, writer, src_addr, target_addr, protocol_version, separate_diagnostic_message_queue)

async def _read_frame(self) -> DoIPFrame | tuple[None, None]:
# Header is fixed size 8 byte.
Expand Down Expand Up @@ -547,6 +551,9 @@ async def _read_worker(self) -> None:
if hdr.PayloadType == PayloadTypes.AliveCheckRequest:
await self.write_alive_check_response()
continue
if isinstance(data, DiagnosticMessage) and self.separate_diagnostic_message_queue:
await self._diagnostic_message_queue.put((hdr, data))
continue
await self._read_queue.put((hdr, data))
except asyncio.CancelledError:
logger.debug("DoIP read worker got cancelled")
Expand All @@ -573,6 +580,8 @@ async def read_frame(self) -> DoIPFrame:
async def read_diag_request_raw(self) -> DoIPDiagFrame:
unexpected_packets: list[tuple[Any, Any]] = []
while True:
if self.separate_diagnostic_message_queue:
return await self._diagnostic_message_queue.get()
hdr, payload = await self.read_frame()
if not isinstance(payload, DiagnosticMessage):
logger.warning(f"expected DoIP DiagnosticMessage, instead got: {hdr} {payload}")
Expand Down

0 comments on commit 9704227

Please sign in to comment.