Skip to content

Commit

Permalink
Airflow operator - reattach to running jobs on retries (disabled by d…
Browse files Browse the repository at this point in the history
…efault) (#286)
  • Loading branch information
Martynas Asipauskas authored and GitHub Enterprise committed Nov 26, 2024
1 parent 7313fb8 commit a30c99e
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 32 deletions.
2 changes: 1 addition & 1 deletion client/python/armada_client/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def get_job_status_by_external_job_uri(
:rtype: JobStatusResponse
"""
req = job_pb2.JobStatusUsingExternalJobUriRequest(
queue, job_set_id, external_job_uri
queue=queue, jobset=job_set_id, external_job_uri=external_job_uri
)
resp = await self.job_stub.GetJobStatusUsingExternalJobUri(req)
return resp
Expand Down
2 changes: 1 addition & 1 deletion client/python/armada_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_job_status_by_external_job_uri(
:rtype: JobStatusResponse
"""
req = job_pb2.JobStatusUsingExternalJobUriRequest(
queue, job_set_id, external_job_uri
queue=queue, jobset=job_set_id, external_job_uri=external_job_uri
)
return self.job_stub.GetJobStatusUsingExternalJobUri(req)

Expand Down
2 changes: 1 addition & 1 deletion client/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "armada_client"
version = "0.4.7"
version = "0.4.8"
description = "Armada gRPC API python client"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
62 changes: 55 additions & 7 deletions third_party/airflow/armada/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from airflow.models import TaskInstance
from airflow.utils.log.logging_mixin import LoggingMixin
from armada.model import GrpcChannelArgs
from armada_client.armada.job_pb2 import JobRunDetails
from armada_client.armada.job_pb2 import JobDetailsResponse, JobRunDetails
from armada_client.armada.submit_pb2 import JobSubmitRequestItem
from armada_client.client import ArmadaClient
from armada_client.typings import JobState
Expand Down Expand Up @@ -99,10 +99,55 @@ def submit_job(
)
@log_exceptions
def job_termination_reason(self, job_context: RunningJobContext) -> str:
resp = self.client.get_job_errors([job_context.job_id])
job_error = resp.job_errors.get(job_context.job_id, "")
if job_context.state in {
JobState.REJECTED,
JobState.PREEMPTED,
JobState.FAILED,
}:
resp = self.client.get_job_errors([job_context.job_id])
job_error = resp.job_errors.get(job_context.job_id, "")
return job_error or ""
return ""

return job_error or ""
@tenacity.retry(
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
def job_by_external_job_uri(
self,
armada_queue: str,
job_set: str,
external_job_uri: str,
) -> RunningJobContext:
response = self.client.get_job_status_by_external_job_uri(
armada_queue, job_set, external_job_uri
)
job_ids = list(response.job_states.keys())
job_details = self.client.get_job_details(job_ids).job_details.values()
last_submitted = next(
iter(
sorted(job_details, key=lambda d: d.submitted_ts.seconds, reverse=True)
),
None,
)
if last_submitted:
cluster = None
latest_run = self._get_latest_job_run_details(last_submitted)
if latest_run:
cluster = latest_run.cluster
return RunningJobContext(
armada_queue,
last_submitted.job_id,
job_set,
DateTime.utcnow(),
last_log_time=None,
cluster=cluster,
job_state=JobState(last_submitted.state).name,
)

return None

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=3),
Expand All @@ -125,7 +170,9 @@ def refresh_context(
if not cluster:
# Job is running / or completed already
if state == JobState.RUNNING or state.is_terminal():
run_details = self._get_latest_job_run_details(job_context.job_id)
job_id = job_context.job_id
job_details = self.client.get_job_details([job_id]).job_details[job_id]
run_details = self._get_latest_job_run_details(job_details)
if run_details:
cluster = run_details.cluster
return dataclasses.replace(job_context, job_state=state.name, cluster=cluster)
Expand Down Expand Up @@ -167,8 +214,9 @@ def context_to_xcom(
},
)

def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]:
job_details = self.client.get_job_details([job_id]).job_details[job_id]
def _get_latest_job_run_details(
self, job_details: Optional[JobDetailsResponse]
) -> Optional[JobRunDetails]:
if job_details and job_details.latest_run_id:
for run in job_details.job_runs:
if run.run_id == job_details.latest_run_id:
Expand Down
64 changes: 44 additions & 20 deletions third_party/airflow/armada/operators/armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import jinja2
import tenacity
from airflow.configuration import conf
from airflow.exceptions import AirflowFailException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serde import deserialize
from airflow.utils.context import Context
Expand All @@ -42,8 +44,9 @@
from .errors import ArmadaOperatorJobFailedError
from ..hooks import ArmadaHook
from ..model import RunningJobContext
from ..policies.reattach import external_job_uri, policy
from ..triggers import ArmadaPollJobTrigger
from ..utils import log_exceptions, xcom_pull_for_ti
from ..utils import log_exceptions, xcom_pull_for_ti, resolve_parameter_value


class LookoutLink(BaseOperatorLink):
Expand Down Expand Up @@ -102,6 +105,8 @@ class ArmadaOperator(BaseOperator, LoggingMixin):
:type job_acknowledgement_timeout: int
:param dry_run: Run Operator in dry-run mode - render Armada request and terminate.
:type dry_run: bool
:param reattach_policy: Operator reattach policy to use (defaults to: always)
:type reattach_policy: Optional[str]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
"""

Expand Down Expand Up @@ -130,6 +135,7 @@ def __init__(
dry_run: bool = conf.getboolean(
"armada_operator", "default_dry_run", fallback=False
),
reattach_policy: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -148,6 +154,15 @@ def __init__(
self.dry_run = dry_run
self.job_context = None

configured_reattach_policy: str = resolve_parameter_value(
"reattach_policy", reattach_policy, kwargs, "never"
)
self.log.info(
f"Configured reattach policy to: '{configured_reattach_policy}',"
f" max retries: {self.retries}"
)
self.reattach_policy = policy(configured_reattach_policy)

if self.container_logs and self.k8s_token_retriever is None:
self.log.warning(
"Token refresh mechanism not configured, airflow may stop retrieving "
Expand Down Expand Up @@ -326,35 +341,52 @@ def _reattach_or_submit_job(
def _try_reattach_to_running_job(
self, context: Context
) -> Optional[RunningJobContext]:
# TODO: We should support re-attaching to currently running jobs.
# This is subject to re-attach policy / discovering jobs we already submitted.
# Issue - xcom state gets cleared before re-entry.
# ctx = self.hook.context_from_xcom(ti, re_attach=True)
# On first try we intentionally do not re-attach.
self.log.info(context)
if context["ti"].try_number == 1:
return None

expected_job_uri = external_job_uri(context)
ctx = self.hook.job_by_external_job_uri(
self.armada_queue, self.job_set_id, expected_job_uri
)

# if ctx:
# if ctx.state not in {JobState.FAILED, JobState.PREEMPTED}:
if ctx:
termination_reason = self.hook.job_termination_reason(ctx)
if self.reattach_policy(ctx.state, termination_reason):
return ctx
else:
self.log.info(
f"Found: job-id {ctx.job_id} in {ctx.state}. "
"Didn't reattach due to reattach policy."
)

return None

def _poll_for_termination(self, context) -> None:
def _poll_for_termination(self, context: Context) -> None:
while self.job_context.state.is_active():
self._check_job_status_and_fetch_logs(context)
if self.job_context.state.is_active():
self._yield()

self._running_job_terminated(self.job_context)
self._running_job_terminated(context["ti"], self.job_context)

def _running_job_terminated(self, context: RunningJobContext):
def _running_job_terminated(self, ti: TaskInstance, context: RunningJobContext):
self.log.info(
f"job {context.job_id} terminated with state: {context.state.name}"
)
if context.state != JobState.SUCCEEDED:
raise ArmadaOperatorJobFailedError(
error = ArmadaOperatorJobFailedError(
context.armada_queue,
context.job_id,
context.state,
self.hook.job_termination_reason(context),
)
if self.reattach_policy(error.state, error.reason):
self.log.error(str(error))
raise AirflowFailException()
else:
raise error

def _not_acknowledged_within_timeout(self) -> bool:
if self.job_context.state == JobState.UNKNOWN:
Expand Down Expand Up @@ -416,14 +448,6 @@ def _xcom_push(self, context, key: str, value: Any):
task_instance = context["ti"]
task_instance.xcom_push(key=key, value=value)

def _external_job_uri(self, context: Context) -> str:
task_id = context["ti"].task_id
map_index = context["ti"].map_index
run_id = context["run_id"]
dag_id = context["dag"].dag_id

return f"airflow://{dag_id}/{task_id}/{run_id}/{map_index}"

def _annotate_job_request(self, context, request: JobSubmitRequestItem):
if "ANNOTATION_KEY_PREFIX" in os.environ:
annotation_key_prefix = f'{os.environ.get("ANNOTATION_KEY_PREFIX")}'
Expand All @@ -438,5 +462,5 @@ def _annotate_job_request(self, context, request: JobSubmitRequestItem):
request.annotations[annotation_key_prefix + "taskRunId"] = run_id
request.annotations[annotation_key_prefix + "dagId"] = dag_id
request.annotations[annotation_key_prefix + "externalJobUri"] = (
self._external_job_uri(context)
external_job_uri(context)
)
54 changes: 54 additions & 0 deletions third_party/airflow/armada/policies/reattach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Literal
from airflow.utils.context import Context

from armada_client.typings import JobState


def external_job_uri(context: Context) -> str:
task_id = context["ti"].task_id
map_index = context["ti"].map_index
run_id = context["run_id"]
dag_id = context["dag"].dag_id

return f"airflow://{dag_id}/{task_id}/{run_id}/{map_index}"


def policy(policy_type: Literal["always", "never", "running_or_succeeded"]) -> callable:
"""
Returns the corresponding re-attach policy function based on the policy type.
:param policy_type: The type of policy ('always', 'never', 'running_or_succeeded').
:type policy_type: Literal['always', 'never', 'running_or_succeeded']
:return: A function that determines whether to re-attach to an existing job.
:rtype: Callable[[JobState, str], bool]
"""
policy_type = policy_type.lower()
if policy_type == "always":
return always_reattach
elif policy_type == "never":
return never_reattach
elif policy_type == "running_or_succeeded":
return running_or_succeeded_reattach
else:
raise ValueError(f"Unknown policy type: {policy_type}")


def never_reattach(state: JobState, termination_reason: str) -> bool:
"""
Policy that never allows re-attaching a job.
"""
return False


def always_reattach(state: JobState, termination_reason: str) -> bool:
"""
Policy that always re-attaches to a job.
"""
return True


def running_or_succeeded_reattach(state: JobState, termination_reason: str) -> bool:
"""
Policy that allows re-attaching as long as it hasn't failed.
"""
return state not in {JobState.FAILED, JobState.REJECTED}
Empty file.
34 changes: 33 additions & 1 deletion third_party/airflow/armada/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
from typing import Any
from typing import Any, Callable, Optional, TypeVar

import tenacity
from airflow.configuration import conf
from airflow.models import TaskInstance


Expand All @@ -26,3 +27,34 @@ def wrapper(self, *args, **kwargs):
@log_exceptions
def xcom_pull_for_ti(ti: TaskInstance, key: str) -> Any:
return ti.xcom_pull(key=key, task_ids=ti.task_id, map_indexes=ti.map_index)


T = TypeVar("T")


def resolve_parameter_value(
param_name: str,
param_value: Optional[T],
kwargs: dict,
fallback_value: T,
type_converter: Callable[[str], T] = lambda x: x,
) -> T:
if param_value is not None:
return param_value

dag = kwargs.get("dag")
if dag and getattr(dag, "default_args", None):
default_args = dag.default_args
if param_name in default_args:
return default_args[param_name]

airflow_config_value = conf.get("my_section", param_name, fallback=None)
if airflow_config_value is not None:
try:
return type_converter(airflow_config_value)
except ValueError as e:
raise ValueError(
f"Failed to convert '{airflow_config_value}' for '{param_name}': {e}"
)

return fallback_value
4 changes: 3 additions & 1 deletion third_party/airflow/test/unit/operators/test_armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def default_hook() -> MagicMock:
mock = MagicMock()
job_context = running_job_context()
mock.submit_job.return_value = job_context
mock.job_by_external_job_uri.return_value = None
mock.job_termination_reason.return_value = "FAILED"
mock.refresh_context.return_value = dataclasses.replace(
job_context, job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER
)
Expand All @@ -51,7 +53,7 @@ def mock_operator_dependencies():
def context():
mock_ti = MagicMock()
mock_ti.task_id = DEFAULT_TASK_ID
mock_ti.try_number = 0
mock_ti.try_number = 1
mock_ti.xcom_pull.return_value = None

mock_dag = MagicMock()
Expand Down
Loading

0 comments on commit a30c99e

Please sign in to comment.