Skip to content

Commit

Permalink
Introduce async alternatives in the client command hierarchy
Browse files Browse the repository at this point in the history
TODO: Consider best approach here and in the topmost nonblocking
reading.
  • Loading branch information
pevogam committed Mar 21, 2023
1 parent b16da4e commit e6d28d2
Showing 1 changed file with 266 additions and 0 deletions.
266 changes: 266 additions & 0 deletions aexpect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import subprocess
import locale
import logging
import asyncio

from aexpect.exceptions import ExpectError
from aexpect.exceptions import ExpectProcessTerminatedError
Expand Down Expand Up @@ -806,6 +807,41 @@ def _read_nonblocking(self, internal_timeout=None, timeout=None):
if end_time and time.time() > end_time:
return read, data

async def _read_nonblocking_async(self, internal_timeout=None, timeout=None):
"""
Read from child until there is nothing to read for timeout seconds via a coroutine.
All arguments are identical to the regular function.
"""
if internal_timeout is None:
internal_timeout = 100
else:
internal_timeout *= 1000
end_time = None
if timeout:
end_time = time.time() + timeout
expect_pipe = self._get_fd("expect")
poller = select.poll()
poller.register(expect_pipe, select.POLLIN)
data = ""
read = 0
while True:
try:
poll_status = poller.poll(internal_timeout)
except select.error:
return read, data
if poll_status:
raw_data = os.read(expect_pipe, 1024)
if not raw_data:
return read, data
read += len(raw_data)
data += raw_data.decode(self.encoding, "ignore")
else:
return read, data
if end_time and time.time() > end_time:
return read, data
await asyncio.sleep(1)

def read_nonblocking(self, internal_timeout=None, timeout=None):
"""
Read from child until there is nothing to read for timeout seconds.
Expand Down Expand Up @@ -925,6 +961,54 @@ def read_until_output_matches(self, patterns, filter_func=lambda x: x,
# This shouldn't happen
raise ExpectError(patterns, output)

async def read_until_output_matches_async(self, patterns, filter_func=lambda x: x,
timeout=60.0, internal_timeout=None,
print_func=None, match_func=None):
"""
Read from child using read_nonblocking until a pattern matches via a coroutine.
All arguments are identical to the regular function.
"""
if not match_func:
match_func = self.match_patterns
expect_pipe = self._get_fd("expect")
poller = select.poll()
poller.register(expect_pipe, select.POLLIN)
output = ""
end_time = time.time() + timeout
while True:
try:
max_ms = int((end_time - time.time()) * 1000)
poll_timeout_ms = max(0, max_ms)
poll_status = poller.poll(poll_timeout_ms)
except select.error:
break
if not poll_status:
raise ExpectTimeoutError(patterns, output)
# Read data from child
read, data = await self._read_nonblocking_async(internal_timeout,
end_time - time.time())
if not read:
break
if not data:
continue
# Print it if necessary-
if print_func:
for line in data.splitlines():
print_func(line)
# Look for patterns
output += data
match = match_func(filter_func(output), patterns)
if match is not None:
return match, output

# Check if the child has terminated
if utils_wait.wait_for(lambda: not self.is_alive(), 5, 0, 0.1):
raise ExpectProcessTerminatedError(patterns, self.get_status(),
output)
# This shouldn't happen
raise ExpectError(patterns, output)

def read_until_last_word_matches(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Expand Down Expand Up @@ -953,6 +1037,25 @@ def _get_last_word(cont):
timeout, internal_timeout,
print_func)

async def read_until_last_word_matches_async(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Read using read_nonblocking until the last word of the output matches
one of the patterns (using match_patterns), or until timeout expires
via a coroutine.
All arguments are identical to the regular function.
"""

def _get_last_word(cont):
if cont:
return cont.split()[-1]
return ""

return await self.read_until_output_matches_async(patterns, _get_last_word,
timeout, internal_timeout,
print_func)

def read_until_last_line_matches(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Expand Down Expand Up @@ -987,6 +1090,25 @@ def _get_last_nonempty_line(cont):
timeout, internal_timeout,
print_func)

async def read_until_last_line_matches_async(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Read until the last non-empty line matches a pattern via a coroutine.
All arguments are identical to the regular function.
"""

def _get_last_nonempty_line(cont):
nonempty_lines = [_ for _ in cont.splitlines() if _.strip()]
if nonempty_lines:
return nonempty_lines[-1]
return ""

return await self.read_until_output_matches_async(patterns,
_get_last_nonempty_line,
timeout, internal_timeout,
print_func)

def read_until_any_line_matches(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Expand Down Expand Up @@ -1016,6 +1138,19 @@ def read_until_any_line_matches(self, patterns, timeout=60.0,
print_func,
self.match_patterns_multiline)

async def read_until_any_line_matches_async(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Read using read_nonblocking until any line matches a pattern via a coroutine.
All arguments are identical to the regular function.
"""
return await self.read_until_output_matches_async(patterns,
lambda x: x.splitlines(),
timeout, internal_timeout,
print_func,
self.match_patterns_multiline)


class ShellSession(Expect):

Expand Down Expand Up @@ -1170,6 +1305,18 @@ def read_up_to_prompt(self, timeout=60.0, internal_timeout=None,
internal_timeout,
print_func)[1]

async def read_up_to_prompt_async(self, timeout=60.0, internal_timeout=None,
print_func=None):
"""
Read until the last non-empty line matches the prompt via a coroutine.
All arguments are identical to the regular function.
"""
_, data = await self.read_until_last_line_matches_async([self.prompt], timeout,
internal_timeout,
print_func)
return data

def cmd_output(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Expand Down Expand Up @@ -1215,6 +1362,35 @@ def cmd_output(self, cmd, timeout=60, internal_timeout=None,
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

async def cmd_output_async(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Send a command and return its output via a coroutine.
All arguments are identical to the regular function.
"""
if safe:
return await self.cmd_output_safe_async(cmd, timeout)
session_tag = f"[{self.output_prefix}] " if self.output_prefix else ""
LOG.debug("%sSending command: %s", session_tag, cmd)
self.read_nonblocking(0, timeout)
self.sendline(cmd)
try:
out = await self.read_up_to_prompt_async(timeout, internal_timeout, print_func)
except ExpectTimeoutError as error:
output = self.remove_command_echo(error.output, cmd)
raise ShellTimeoutError(cmd, output) from error
except ExpectProcessTerminatedError as error:
output = self.remove_command_echo(error.output, cmd)
raise ShellProcessTerminatedError(cmd, error.status, output) from error
except ExpectError as error:
output = self.remove_command_echo(error.output, cmd)
raise ShellError(cmd, output) from error

# Remove the echoed command and the final shell prompt
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

def cmd_output_safe(self, cmd, timeout=60):
"""
Send a command and return its output (serial sessions).
Expand Down Expand Up @@ -1264,6 +1440,42 @@ def cmd_output_safe(self, cmd, timeout=60):
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

async def cmd_output_safe_async(self, cmd, timeout=60):
"""
Send a command and return its output (serial sessions) via a coroutine.
All arguments are identical to the regular function.
"""
session_tag = f"[{self.output_prefix}] " if self.output_prefix else ""
LOG.debug("%sSending command (safe): %s", session_tag, cmd)
self.read_nonblocking(0, timeout)
self.sendline(cmd)
out = ""
success = False
start_time = time.time()
while (time.time() - start_time) < timeout:
try:
out += await self.read_up_to_prompt_async(0.5)
success = True
break
except ExpectTimeoutError as error:
out = f"{out}{error.output}"
self.sendline()
except ExpectProcessTerminatedError as error:
output = self.remove_command_echo(f"{out}{error.output}", cmd)
raise ShellProcessTerminatedError(cmd, error.status,
output) from error
except ExpectError as error:
output = self.remove_command_echo(f"{out}{error.output}", cmd)
raise ShellError(cmd, output) from error

if not success:
raise ShellTimeoutError(cmd, out)

# Remove the echoed command and the final shell prompt
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

def cmd_status_output(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Expand Down Expand Up @@ -1304,6 +1516,28 @@ def cmd_status_output(self, cmd, timeout=60, internal_timeout=None,
return int(digit_lines[0].strip()), out
raise ShellStatusError(cmd, out)

async def cmd_status_output_async(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Send a command and return its exit status and output via a coroutine.
All arguments are identical to the regular function.
"""
out = await self.cmd_output_async(cmd, timeout, internal_timeout, print_func, safe)
try:
# Send the 'echo $?' (or equivalent) command to get the exit status
status = self.cmd_output(self.status_test_command, 30,
internal_timeout, print_func, safe)
except ShellError as error:
raise ShellStatusError(cmd, out) from error

# Get the first line consisting of digits only
digit_lines = [_ for _ in status.splitlines()
if self.__RE_STATUS.match(_.strip())]
if digit_lines:
return int(digit_lines[0].strip()), out
raise ShellStatusError(cmd, out)

def cmd_status(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Expand Down Expand Up @@ -1331,6 +1565,16 @@ def cmd_status(self, cmd, timeout=60, internal_timeout=None,
return self.cmd_status_output(cmd, timeout, internal_timeout,
print_func, safe)[0]

async def cmd_status_async(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Send a command and return its exit status via a coroutine.
All arguments are identical to the regular function.
"""
return await self.cmd_status_output_async(cmd, timeout, internal_timeout,
print_func, safe)[0]

def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None,
ok_status=None, ignore_all_errors=False):
"""
Expand Down Expand Up @@ -1372,6 +1616,28 @@ def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None,
return None
raise

async def cmd_async(self, cmd, timeout=60, internal_timeout=None, print_func=None,
ok_status=None, ignore_all_errors=False):
"""
Send a command and return its output via a coroutine. If the command's
exit status is nonzero, raise an exception.
All arguments are identical to the regular function.
"""
if ok_status is None:
ok_status = [0, ]
try:
status, output = await self.cmd_status_output_async(cmd, timeout,
internal_timeout,
print_func)
if status not in ok_status:
raise ShellCmdError(cmd, status, output)
return output
except ShellError:
if ignore_all_errors:
return None
raise

def get_command_output(self, cmd, timeout=60, internal_timeout=None,
print_func=None):
"""
Expand Down

0 comments on commit e6d28d2

Please sign in to comment.