Skip to content

Commit

Permalink
Update neptune_example.py (flyteorg#1743)
Browse files Browse the repository at this point in the history
* Update neptune_example.py

This PR fixes some of the formatting issues in the neptune plugin example

Signed-off-by: Niels Bantilan <[email protected]>

* Update neptune_example.py

Signed-off-by: Niels Bantilan <[email protected]>

* make linter happy

Signed-off-by: Niels Bantilan <[email protected]>

---------

Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Oct 7, 2024
1 parent 568754e commit d6a9861
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions examples/neptune_plugin/neptune_plugin/neptune_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,23 @@
# %% [markdown]
# First, we specify the Neptune project that was created on Neptune's platform.
# Please update `NEPTUNE_PROJECT` to the value associated with your account.

# %%
NEPTUNE_PROJECT = "username/project"

# %% [markdown]
# Neptune requires an API key to authenticate with their service. In the above example,
# the secret is created using
# [Flyte's Secrets manager](https://docs.flyte.org/en/latest/user_guide/productionizing/secrets.html).

# %%
api_key = Secret(key="neptune-api-token", group="neptune-api-group")

# %% [markdown]
# Next, we use `ImageSpec` to construct a container with the dependencies for our
# XGBoost training task. Please set the `REGISTRY` to a registry that your cluster can access;

# %%
REGISTRY = "localhost:30000"

image = ImageSpec(
Expand All @@ -53,8 +59,11 @@
)


# %%
# %% [markdown]
# First, we use a task to download the dataset and cache the data in Flyte:


# %%
@task(
container_image=image,
cache=True,
Expand All @@ -68,14 +77,17 @@ def get_dataset() -> Tuple[np.ndarray, np.ndarray]:
return X, y


# %%
# %% [markdown]
# Next, we use the `neptune_init_run` decorator to configure Flyte to train an XGBoost
# model. The decorator requires an `api_key` secret to authenticate with Neptune and
# the task definition needs to request the same `api_key` secret. In the training
# function, the [Neptune run object](https://docs.neptune.ai/api/run/) is accessible
# through `current_context().neptune_run`, which is frequently used
# in Neptune's integrations. In this example, we pass the `Run` object into Neptune's
# XGBoost callback.


# %%
@task(
container_image=image,
secret_requests=[api_key],
Expand Down Expand Up @@ -119,9 +131,12 @@ def train_model(max_depth: int, X: np.ndarray, y: np.ndarray):
)


# %%
# %% [markdown]
# With Flyte's dynamic workflows, we can scale up multiple training jobs with different
# `max_depths`:


# %%
@dynamic(container_image=image)
def train_multiple_models(max_depths: List[int], X: np.ndarray, y: np.ndarray):
for max_depth in max_depths:
Expand All @@ -134,7 +149,7 @@ def train_wf(max_depths: List[int] = [2, 4, 10]):
train_multiple_models(max_depths=max_depths, X=X, y=y)


# %%
# %% [markdown]
# To run this workflow on a remote Flyte cluster run:
# ```bash
# union run --remote neptune_example.py train_wf
Expand Down

0 comments on commit d6a9861

Please sign in to comment.