diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 8c24cab..0000000 --- a/mypy.ini +++ /dev/null @@ -1,7 +0,0 @@ -[mypy] -plugins = pydantic.mypy -disallow_untyped_defs = False -ignore_missing_imports = True - -[mypy-scaling.*] -disallow_untyped_defs = True diff --git a/requirements/base.txt b/requirements/base.txt deleted file mode 100644 index 92cf521..0000000 --- a/requirements/base.txt +++ /dev/null @@ -1,13 +0,0 @@ -torch==2.1.1 -torchvision==0.16.1 -blended-dataset-loop==0.1 -pydantic==2.8.2 -einops -pyyaml -numpy==1.26.4 -tensorboard -tqdm -python-dateutil -tokenizers -pillow -wandb diff --git a/requirements/determined.txt b/requirements/determined.txt deleted file mode 100644 index 18ecce2..0000000 --- a/requirements/determined.txt +++ /dev/null @@ -1,4 +0,0 @@ -determined==0.26.4 -msrest==0.6.21 -google-api-core==2.8.2 -google-api-python-client==2.61.0 diff --git a/requirements/gpu_optimization.txt b/requirements/gpu_optimization.txt deleted file mode 100644 index 8bc78fe..0000000 --- a/requirements/gpu_optimization.txt +++ /dev/null @@ -1 +0,0 @@ -flash-attn==2.4.2 diff --git a/requirements/test.txt b/requirements/test.txt deleted file mode 100644 index 5cd3694..0000000 --- a/requirements/test.txt +++ /dev/null @@ -1,6 +0,0 @@ -pytest -pre-commit -ruff -mypy==1.10.0 -types-requests -types-pyyaml diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 9fc1d1f..0000000 --- a/ruff.toml +++ /dev/null @@ -1,3 +0,0 @@ -line-length = 120 -[lint.extend-per-file-ignores] -"__init__.py" = ["F401"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 57ce865..0000000 --- a/setup.py +++ /dev/null @@ -1,48 +0,0 @@ -from setuptools import setup, find_packages -from pathlib import Path - -reqs_dir = Path("./requirements") - - -def get_whitelisted_packages(sub_package: str): - whitelisted_packages = [] - for package_name in find_packages("src"): - if package_name.startswith(sub_package): - whitelisted_packages.append(package_name) - return whitelisted_packages - -whitelisted_packages = get_whitelisted_packages(sub_package="scaling") - -# Gather scaling requirements. -requirements_base = (reqs_dir / "base.txt").read_text().splitlines() -requirements_test = (reqs_dir / "test.txt").read_text().splitlines() - -requirements_optimization = (reqs_dir / "gpu_optimization.txt").read_text().splitlines() -requirements_determined = (reqs_dir / "determined.txt").read_text().splitlines() - -setup( - name="aleph-alpha-scaling", - url="https://github.com/Aleph-Alpha", - author="Aleph Alpha", - author_email="requests@aleph-alpha-ip.ai", - install_requires=requirements_base, - tests_require=requirements_test, - extras_require={ - "test": requirements_test, - "gpu_optimization": requirements_optimization, - "determined": requirements_determined, - }, - package_dir={"": "src"}, - packages=whitelisted_packages, - version="0.1.0", - license="Open Aleph License", - description="Non-distributed transformer implementation aimed at loading neox checkpoints for inference.", - # long_description=open("README.md").read(), - long_description_content_type="text/markdown", - entry_points="", - package_data={ - # If any package contains *.json or *.typed - "": ["*.json", "*.typed", "warnings.txt"], - }, - include_package_data=True, -) diff --git a/tests/core/test_logging/__init__.py b/tests/core/test_logging/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/core/test_logging/test_logger_config.py b/tests/core/test_logging/test_logger_config.py deleted file mode 100644 index 4106f6e..0000000 --- a/tests/core/test_logging/test_logger_config.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -from contextlib import nullcontext as does_not_raise -from typing import ContextManager - -import pytest -from pydantic import ValidationError - -from scaling.core.logging import LoggerConfig - - -@pytest.mark.parametrize( - "use_wandb, api_key, is_env_variable_set, expectation", - [ - pytest.param( - True, - "", - False, - pytest.raises(ValidationError, match="If 'use_wandb' is set to True a wandb api key needs to be provided."), - id="use_wandb is true, but api key is empty string", - ), - pytest.param( - True, - None, - False, - pytest.raises(ValidationError, match="If 'use_wandb' is set to True a wandb api key needs to be provided."), - id="use_wandb is true, but api key is None", - ), - pytest.param( - True, - "some_key", - False, - does_not_raise(), - id="use_wandb is true, api key is provided", - ), - pytest.param( - True, - None, - True, - does_not_raise(), - id="use_wandb is true, api key is not provided in config, but set as env variable", - ), - pytest.param(False, "", False, does_not_raise(), id="use_wandb is false and api key is empty string"), - pytest.param(False, None, False, does_not_raise(), id="use_wandb is false and api key is None"), - pytest.param(False, "some_key", False, does_not_raise(), id="use_wandb is false and api key is provided"), - ], -) -def test_logger_config_validation_for_wandb_and_api_key( - use_wandb: bool, api_key: str | None, is_env_variable_set: bool, expectation: ContextManager -) -> None: - if is_env_variable_set: - os.environ["WANDB_API_KEY"] = "some_key" - - with expectation: - LoggerConfig(use_wandb=use_wandb, wandb_api_key=api_key) diff --git a/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.bin b/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.bin deleted file mode 100644 index 0f36152..0000000 Binary files a/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.bin and /dev/null differ diff --git a/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.idx b/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.idx deleted file mode 100644 index 8db7793..0000000 Binary files a/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.idx and /dev/null differ diff --git a/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.meta.json b/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.meta.json deleted file mode 100644 index ca88af0..0000000 --- a/tests/transformer/files/dataset/data_index_cache_decoder_dataset_seed_42_seq_len_64.meta.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "document_count": 55, - "dtype": "int64", - "index_dtype": "int64" -} diff --git a/tests/transformer/utils.py b/tests/transformer/utils.py deleted file mode 100644 index c39cec8..0000000 --- a/tests/transformer/utils.py +++ /dev/null @@ -1,306 +0,0 @@ -import math -import os - -import torch - -from scaling.core.logging import LoggerConfig, logger - - -def normal_round(n, digits=0): - """ - python rounding rounds 0.5 down - this function fixes it and rounds up - """ - n = n * (10**digits) - if n - math.floor(n) < 0.5: - return float(math.floor(n)) / (10**digits) - return float(math.ceil(n)) / (10**digits) - - -def rounded_equal(a: float, b: float, digits=10): - """ - sometimes float have different numbers of digits due to serialization and rounding - this function checks if two floats are equal while allowing for one float to be rounded - If different numbers of digits are encountered, the more precise number is rounded - """ - # no rounding necessary if both are equal - if a == b: - return True - - a = float(a) - b = float(b) - - # find number of characters - str_a = str(a) - str_b = str(b) - digits_a = len(str_a.split(".")[1]) - digits_b = len(str_b.split(".")[1]) - - # if same length we know that the two floats must truly be different - if digits_a < digits_b: - round_to_digits = min(digits_a, digits) - else: - # b is shorter, a is rounded - round_to_digits = min(digits_b, digits) - - # if the number itself has been rounded, rounded again at the last position can result in problems - if min(digits_a, digits_b) - 1 == round_to_digits: - round_to_digits += 1 - - a = normal_round(a, round_to_digits) - b = normal_round(b, round_to_digits) - return a == b - - -def assert_nested_dicts_equal(result, target, message=None, precision=None): - """ - compare two dictionaries and do readable assertions - """ - assert isinstance(result, dict), "result is a dict" - assert isinstance(target, dict), "target is a dict" - - assert set(result.keys()) == set(target.keys()), ( - f"result and target have different keys: {set(result.keys())} vs. {set(target.keys())}" - + ("" if message is None else " (" + message + ")") - ) - - for k, r_v in result.items(): - t_v = target[k] - assert type(r_v) is type(t_v), ( - "result and target have different value types for " - + str(k) - + ( - f": {type(r_v)} with value {r_v} vs. {type(t_v)} with value {t_v}" - if message is None - else " (" + message + f"): {type(r_v)} vs. {type(t_v)}" - ) - ) - if isinstance(r_v, dict): - assert_nested_dicts_equal( - result=r_v, - target=t_v, - message=(str(k) if message is None else message + "." + str(k)), - precision=precision, - ) - elif isinstance(r_v, list): - assert_nested_lists_equal( - result=r_v, - target=t_v, - message=(str(k) if message is None else message + "." + str(k)), - precision=precision, - ) - else: - if precision is not None and isinstance(r_v, float): - assert abs(r_v - t_v) < 0.1**precision, ( - "result and target have different values for " - + str(k) - + "; r_v == " - + str(r_v) - + " (" - + str(type(r_v)) - + "); t_v == " - + str(t_v) - + "(" - + str(type(r_v)) - + ")" - + ("" if message is None else " (" + message + ")") - + "; precision == " - + str(precision) - ) - else: - if torch.is_tensor(r_v) and torch.is_tensor(t_v): - assert (r_v == t_v).all(), ( - "result and target have different values for " - + str(k) - + "; r_v == " - + str(r_v) - + " (" - + str(type(r_v)) - + "); t_v == " - + str(t_v) - + "(" - + str(type(r_v)) - + ")" - + ("" if message is None else " (" + message + ")") - ) - else: - assert r_v == t_v, ( - "result and target have different values for " - + str(k) - + "; r_v == " - + str(r_v) - + " (" - + str(type(r_v)) - + "); t_v == " - + str(t_v) - + "(" - + str(type(r_v)) - + ")" - + ("" if message is None else " (" + message + ")") - ) - - -def assert_nested_lists_equal(result, target, message=None, precision=None): - assert isinstance(result, list), "result is a list" - assert isinstance(target, list), "target is a list" - - assert len(result) == len(target), "result and target have different lengths" + ( - "" if message is None else " (" + message + ")" - ) - for i, (r_v, t_v) in enumerate(zip(result, target)): - assert type(r_v) is type(t_v), ( - "result and target have different value types for list item " - + str(i) - + ("" if message is None else " (" + message + ")") - ) - if isinstance(r_v, dict): - assert_nested_dicts_equal( - result=r_v, - target=t_v, - message=("list item " + str(i) if message is None else message + "." + "list item " + str(i)), - precision=precision, - ) - elif isinstance(r_v, list): - assert_nested_lists_equal( - result=r_v, - target=t_v, - message=("list item " + str(i) if message is None else message + "." + "list item " + str(i)), - precision=precision, - ) - else: - if precision is not None and isinstance(r_v, float): - assert rounded_equal(r_v, t_v, digits=precision), ( - "result and target have different values" - + "; r_v == " - + str(r_v) - + "(" - + str(type(r_v)) - + ")" - + "; t_v == " - + str(t_v) - + "(" - + str(type(t_v)) - + ")" - + ("" if message is None else " (" + message + ")") - + "; precision == " - + str(precision) - ) - elif torch.is_tensor(r_v) and torch.is_tensor(t_v): - assert (r_v == t_v).all().item(), ( - "result and target have different values" - + "; r_v == " - + str(r_v) - + "(" - + str(type(r_v)) - + ")" - + "; t_v == " - + str(t_v) - + "(" - + str(type(t_v)) - + ")" - + ("" if message is None else " (" + message + ")") - ) - else: - assert r_v == t_v, ( - "result and target have different values" - + "; r_v == " - + str(r_v) - + "(" - + str(type(r_v)) - + ")" - + "; t_v == " - + str(t_v) - + "(" - + str(type(t_v)) - + ")" - + ("" if message is None else " (" + message + ")") - ) - - -# Worker timeout *after* the first worker has completed. -PROCESS_TIMEOUT = 120 - - -def dist_init( - run_func, - master_port, - local_rank, - world_size, - return_dict, - *func_args, - **func_kwargs, -): - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(master_port) - os.environ["WORLD_SIZE"] = str(world_size) - # NOTE: unit tests don't support multi-node so local_rank == global rank - os.environ["RANK"] = str(local_rank) - os.environ["LOCAL_SLOT"] = str(local_rank) - logger.configure(LoggerConfig()) - run_func(return_dict=return_dict, *func_args, **func_kwargs) - - -def dist_launcher(run_func, world_size, master_port, *func_args, **func_kwargs): - """Launch processes and gracefully handle failures.""" - ctx = torch.multiprocessing.get_context("spawn") - manager = ctx.Manager() - return_dict = manager.dict() - # Spawn all workers on subprocesses. - processes = [] - for local_rank in range(world_size): - p = ctx.Process( - target=dist_init, - args=( - run_func, - master_port, - local_rank, - world_size, - return_dict, - *func_args, - ), - kwargs=func_kwargs, - ) - p.start() - processes.append(p) - - # Now loop and wait for a test to complete. The spin-wait here isn't a big - # deal because the number of processes will be O(#GPUs) << O(#CPUs). - any_done = False - any_failed = False - while not any_done: - for p in processes: - if not p.is_alive(): - any_done = True - if p.exitcode is not None: - any_failed = any_failed or (p.exitcode != 0) - - if any_failed: - for p in processes: - # If the process hasn't terminated, kill it because it hung. - if p.is_alive(): - p.terminate() - if p.is_alive(): - p.kill() - - # Wait for all other processes to complete - for p in processes: - p.join(PROCESS_TIMEOUT) - - # Collect exit codes and terminate hanging process - failures = [] - for rank, p in enumerate(processes): - if p.exitcode is None: - # If it still hasn't terminated, kill it because it hung. - p.terminate() - if p.is_alive(): - p.kill() - failures.append(f"Worker {rank} hung.") - elif p.exitcode < 0: - failures.append(f"Worker {rank} killed by signal {-p.exitcode}") - elif p.exitcode > 0: - failures.append(f"Worker {rank} exited with code {p.exitcode}") - - if len(failures) > 0: - raise RuntimeError("\n".join(failures)) - - return dict(return_dict)