Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce async alternatives in the client command hierarchy #121

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions aexpect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import subprocess
import locale
import logging
import asyncio

from aexpect.exceptions import ExpectError
from aexpect.exceptions import ExpectProcessTerminatedError
Expand Down Expand Up @@ -806,6 +807,42 @@ def _read_nonblocking(self, internal_timeout=None, timeout=None):
if end_time and time.time() > end_time:
return read, data

async def _read_nonblocking_async(self, internal_timeout=None, timeout=None):
"""
Read from child until there is nothing to read for timeout seconds via a coroutine.

All arguments are identical to the regular function.
"""
if internal_timeout is None:
internal_timeout = 100
else:
internal_timeout *= 1000
end_time = None
if timeout:
end_time = time.time() + timeout
expect_pipe = self._get_fd("expect")
poller = select.poll()
poller.register(expect_pipe, select.POLLIN)
data = ""
read = 0
while True:
try:
poll_status = poller.poll(internal_timeout)
except select.error:
return read, data
if poll_status:
raw_data = os.read(expect_pipe, 1024)
if not raw_data:
return read, data
read += len(raw_data)
data += raw_data.decode(self.encoding, "ignore")
else:
return read, data
if end_time and time.time() > end_time:
return read, data
# TODO: sleeping for long here slows down the async command too much
#await asyncio.sleep(1)

def read_nonblocking(self, internal_timeout=None, timeout=None):
"""
Read from child until there is nothing to read for timeout seconds.
Expand Down Expand Up @@ -925,6 +962,55 @@ def read_until_output_matches(self, patterns, filter_func=lambda x: x,
# This shouldn't happen
raise ExpectError(patterns, output)

async def read_until_output_matches_async(self, patterns, filter_func=lambda x: x,
timeout=60.0, internal_timeout=None,
print_func=None, match_func=None):
"""
Read from child using read_nonblocking until a pattern matches via a coroutine.

All arguments are identical to the regular function.
"""
if not match_func:
match_func = self.match_patterns
expect_pipe = self._get_fd("expect")
poller = select.poll()
poller.register(expect_pipe, select.POLLIN)
output = ""
end_time = time.time() + timeout
while True:
try:
max_ms = int((end_time - time.time()) * 1000)
poll_timeout_ms = max(0, max_ms)
poll_status = poller.poll(poll_timeout_ms)
except select.error:
break
if not poll_status:
raise ExpectTimeoutError(patterns, output)
# Read data from child
read, data = await self._read_nonblocking_async(internal_timeout,
end_time - time.time())
if not read:
break
if not data:
continue
# Print it if necessary-
if print_func:
for line in data.splitlines():
print_func(line)
# Look for patterns
output += data
match = match_func(filter_func(output), patterns)
if match is not None:
return match, output
await asyncio.sleep(0.1)

# Check if the child has terminated
if utils_wait.wait_for(lambda: not self.is_alive(), 5, 0, 0.1):
raise ExpectProcessTerminatedError(patterns, self.get_status(),
output)
# This shouldn't happen
raise ExpectError(patterns, output)

def read_until_last_word_matches(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Expand Down Expand Up @@ -953,6 +1039,25 @@ def _get_last_word(cont):
timeout, internal_timeout,
print_func)

async def read_until_last_word_matches_async(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Read using read_nonblocking until the last word of the output matches
one of the patterns (using match_patterns), or until timeout expires
via a coroutine.

All arguments are identical to the regular function.
"""

def _get_last_word(cont):
if cont:
return cont.split()[-1]
return ""

return await self.read_until_output_matches_async(patterns, _get_last_word,
timeout, internal_timeout,
print_func)

def read_until_last_line_matches(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Expand Down Expand Up @@ -987,6 +1092,25 @@ def _get_last_nonempty_line(cont):
timeout, internal_timeout,
print_func)

async def read_until_last_line_matches_async(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Read until the last non-empty line matches a pattern via a coroutine.

All arguments are identical to the regular function.
"""

def _get_last_nonempty_line(cont):
nonempty_lines = [_ for _ in cont.splitlines() if _.strip()]
if nonempty_lines:
return nonempty_lines[-1]
return ""

return await self.read_until_output_matches_async(patterns,
_get_last_nonempty_line,
timeout, internal_timeout,
print_func)

def read_until_any_line_matches(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Expand Down Expand Up @@ -1016,6 +1140,19 @@ def read_until_any_line_matches(self, patterns, timeout=60.0,
print_func,
self.match_patterns_multiline)

async def read_until_any_line_matches_async(self, patterns, timeout=60.0,
internal_timeout=None, print_func=None):
"""
Read using read_nonblocking until any line matches a pattern via a coroutine.

All arguments are identical to the regular function.
"""
return await self.read_until_output_matches_async(patterns,
lambda x: x.splitlines(),
timeout, internal_timeout,
print_func,
self.match_patterns_multiline)


class ShellSession(Expect):

Expand Down Expand Up @@ -1170,6 +1307,18 @@ def read_up_to_prompt(self, timeout=60.0, internal_timeout=None,
internal_timeout,
print_func)[1]

async def read_up_to_prompt_async(self, timeout=60.0, internal_timeout=None,
print_func=None):
"""
Read until the last non-empty line matches the prompt via a coroutine.

All arguments are identical to the regular function.
"""
_, data = await self.read_until_last_line_matches_async([self.prompt], timeout,
internal_timeout,
print_func)
return data

def cmd_output(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Expand Down Expand Up @@ -1215,6 +1364,35 @@ def cmd_output(self, cmd, timeout=60, internal_timeout=None,
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

async def cmd_output_async(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Send a command and return its output via a coroutine.

All arguments are identical to the regular function.
"""
if safe:
return await self.cmd_output_safe_async(cmd, timeout)
session_tag = f"[{self.output_prefix}] " if self.output_prefix else ""
LOG.debug("%sSending command: %s", session_tag, cmd)
self.read_nonblocking(0, timeout)
self.sendline(cmd)
try:
out = await self.read_up_to_prompt_async(timeout, internal_timeout, print_func)
except ExpectTimeoutError as error:
output = self.remove_command_echo(error.output, cmd)
raise ShellTimeoutError(cmd, output) from error
except ExpectProcessTerminatedError as error:
output = self.remove_command_echo(error.output, cmd)
raise ShellProcessTerminatedError(cmd, error.status, output) from error
except ExpectError as error:
output = self.remove_command_echo(error.output, cmd)
raise ShellError(cmd, output) from error

# Remove the echoed command and the final shell prompt
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

def cmd_output_safe(self, cmd, timeout=60):
"""
Send a command and return its output (serial sessions).
Expand Down Expand Up @@ -1264,6 +1442,42 @@ def cmd_output_safe(self, cmd, timeout=60):
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

async def cmd_output_safe_async(self, cmd, timeout=60):
"""
Send a command and return its output (serial sessions) via a coroutine.

All arguments are identical to the regular function.
"""
session_tag = f"[{self.output_prefix}] " if self.output_prefix else ""
LOG.debug("%sSending command (safe): %s", session_tag, cmd)
self.read_nonblocking(0, timeout)
self.sendline(cmd)
out = ""
success = False
start_time = time.time()
while (time.time() - start_time) < timeout:
try:
out += await self.read_up_to_prompt_async(0.5)
success = True
break
except ExpectTimeoutError as error:
out = f"{out}{error.output}"
self.sendline()
except ExpectProcessTerminatedError as error:
output = self.remove_command_echo(f"{out}{error.output}", cmd)
raise ShellProcessTerminatedError(cmd, error.status,
output) from error
except ExpectError as error:
output = self.remove_command_echo(f"{out}{error.output}", cmd)
raise ShellError(cmd, output) from error

if not success:
raise ShellTimeoutError(cmd, out)

# Remove the echoed command and the final shell prompt
return self.remove_last_nonempty_line(self.remove_command_echo(out,
cmd))

def cmd_status_output(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Expand Down Expand Up @@ -1304,6 +1518,28 @@ def cmd_status_output(self, cmd, timeout=60, internal_timeout=None,
return int(digit_lines[0].strip()), out
raise ShellStatusError(cmd, out)

async def cmd_status_output_async(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Send a command and return its exit status and output via a coroutine.

All arguments are identical to the regular function.
"""
out = await self.cmd_output_async(cmd, timeout, internal_timeout, print_func, safe)
try:
# Send the 'echo $?' (or equivalent) command to get the exit status
status = self.cmd_output(self.status_test_command, 30,
internal_timeout, print_func, safe)
except ShellError as error:
raise ShellStatusError(cmd, out) from error

# Get the first line consisting of digits only
digit_lines = [_ for _ in status.splitlines()
if self.__RE_STATUS.match(_.strip())]
if digit_lines:
return int(digit_lines[0].strip()), out
raise ShellStatusError(cmd, out)

def cmd_status(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Expand Down Expand Up @@ -1331,6 +1567,16 @@ def cmd_status(self, cmd, timeout=60, internal_timeout=None,
return self.cmd_status_output(cmd, timeout, internal_timeout,
print_func, safe)[0]

async def cmd_status_async(self, cmd, timeout=60, internal_timeout=None,
print_func=None, safe=False):
"""
Send a command and return its exit status via a coroutine.

All arguments are identical to the regular function.
"""
return await self.cmd_status_output_async(cmd, timeout, internal_timeout,
print_func, safe)[0]

def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None,
ok_status=None, ignore_all_errors=False):
"""
Expand Down Expand Up @@ -1372,6 +1618,28 @@ def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None,
return None
raise

async def cmd_async(self, cmd, timeout=60, internal_timeout=None, print_func=None,
ok_status=None, ignore_all_errors=False):
"""
Send a command and return its output via a coroutine. If the command's
exit status is nonzero, raise an exception.

All arguments are identical to the regular function.
"""
if ok_status is None:
ok_status = [0, ]
try:
status, output = await self.cmd_status_output_async(cmd, timeout,
internal_timeout,
print_func)
if status not in ok_status:
raise ShellCmdError(cmd, status, output)
return output
except ShellError:
if ignore_all_errors:
return None
raise

def get_command_output(self, cmd, timeout=60, internal_timeout=None,
print_func=None):
"""
Expand Down