diff --git a/src/gallia/commands/discover/doip.py b/src/gallia/commands/discover/doip.py index f0e411eda..ee64e1811 100644 --- a/src/gallia/commands/discover/doip.py +++ b/src/gallia/commands/discover/doip.py @@ -326,7 +326,7 @@ async def enumerate_target_addresses( search_space = range(start, stop + 1) target_template = f"doip://{tgt_hostname}:{tgt_port}?protocol_version={self.protocol_version}&activation_type={correct_rat:#x}&src_addr={correct_src:#x}&target_addr={{:#x}}" - conn = await self.create_DoIP_conn(tgt_hostname, tgt_port, correct_rat, correct_src, 0xAFFE) + conn = await self.create_DoIP_conn(tgt_hostname, tgt_port, correct_rat, correct_src, 0xAFFE, fast_queue=True) reader_task = asyncio.create_task(self.task_read_diagnostic_messages(conn, target_template)) for target_addr in search_space: @@ -413,13 +413,14 @@ async def enumerate_target_addresses( ) async def task_read_diagnostic_messages( - self, conn: FastDoIPConnection, target_template: str + self, conn: DoIPConnection, target_template: str ) -> None: responsive_targets = [] potential_broadcasts = [] try: while True: - source_address, data = await conn.read_diag_request_custom() + _, payload = await conn.read_diag_request_raw() + (source_address, data) = (payload.SourceAddress, payload.UserData) current_target = target_template.format(source_address) resp = TesterPresentResponse.parse_static(data) @@ -453,9 +454,10 @@ async def task_read_diagnostic_messages( logger.notice(item) # TODO: the discoverer could be extended to search for and validate the broadcast address(es) automatically - logger.notice( - "[๐Ÿ•ต๏ธ] You could also investigate these target addresses that appear to be near broadcasts:" - ) + if len(potential_broadcasts) > 0: + logger.notice( + "[๐Ÿ•ต๏ธ] You could also investigate these target addresses that appear to be near broadcasts:" + ) for target_addr in potential_broadcasts: logger.notice(f"[๐Ÿค‘] B-B-B-B-B-B-BROADCAST around TargetAddress {target_addr:#x}!") @@ -466,16 +468,18 @@ async def create_DoIP_conn( routing_activation_type: int, src_addr: int, target_addr: int, - ) -> FastDoIPConnection: # noqa: PLR0913 + fast_queue: bool = False, + ) -> DoIPConnection: # noqa: PLR0913 while True: try: # Ensure that connections do not remain in TIME_WAIT - conn = await FastDoIPConnection.connect( + conn = await DoIPConnection.connect( hostname, port, src_addr, target_addr, so_linger=True, protocol_version=self.protocol_version, + separate_diagnostic_message_queue=fast_queue, ) logger.info("[๐Ÿ“ซ] Sending RoutingActivationRequest") await conn.write_routing_activation_request( @@ -543,56 +547,3 @@ async def run_udp_discovery(self) -> list[tuple[str, int]]: ) return found - - -class FastDoIPConnection(DoIPConnection): - # This code is pretty much copied from the original except for the special treatment of DiagnosticMessages which - # are placed in the separate queue ._diagnostic_message_queue which can be read with .read_diag_request_custom() - - def __init__( # noqa: PLR0913 - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - src_addr: int, - target_addr: int, - protocol_version: int, - ): - self._diagnostic_message_queue: asyncio.Queue[DoIPFrame] = asyncio.Queue() - super().__init__( - reader, - writer, - src_addr, - target_addr, - protocol_version, - ) - - async def _read_worker(self) -> None: - try: - while True: - hdr, data = await self._read_frame() - if hdr is None or data is None: - continue - if hdr.PayloadType == PayloadTypes.AliveCheckRequest: - await self.write_alive_check_response() - continue - if hdr.PayloadType == PayloadTypes.DiagnosticMessage: - await self._diagnostic_message_queue.put((hdr, data)) - else: - await self._read_queue.put((hdr, data)) - except asyncio.CancelledError: - logger.debug("DoIP read worker got cancelled") - except asyncio.IncompleteReadError as e: - logger.debug(f"DoIP read worker received EOF: {e!r}") - except Exception as e: - logger.info(f"DoIP read worker died with {e!r}") - finally: - logger.debug("Feeding EOF to reader and requesting a close") - self.reader.feed_eof() - await self.close() - - async def read_diag_request_custom(self) -> tuple[int, bytes]: - hdr, payload = await self._diagnostic_message_queue.get() - if not isinstance(payload, DiagnosticMessage): - logger.warning(f"[๐Ÿงจ] Unexpected DoIP message: {hdr} {payload}") - return 0, b"" - return payload.SourceAddress, payload.UserData diff --git a/src/gallia/transports/doip.py b/src/gallia/transports/doip.py index b71c55b46..0489aacda 100644 --- a/src/gallia/transports/doip.py +++ b/src/gallia/transports/doip.py @@ -470,12 +470,15 @@ def __init__( # noqa: PLR0913 src_addr: int, target_addr: int, protocol_version: int, + separate_diagnostic_message_queue: bool = False, ): self.reader = reader self.writer = writer self.src_addr = src_addr self.target_addr = target_addr self.protocol_version = protocol_version + self.separate_diagnostic_message_queue = separate_diagnostic_message_queue + self._diagnostic_message_queue: asyncio.Queue[DoIPDiagFrame] = asyncio.Queue() self._read_queue: asyncio.Queue[DoIPFrame] = asyncio.Queue() self._read_task = asyncio.create_task(self._read_worker()) self._read_task.add_done_callback( @@ -494,6 +497,7 @@ async def connect( # noqa: PLR0913 target_addr: int, so_linger: bool = False, protocol_version: int = ProtocolVersions.ISO_13400_2_2019, + separate_diagnostic_message_queue: bool = False, ) -> Self: reader, writer = await asyncio.open_connection(host, port) @@ -508,7 +512,7 @@ async def connect( # noqa: PLR0913 sock = writer.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)) - return cls(reader, writer, src_addr, target_addr, protocol_version) + return cls(reader, writer, src_addr, target_addr, protocol_version, separate_diagnostic_message_queue) async def _read_frame(self) -> DoIPFrame | tuple[None, None]: # Header is fixed size 8 byte. @@ -547,6 +551,9 @@ async def _read_worker(self) -> None: if hdr.PayloadType == PayloadTypes.AliveCheckRequest: await self.write_alive_check_response() continue + if isinstance(data, DiagnosticMessage) and self.separate_diagnostic_message_queue: + await self._diagnostic_message_queue.put((hdr, data)) + continue await self._read_queue.put((hdr, data)) except asyncio.CancelledError: logger.debug("DoIP read worker got cancelled") @@ -573,6 +580,8 @@ async def read_frame(self) -> DoIPFrame: async def read_diag_request_raw(self) -> DoIPDiagFrame: unexpected_packets: list[tuple[Any, Any]] = [] while True: + if self.separate_diagnostic_message_queue: + return await self._diagnostic_message_queue.get() hdr, payload = await self.read_frame() if not isinstance(payload, DiagnosticMessage): logger.warning(f"expected DoIP DiagnosticMessage, instead got: {hdr} {payload}")