Skip to content

Commit

Permalink
feat: add support for TLS
Browse files Browse the repository at this point in the history
  • Loading branch information
qdelamea-aneo committed Nov 13, 2024
1 parent 9b19903 commit e44a058
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/armonik_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@


COMMON_OPTIONS = cast(
OptionGroupDict, {"name": "Common options", "options": ["--debug", "--output", "--help"]}
OptionGroupDict, {"name": "Common options", "options": ["--output", "--config", "--debug", "--help"]}
)
CONNECTION_OPTIONS = cast(
OptionGroupDict, {"name": "Connection options", "options": ["--endpoint"]}
OptionGroupDict, {"name": "Connection options", "options": ["--endpoint", "--ca", "--cert", "--key"]}
)
click.rich_click.OPTION_GROUPS = {
"armonik": [
Expand Down
46 changes: 23 additions & 23 deletions src/armonik_cli/commands/session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import grpc
import rich_click as click

from datetime import timedelta
from typing import List, Tuple, Union

from armonik.client import ArmoniKSessions, ArmoniKTasks
from armonik.common import SessionStatus, Session, TaskOptions, Task, TaskStatus, Direction
from grpc import Channel

from armonik_cli.core import console, base_command, KeyValuePairParam, TimeDeltaParam

Expand All @@ -22,9 +22,9 @@ def session_group() -> None:

@session_group.command()
@base_command
def list(endpoint: str, output: str, debug: bool) -> None:
def list(channel_ctx: Channel, output: str, debug: bool) -> None:
"""List the sessions of an ArmoniK cluster."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
total, sessions = sessions_client.list_sessions()

Expand All @@ -40,9 +40,9 @@ def list(endpoint: str, output: str, debug: bool) -> None:
@click.option("-s", "--stats", is_flag=True, help="Compute a set of statistics for the session.")
@session_argument
@base_command
def get(endpoint: str, output: str, session_id: str, stats: bool, debug: bool) -> None:
def get(channel_ctx: Channel, output: str, session_id: str, stats: bool, debug: bool) -> None:
"""Get details of a given session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.get_session(session_id=session_id)
session = _clean_up_status(session)
Expand Down Expand Up @@ -127,7 +127,7 @@ def get(endpoint: str, output: str, session_id: str, stats: bool, debug: bool) -
)
@base_command
def create(
endpoint: str,
channel_ctx: Channel,
max_retries: int,
max_duration: timedelta,
priority: int,
Expand All @@ -143,7 +143,7 @@ def create(
debug: bool,
) -> None:
"""Create a new session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session_id = sessions_client.create_session(
default_task_options=TaskOptions(
Expand All @@ -169,9 +169,9 @@ def create(
@click.confirmation_option("--confirm", prompt="Are you sure you want to cancel this session?")
@session_argument
@base_command
def cancel(endpoint: str, output: str, session_id: str, debug: bool) -> None:
def cancel(channel_ctx: Channel, output: str, session_id: str, debug: bool) -> None:
"""Cancel a session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.cancel_session(session_id=session_id)
session = _clean_up_status(session)
Expand All @@ -181,9 +181,9 @@ def cancel(endpoint: str, output: str, session_id: str, debug: bool) -> None:
@session_group.command()
@session_argument
@base_command
def pause(endpoint: str, output: str, session_id: str, debug: bool) -> None:
def pause(channel_ctx: Channel, output: str, session_id: str, debug: bool) -> None:
"""Pause a session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.pause_session(session_id=session_id)
session = _clean_up_status(session)
Expand All @@ -193,9 +193,9 @@ def pause(endpoint: str, output: str, session_id: str, debug: bool) -> None:
@session_group.command()
@session_argument
@base_command
def resume(endpoint: str, output: str, session_id: str, debug: bool) -> None:
def resume(channel_ctx: Channel, output: str, session_id: str, debug: bool) -> None:
"""Resume a session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.resume_session(session_id=session_id)
session = _clean_up_status(session)
Expand All @@ -206,9 +206,9 @@ def resume(endpoint: str, output: str, session_id: str, debug: bool) -> None:
@click.confirmation_option("--confirm", prompt="Are you sure you want to close this session?")
@session_argument
@base_command
def close(endpoint: str, output: str, session_id: str, debug: bool) -> None:
def close(channel_ctx: Channel, output: str, session_id: str, debug: bool) -> None:
"""Close a session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.close_session(session_id=session_id)
session = _clean_up_status(session)
Expand All @@ -219,9 +219,9 @@ def close(endpoint: str, output: str, session_id: str, debug: bool) -> None:
@click.confirmation_option("--confirm", prompt="Are you sure you want to purge this session?")
@session_argument
@base_command
def purge(endpoint: str, output: str, session_id: str, debug: bool) -> None:
def purge(channel_ctx: Channel, output: str, session_id: str, debug: bool) -> None:
"""Purge a session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.purge_session(session_id=session_id)
session = _clean_up_status(session)
Expand All @@ -232,9 +232,9 @@ def purge(endpoint: str, output: str, session_id: str, debug: bool) -> None:
@click.confirmation_option("--confirm", prompt="Are you sure you want to delete this session?")
@session_argument
@base_command
def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None:
def delete(channel_ctx: Channel, output: str, session_id: str, debug: bool) -> None:
"""Delete a session and associated data from the cluster."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.delete_session(session_id=session_id)
session = _clean_up_status(session)
Expand All @@ -257,10 +257,10 @@ def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None:
@session_argument
@base_command
def stop_submission(
endpoint: str, session_id: str, clients_only: bool, workers_only: bool, output: str, debug: bool
channel_ctx: Channel, session_id: str, clients_only: bool, workers_only: bool, output: str, debug: bool
) -> None:
"""Stop clients and/or workers from submitting new tasks in a session."""
with grpc.insecure_channel(endpoint) as channel:
with channel_ctx as channel:
sessions_client = ArmoniKSessions(channel)
session = sessions_client.stop_submission_session(
session_id=session_id, client=clients_only, worker=workers_only
Expand All @@ -274,7 +274,7 @@ def _clean_up_status(session: Session) -> Session:
session.status = SessionStatus.name_from_value(session.status).split("_")[-1].capitalize()
return session

def _get_session_throughput(channel: grpc.Channel, session_id: str) -> Tuple[float, float]:
def _get_session_throughput(channel: Channel, session_id: str) -> Tuple[float, float]:
client = ArmoniKTasks(channel)
total, first = client.list_tasks(task_filter=Task.session_id == session_id, page_size=1, sort_field=Task.created_at, sort_direction=Direction.ASC)
_, last = client.list_tasks(task_filter=Task.session_id == session_id, page_size=1, sort_field=Task.ended_at, sort_direction=Direction.DESC)
Expand All @@ -285,7 +285,7 @@ def _get_session_throughput(channel: grpc.Channel, session_id: str) -> Tuple[flo
return throughput, elapsed_time


def _get_session_task_status(channel: grpc.Channel, session_id: str) -> Tuple[str, int]:
def _get_session_task_status(channel: Channel, session_id: str) -> Tuple[str, int]:
client = ArmoniKTasks(channel)
task_status = client.count_tasks_by_status(task_filter=Task.session_id == session_id)
return {k.name.capitalize(): v for k, v in task_status.items()}
49 changes: 48 additions & 1 deletion src/armonik_cli/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import grpc
import rich_click as click

from armonik.common.channel import create_channel

from armonik_cli.core.console import console
from armonik_cli.exceptions import NotFoundError, InternalError

Expand Down Expand Up @@ -72,13 +74,55 @@ def decorator(func):
"-e",
"--endpoint",
type=str,
required=True,
required=False,
help="Endpoint of the cluster to connect to.",
envvar="ARMONIK__ENDPOINT",
metavar="ENDPOINT",
)
ca_option = click.option(
"--ca",
"--certificate-authority",
type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
required=False,
help="Path to the certificate authority to read.",
envvar="ARMONIK__CA",
metavar="CA_PATH",
)
cert_option = click.option(
"--cert",
"--client-certificate",
type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
required=False,
help="Path to the client certificate to read.",
envvar="ARMONIK__CERT",
metavar="CERT_PATH",
)
key_option = click.option(
"--key",
"--client-key",
type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
required=False,
help="Path to the client key to read.",
envvar="ARMONIK__KEY",
metavar="KEY_PATH",
)
config_option = click.option(
"-c",
"--config",
"optional_config_file",
type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
required=False,
help="Path to a third-party configuration file.",
envvar="ARMONIK__CONFIG",
metavar="CONFIG_PATH",
)

if connection_args:
func = endpoint_option(func)
func = ca_option(func)
func = cert_option(func)
func = key_option(func)
func = config_option(func)

@click.option(
"-o",
Expand All @@ -95,6 +139,9 @@ def decorator(func):
@error_handler
@wraps(func)
def wrapper(*args, **kwargs):
if connection_args:
optional_config = kwargs.pop("optional_config_file")
kwargs["channel_ctx"] = create_channel(kwargs.pop("endpoint"), certificate_authority=kwargs.pop("ca"), client_certificate=kwargs.pop("cert"), client_key=kwargs.pop("key"))
return func(*args, **kwargs)

return wrapper
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def test_func():
pass

assert test_func.__name__ == "test_func"
assert len(test_func.__click_params__) == 3 if connection_args else 2
assert len(test_func.__click_params__) == 7 if connection_args else 2
assert (
sorted([param.name for param in test_func.__click_params__])
== ["debug", "endpoint", "output"]
== ["ca", "cert", "debug", "endpoint", "key", "optional_config_file", "output"]
if connection_args
else ["debug", "output"]
)

0 comments on commit e44a058

Please sign in to comment.