Skip to content

Commit

Permalink
Fixes #45
Browse files Browse the repository at this point in the history
Signed-off-by: Trevor Grant <[email protected]>
  • Loading branch information
ibm-peach-fish committed Oct 27, 2023
1 parent af417d8 commit 4c894a7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
3 changes: 3 additions & 0 deletions caikit_ray_backend/blocks/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def train(
error.value_check("<RYT87231812E>", num_gpus > 0)
env_vars["requested_gpus"] = num_gpus

training_timeout = self.config.get("training_timeout", 60)
env_vars["training_timeout"] = float(training_timeout)

# Serialize **kwargs and add them to environment variables
my_kwargs = {}
for key, value in kwargs.items():
Expand Down
14 changes: 2 additions & 12 deletions caikit_ray_backend/ray_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


# Standard
from time import sleep
import base64
import json
import os
Expand All @@ -24,7 +23,6 @@
import ray

# First Party
from caikit import get_config
from caikit.core.toolkit.errors import error_handler
import alog

Expand Down Expand Up @@ -78,28 +76,20 @@ def main():
if model_path:
error.type_check("<RYT70238308E>", str, model_path=model_path)

timeout = 3
if get_config().training_timeout:
try:
timeout = float(get_config().training_timeout)
except ValueError:
log.warn(
f"training_timeout: '{get_config().training_timeout}' cannot be converted to int, ignoring"
)
timeout = runtime_env.get("training_timeout", float(60))

# Finally kick off training
with alog.ContextTimer(log.debug, "Done training %s in: ", module_class):
task = ray_training_tasks.train_and_save.options(
num_cpus=num_cpus, num_gpus=num_gpus
).remote(module_class, model_path, *args, **kwargs)

ready, _ = ray.wait([task], timeout=timeout)

if ready:
ray.get(task)
else:
ray.cancel(task)
log.error("Task did not complete before time out.")
raise TimeoutError("Task did not complete before time out.")


if __name__ == "__main__":
Expand Down
27 changes: 20 additions & 7 deletions tests/test_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
# Standard
from datetime import datetime
import logging
import os
import time

Expand Down Expand Up @@ -46,7 +47,10 @@ def jsonl_file_data_stream():


def test_job_submission_client(mock_ray_cluster, jsonl_file_data_stream):
config = {"connection": {"address": mock_ray_cluster.address}}
config = {
"connection": {"address": mock_ray_cluster.address},
"training_timeout": 30.0,
}
trainer = RayJobTrainModule(config, "ray_backend")

args = [jsonl_file_data_stream]
Expand Down Expand Up @@ -82,7 +86,10 @@ def test_job_submission_client(mock_ray_cluster, jsonl_file_data_stream):


def test_wait(mock_ray_cluster, jsonl_file_data_stream):
config = {"connection": {"address": mock_ray_cluster.address}}
config = {
"connection": {"address": mock_ray_cluster.address},
"training_timeout": 30.0,
}
trainer = RayJobTrainModule(config, "ray_backend")

args = [jsonl_file_data_stream]
Expand All @@ -101,7 +108,10 @@ def test_wait(mock_ray_cluster, jsonl_file_data_stream):


def test_load(mock_ray_cluster, jsonl_file_data_stream):
config = {"connection": {"address": mock_ray_cluster.address}}
config = {
"connection": {"address": mock_ray_cluster.address},
"training_timeout": 30.0,
}
trainer = RayJobTrainModule(config, "ray_backend")

args = [jsonl_file_data_stream]
Expand All @@ -118,7 +128,10 @@ def test_load(mock_ray_cluster, jsonl_file_data_stream):


def test_cancel(mock_ray_cluster, jsonl_file_data_stream):
config = {"connection": {"address": mock_ray_cluster.address}}
config = {
"connection": {"address": mock_ray_cluster.address},
"training_timeout": 30.0,
}
trainer = RayJobTrainModule(config, "ray_backend")

args = [jsonl_file_data_stream]
Expand All @@ -145,7 +158,7 @@ def test_cancel(mock_ray_cluster, jsonl_file_data_stream):
def test_timeout(mock_ray_cluster, jsonl_file_data_stream):
config = {
"connection": {"address": mock_ray_cluster.address},
"training_timeout": 3,
"training_timeout": 0.1,
}
trainer = RayJobTrainModule(config, "ray_backend")

Expand All @@ -156,11 +169,11 @@ def test_timeout(mock_ray_cluster, jsonl_file_data_stream):
save_path="/tmp",
)

time.sleep(5)
time.sleep(3)

status = model_future.get_info().status
print("Final status was", status)
assert status == TrainingStatus.CANCELED
assert status == TrainingStatus.ERRORED


## Test Ray Backend
Expand Down

0 comments on commit 4c894a7

Please sign in to comment.