Skip to content

Commit

Permalink
clean up train/launch scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 31, 2024
1 parent c43039e commit 6569131
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- The `work_dir` argument to `TrainerConfig` now defaults to `save_folder` is `save_folder` is a local path, otherwise a temporary directory with the same name as the basename of the `save_folder`.
- The `seed` argument to `prepare_training_environment()` is now optional.

### Fixed

Expand Down
9 changes: 6 additions & 3 deletions src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
SpeedMonitorCallback,
WandBCallback,
)
from olmo_core.utils import get_default_device
from olmo_core.utils import get_default_device, seed_all


@dataclass
Expand All @@ -39,6 +39,7 @@ class ExperimentConfig(Config):
optim: AdamWConfig
dataset: MemMapDatasetConfig
trainer: TrainerConfig
init_seed: int = 12536


def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
Expand Down Expand Up @@ -108,6 +109,9 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
def main(run_name: str, overrides: List[str]):
config = build_config(run_name, overrides)

# Set RNG states on all devices.
seed_all(config.init_seed)

# Build components.
model = config.model.build(
init_device="meta",
Expand All @@ -133,8 +137,7 @@ def main(run_name: str, overrides: List[str]):
print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]")
sys.exit(1)

run_name = sys.argv[1]
overrides = sys.argv[2:]
run_name, *overrides = sys.argv[1:]

prepare_training_environment()
try:
Expand Down
3 changes: 1 addition & 2 deletions src/examples/train_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig:
print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]")
sys.exit(1)

run_name = sys.argv[1]
overrides = sys.argv[2:]
run_name, *overrides = sys.argv[1:]

prepare_cli_environment()

Expand Down
5 changes: 3 additions & 2 deletions src/olmo_core/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

def prepare_training_environment(
*,
seed: int = 0,
seed: Optional[int] = None,
backend: Optional[str] = "cpu:gloo,cuda:nccl",
timeout: timedelta = timedelta(minutes=5),
log_filter_type: Optional[LogFilterType] = None,
Expand Down Expand Up @@ -112,7 +112,8 @@ def prepare_training_environment(
add_cached_path_clients()

# Init RNG states.
seed_all(seed)
if seed is not None:
seed_all(seed)

if is_distributed():
log.info(f"Using distributed backend {dist.get_backend()}")
Expand Down
49 changes: 33 additions & 16 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Train a 7B OLMo model. See below for usage.
Train a 7B OLMo model. Run this script without any arguments to see usage info.
"""

import json
Expand Down Expand Up @@ -36,7 +36,12 @@
SpeedMonitorCallback,
WandBCallback,
)
from olmo_core.utils import generate_uuid, get_default_device, prepare_cli_environment
from olmo_core.utils import (
generate_uuid,
get_default_device,
prepare_cli_environment,
seed_all,
)

log = logging.getLogger(__name__)

Expand All @@ -55,6 +60,7 @@ class ExperimentConfig(Config):
optim: AdamWConfig
dataset: MemMapDatasetConfig
trainer: TrainerConfig
init_seed: int = 12536


def build_config(run_name: str, cluster: str, overrides: List[str]) -> ExperimentConfig:
Expand Down Expand Up @@ -185,6 +191,9 @@ def launch(config: ExperimentConfig):


def train(config: ExperimentConfig):
# Set RNG states on all devices.
seed_all(config.init_seed)

# Build components.
model = config.model.build(
init_device="meta",
Expand All @@ -208,33 +217,41 @@ def train(config: ExperimentConfig):


if __name__ == "__main__":
usage = (
f"Usage: python {sys.argv[0]} {SubCmd.launch}|{SubCmd.train}|{SubCmd.dry_run} run_name cluster [OVERRIDES...]\n\n"
"Example:\n"
f"$ python {sys.argv[0]} {SubCmd.launch} OLMo-core-7B ai2/pluto-cirrascale --launch.num_nodes=2"
)
usage = f"""
[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]RUN_NAME CLUSTER[/] [i][OVERRIDES...][/]
[b]Subcommands[/]
[b magenta]launch:[/] Launch the script on Beaker with the [b magenta]train[/] subcommand.
[b magenta]train:[/] Run the trainer. You usually shouldn't invoke the script with this subcommand directly.
Instead use [b magenta]launch[/] or run it with torchrun.
[b magenta]dry_run:[/] Pretty print the config and exit.
if len(sys.argv) < 4:
print(usage)
[b]Examples[/]
$ [i]python {sys.argv[0]} {SubCmd.launch} OLMo-core-7B ai2/pluto-cirrascale --launch.num_nodes=2[/]
""".strip()

if len(sys.argv) < 4 or sys.argv[1] not in set(SubCmd):
import rich

rich.get_console().print(usage, highlight=False)
sys.exit(1)

cmd = sys.argv[1]
run_name = sys.argv[2]
cluster = sys.argv[3]
overrides = sys.argv[4:]
cmd, run_name, cluster, *overrides = sys.argv[1:]

if sys.argv[1] == SubCmd.launch:
if cmd == SubCmd.launch:
prepare_cli_environment()
config = build_config(run_name, cluster, overrides)
launch(config)
elif sys.argv[1] == SubCmd.dry_run:
elif cmd == SubCmd.dry_run:
prepare_cli_environment()
config = build_config(run_name, cluster, overrides)
log.info(config)
else:
elif cmd == SubCmd.train:
prepare_training_environment()
config = build_config(run_name, cluster, overrides)
try:
train(config)
finally:
teardown_training_environment()
else:
raise NotImplementedError(cmd)

0 comments on commit 6569131

Please sign in to comment.