Skip to content

Commit

Permalink
Merge branch 'main' into add_sliding_window_attn_to_torch_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Aug 15, 2024
2 parents ea37b73 + cc703c6 commit 0f44dae
Show file tree
Hide file tree
Showing 91 changed files with 676 additions and 701 deletions.
32 changes: 16 additions & 16 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union

from composer.callbacks import CheckpointSaver
from composer.core import Event, State, Time, Timestamp, TimeUnit
Expand Down Expand Up @@ -84,10 +84,10 @@ def get_run_name(training_run_name: str, current_interval: str) -> str:


def get_eval_parameters(
parameters: Dict[str, Any],
parameters: dict[str, Any],
checkpoint: str,
training_run_name: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Get the parameters needed for the eval run.
Args:
Expand Down Expand Up @@ -164,8 +164,8 @@ def validate_interval(


def validate_eval_run_config(
eval_run_config: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
eval_run_config: Optional[dict[str, Any]],
) -> dict[str, Any]:

if not eval_run_config:
return {}
Expand Down Expand Up @@ -220,9 +220,9 @@ class AsyncEval(CallbackWithConfig):

def __init__(
self,
train_config: Dict[str, Any],
train_config: dict[str, Any],
interval: Union[str, int, Time],
eval_run_config: Optional[Dict[str, Any]] = None,
eval_run_config: Optional[dict[str, Any]] = None,
):

# Run these during init to fail fast in any of the error cases
Expand Down Expand Up @@ -263,7 +263,7 @@ def __init__(

# Keep track of checkpoints that have already been evaled
# Format: {eval_timestamp: (checkpoint, run_name)}
self.checkpoints_evaled: Dict[Time, Tuple[str, str]] = {}
self.checkpoints_evaled: dict[Time, tuple[str, str]] = {}

# Scheduling is based on the check interval, while _get_checkpoints_and_launch_runs
# will only launch runs at the interval
Expand All @@ -279,7 +279,7 @@ def __init__(
f'interval {interval}, checking at {self.check_interval}',
)

def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
checkpoints_evaled = []
for eval_ts, (checkpoint, run_name) in self.checkpoints_evaled.items():
eval_ts_dict = {
Expand All @@ -292,7 +292,7 @@ def state_dict(self) -> Dict[str, Any]:
'checkpoints_evaled': checkpoints_evaled,
}

def load_state_dict(self, state_dict: Dict[str, Any]):
def load_state_dict(self, state_dict: dict[str, Any]):
previous_checkpoints_evaled = state_dict.get('checkpoints_evaled', [])
if previous_checkpoints_evaled:
for (eval_ts, checkpoint, run_name) in previous_checkpoints_evaled:
Expand All @@ -305,9 +305,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]):

@staticmethod
def _get_ready_sharded_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
remote_files: List[str],
) -> Dict[str, Timestamp]:
checkpointer_checkpoints: dict[str, Timestamp],
remote_files: list[str],
) -> dict[str, Timestamp]:
"""Identify checkpoints ready to be evaled based on remote files.
This has special logic for sharded checkpoints to consider checkpoints composed
Expand Down Expand Up @@ -349,9 +349,9 @@ def _get_ready_sharded_checkpoints(

@staticmethod
def _get_ready_single_checkpoints(
checkpointer_checkpoints: Dict[str, Timestamp],
remote_checkpoints: List[str],
) -> Dict[str, Timestamp]:
checkpointer_checkpoints: dict[str, Timestamp],
remote_checkpoints: list[str],
) -> dict[str, Timestamp]:
"""Identify checkpoints ready to be evaled based on remote checkpoints.
This is much simpler than the sharded case, because there is only one file
Expand Down
10 changes: 5 additions & 5 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import math
from enum import Enum
from typing import Dict, Optional
from typing import Optional

from composer.core import Callback, State
from composer.loggers import Logger
Expand All @@ -23,8 +23,8 @@ class Weighting(Enum):


def calculate_named_averages(
average_names: Dict[str, list],
category_scores: Dict[str, float],
average_names: dict[str, list],
category_scores: dict[str, float],
):
"""Calculates the named averages based off the raw category scores.
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.',
)

def extract_metrics_from_state(self, state: State) -> Dict[str, float]:
def extract_metrics_from_state(self, state: State) -> dict[str, float]:
results = {}

for key in self.logger_keys:
Expand All @@ -169,7 +169,7 @@ def extract_metrics_from_state(self, state: State) -> Dict[str, float]:

return {k: sum(v) / len(v) for k, v in results.items()}

def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
def eval_after_all(self, state: State, logger: Logger) -> dict[str, float]:
computed_metrics = self.extract_metrics_from_state(state)
if len(computed_metrics) == 0:
return {}
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union

import torch
from composer.core import Callback, State
Expand Down Expand Up @@ -59,7 +59,7 @@ def init(self, state: State, logger: Logger) -> None:
self.log_output_text = has_output_text

def eval_batch_end(self, state: State, logger: Logger) -> None:
if not isinstance(state.batch, Dict):
if not isinstance(state.batch, dict):
warnings.warn(
f"""EvalOutputLogging only supports batches that are dictionary. \
Found batch for type {type(state.batch)}. \
Expand All @@ -69,8 +69,8 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:

assert state.outputs is not None
assert state.metric_outputs is not None
logging_dict: Dict[str,
Union[List[Any], torch.Tensor,
logging_dict: dict[str,
Union[list[Any], torch.Tensor,
Sequence[torch.Tensor]],
] = deepcopy(
state.metric_outputs,
Expand Down
12 changes: 6 additions & 6 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
from multiprocessing.context import SpawnProcess
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -249,7 +249,7 @@ def __init__(
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.child_processes: List[SpawnProcess] = []
self.child_processes: list[SpawnProcess] = []
# Temporary save directory used by child_processes.
self.temp_save_dir = None

Expand Down Expand Up @@ -349,7 +349,7 @@ def transform_model_and_tokenizer(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
) -> tuple[PreTrainedModel, PreTrainedTokenizerBase]:
"""Transform the model and tokenizer before saving.
This allows a subclass to modify the model and tokenizer before saving. The base class implementation will
Expand Down Expand Up @@ -457,10 +457,10 @@ def _save_checkpoint(self, state: State, logger: Logger):
# Add hook to move tensors to cpu to avoid CUDA OOM
def tensor_hook(
module: nn.Module,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
dtensor_fqns = []
for fqn in state_dict.keys():
tensor = state_dict[fqn]
Expand Down Expand Up @@ -612,7 +612,7 @@ def tensor_hook(
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: Dict[str, Any] = {
model_saving_kwargs: dict[str, Any] = {
'path': local_save_path,
}
if self.using_peft:
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/callbacks/loss_perp_v_len_callback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional

import torch
from composer.core import Callback, State
Expand Down Expand Up @@ -150,7 +150,7 @@ def preprocess_metric_inputs(
logits: torch.Tensor,
seq_parallel_world_size: int,
seq_parallel_rank: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
del sequence_id, seq_parallel_rank
if seq_parallel_world_size > 1:
raise ValueError(
Expand Down Expand Up @@ -315,7 +315,7 @@ def update(
self.sum_perplexity_seq_id += torch.sum(perplexity, dim=(0, 1))
self.sum_length_seq_id += torch.sum(mask, dim=(0, 1))

def compute(self) -> Dict[str, torch.Tensor]:
def compute(self) -> dict[str, torch.Tensor]:
"""Aggregate the state over all processes to compute the metric.
Returns:
Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/callbacks/resumption_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List

from composer.core import Callback, State
from composer.loggers import Logger
Expand Down Expand Up @@ -69,7 +68,7 @@ class LayerFreezing(Callback):
layer_names (float): Names of layers to freeze.
"""

def __init__(self, layer_names: List[str]):
def __init__(self, layer_names: list[str]):
self.layer_names = set(layer_names)

def fit_start(self, state: State, logger: Logger) -> None:
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/command_utils/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import platform
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Iterable, Optional, Union

import datasets as hf_datasets
import psutil
Expand Down Expand Up @@ -39,7 +39,7 @@ class DataSplitConstants:
class DatasetConstants:
chars_per_sample: int
chars_per_token: int
splits: Dict[str, DataSplitConstants] = field(default_factory=dict)
splits: dict[str, DataSplitConstants] = field(default_factory=dict)

def __iter__(self):
for v in self.splits.values():
Expand Down Expand Up @@ -273,7 +273,7 @@ def build_dataloader(
def generate_samples(
loader: DataLoader,
truncate_num_samples: Optional[int] = None,
) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
) -> Iterable[Union[dict[str, bytes], dict[str, NDArray]]]:
"""Generator over samples of a dataloader.
Args:
Expand Down
14 changes: 7 additions & 7 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import urllib.parse
from collections import namedtuple
from concurrent.futures import ProcessPoolExecutor
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Optional, Union
from uuid import uuid4

import google.protobuf.any_pb2 as any_pb2
Expand Down Expand Up @@ -70,7 +70,7 @@

def to_cf(self: 'SparkConnectClient',
plan: 'pb2.Plan',
type: str = 'json') -> Tuple[List[Result], int, bool]:
type: str = 'json') -> tuple[list[Result], int, bool]:
"""Executes the query plans and return as presigned URLS for cloud fetch.
It can handle the current output formats that are supported by the server.
Expand Down Expand Up @@ -163,7 +163,7 @@ def to_cf(self: 'SparkConnectClient',


def collect_as_cf(self: 'DataFrame',
type: str = 'json') -> Tuple[List[Result], int, bool]:
type: str = 'json') -> tuple[list[Result], int, bool]:
"""Collects DataFrame execution plan as presigned URLs.
This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the
Expand Down Expand Up @@ -213,7 +213,7 @@ def run_query(
cursor: Optional['Cursor'] = None,
spark: Optional['SparkSession'] = None,
collect: bool = True,
) -> Optional[Union[List['Row'], 'DataFrame', 'SparkDataFrame']]:
) -> Optional[Union[list['Row'], 'DataFrame', 'SparkDataFrame']]:
"""Run SQL query via databricks-connect or databricks-sql.
Args:
Expand All @@ -240,7 +240,7 @@ def run_query(
raise ValueError(f'Unrecognized method: {method}')


def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable:
def get_args(signed: list, json_output_folder: str, columns: list) -> Iterable:
for i, r in enumerate(signed):
yield (i, r.url, json_output_folder, columns)

Expand All @@ -249,7 +249,7 @@ def download(
ipart: int,
url: str,
json_output_folder: str,
columns: Optional[List] = None,
columns: Optional[list] = None,
resp_format: str = 'arrow',
compressed: bool = False,
) -> None:
Expand Down Expand Up @@ -299,7 +299,7 @@ def download(
)


def download_starargs(args: Tuple) -> None:
def download_starargs(args: tuple) -> None:
return download(*args)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import platform
import warnings
from typing import Any, Callable, Dict, Iterable, Optional, Union
from typing import Any, Callable, Iterable, Optional, Union

import datasets as hf_datasets
import psutil
Expand Down Expand Up @@ -63,7 +63,7 @@ def build_dataloader(
def generate_samples(
loader: DataLoader,
truncate_num_samples: Optional[int] = None,
) -> Iterable[Dict[str, bytes]]:
) -> Iterable[dict[str, bytes]]:
"""Generator over samples of a dataloader.
Args:
Expand Down
Loading

0 comments on commit 0f44dae

Please sign in to comment.