Skip to content

Commit

Permalink
Lazy import executor logic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 656529715
Change-Id: I74e03b52a18af82bc17f5703dae3dc0095fca512
GitOrigin-RevId: 4cbaa12169492ffaa1f45ec8bedddd980e5083c5
  • Loading branch information
DeepMind Team authored and alpiccioni committed Dec 4, 2024
1 parent fe4994a commit b07be62
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 207 deletions.
42 changes: 39 additions & 3 deletions xmanager/cloud/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import asyncio
import functools
from typing import Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence

import attr
from kubernetes import client as k8s_client
Expand All @@ -24,9 +24,11 @@
from xmanager import xm_flags
from xmanager.xm import utils
from xmanager.xm_local import executables as local_executables
from xmanager.xm_local import execution as local_execution
from xmanager.xm_local import executors as local_executors
from xmanager.xm_local import handles
from xmanager.xm_local import registry
from xmanager.xm_local import status as local_status
from xmanager.xm_local.storage import database


@functools.lru_cache()
Expand Down Expand Up @@ -165,7 +167,7 @@ async def wait_for_job(self, job: k8s_client.V1Job) -> None:


@attr.s(auto_attribs=True)
class KubernetesHandle(local_execution.ExecutionHandle):
class KubernetesHandle(handles.ExecutionHandle):
"""A handle for referring to the launched container."""

jobs: List[k8s_client.V1Job]
Expand All @@ -176,6 +178,14 @@ async def wait(self) -> None:
def get_status(self) -> local_status.LocalWorkUnitStatus:
raise NotImplementedError

def save_to_storage(self, experiment_id: int, work_unit_id: int) -> None:
for job in self.jobs:
namespace = job.metadata.namespace or 'default'
name = job.metadata.name
database.database().insert_kubernetes_job(
experiment_id, work_unit_id, namespace, name
)


# Must act on all jobs with `local_executors.Kubernetes` executor.
def launch(
Expand Down Expand Up @@ -252,3 +262,29 @@ def node_selector_from_executor(
)
}
return {}


async def _async_launch(
local_experiment_unit: Any, job_group: xm.JobGroup
) -> list[KubernetesHandle]:
return launch(local_experiment_unit.get_full_job_name, job_group)


def _create_handle(*args, data, kubernetes_jobs) -> KubernetesHandle:
del args # unused
data = k8s_client.V1Job(
metadata=k8s_client.V1ObjectMeta(
namespace=data.kubernetes.namespace, name=data.kubernetes.job_name
)
)
kubernetes_jobs.append(data)
return KubernetesHandle(jobs=kubernetes_jobs)


def register():
"""Registers Kubernetes execution logic."""
registry.register(
local_executors.Kubernetes,
launch=_async_launch,
create_handle=_create_handle,
)
34 changes: 30 additions & 4 deletions xmanager/cloud/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import math
import os
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple

import attr
from google.cloud import aiplatform
Expand All @@ -28,9 +28,11 @@
from xmanager.cloud import auth
from xmanager.xm import utils
from xmanager.xm_local import executables as local_executables
from xmanager.xm_local import execution as local_execution
from xmanager.xm_local import executors as local_executors
from xmanager.xm_local import handles
from xmanager.xm_local import registry
from xmanager.xm_local import status as local_status
from xmanager.xm_local.storage import database

_DEFAULT_LOCATION = 'us-central1'
# The only machines available on AI Platform are N1 machines and A2 machines.
Expand Down Expand Up @@ -328,7 +330,7 @@ def get_machine_spec(job: xm.Job) -> Dict[str, Any]:


@attr.s(auto_attribs=True)
class VertexHandle(local_execution.ExecutionHandle):
class VertexHandle(handles.ExecutionHandle):
"""A handle for referring to the launched container."""

job_name: str
Expand All @@ -344,11 +346,16 @@ def get_status(self) -> local_status.LocalWorkUnitStatus:
status = _STATE_TO_STATUS[int(state)]
return local_status.LocalWorkUnitStatus(status=status)

def save_to_storage(self, experiment_id: int, work_unit_id: int) -> None:
database.database().insert_vertex_job(
experiment_id, work_unit_id, self.job_name
)


# Must act on all jobs with `local_executors.Vertex` executor.
def launch(
experiment_title: str, work_unit_name: str, job_group: xm.JobGroup
) -> List[VertexHandle]:
) -> list[VertexHandle]:
"""Launch Vertex jobs in the job_group and return a handler."""
jobs = xm.job_operators.collect_jobs_by_filter(
job_group, _vertex_job_predicate
Expand Down Expand Up @@ -388,3 +395,22 @@ def cpu_ram_to_machine_type(cpu: Optional[int], ram: Optional[int]) -> str:
raise ValueError(
'(cpu={}, ram={}) does not fit in any valid machine type'.format(cpu, ram)
)


async def _async_launch(
local_experiment_unit: Any, job_group: xm.JobGroup
) -> list[VertexHandle]:
return launch(
local_experiment_unit._experiment_title, # pylint: disable=protected-access
local_experiment_unit.experiment_unit_name,
job_group,
)


def register():
"""Registers Vertex execution logic."""
registry.register(
local_executors.Vertex,
launch=_async_launch,
create_handle=lambda *args, data: VertexHandle(data.caip.resource_name),
)
15 changes: 15 additions & 0 deletions xmanager/xm/job_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,21 @@ class Executor(abc.ABC):
def Spec(cls) -> ExecutorSpec: # pylint: disable=invalid-name
raise NotImplementedError

@classmethod
async def launch(
cls, local_experiment_unit: Any, job_group: 'JobGroup'
) -> Sequence[Any]:
"""Launches a job group on the executor platform.
Args:
local_experiment_unit: The experiment unit to launch.
job_group: The job group to launch.
Returns:
Execution handles for jobs in the group.
"""
raise NotImplementedError


def _validate_env_vars(
self: Any, attribute: Any, env_vars: Dict[str, str]
Expand Down
Loading

0 comments on commit b07be62

Please sign in to comment.