Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grad accum when using loss_mask #842

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config/backpack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ model:
trainer:
tracker:
project: "levanter"
tags: [ "openwebtext", "backpack" ]
tags: ["openwebtext", "backpack"]

mp: p=f32,c=bfloat16

Expand All @@ -21,5 +21,5 @@ trainer:
model_axis_size: 1

optimizer:
learning_rate: 6E-4
learning_rate: 6e-4
weight_decay: 0.1
40 changes: 21 additions & 19 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tqdm_loggable import tqdm_logging
from tqdm_loggable.auto import tqdm

import haliax as hax
import haliax.nn
from haliax import NamedArray, is_named_array
from haliax.jax_utils import is_jax_array_like
Expand All @@ -30,6 +31,8 @@
from levanter.utils import flop_utils, jax_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.utils.logging import save_xla_dumps_to_wandb
from levanter.utils.stat_utils import RunningMean
from levanter.utils.types import Extras
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


Expand Down Expand Up @@ -145,10 +148,8 @@ async def compute_length():


def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
total_load_time = 0.0
total_loss_time = 0.0
n = 0
loss = RunningMean(jnp.zeros(()), jnp.zeros(()))
extras: Extras = {}

if name is not None:
desc = f"eval {name}"
Expand All @@ -159,28 +160,27 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n
pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches)

iter_ = iter(pbar)
n = 0
while True:
time_in = time.time()
n += 1
batch = next(iter_, None)
if batch is None:
break
load_time = time.time() - time_in
total_load_time += load_time
loss = loss_fn(model, batch)
total_loss += loss.item()
n += 1
loss_time = time.time() - time_in - load_time
total_loss_time += loss_time
losses, where, extras = loss_fn(model, batch)
mean_loss = hax.mean(losses, where=where)
loss += RunningMean(mean_loss, where.sum())
for k, v in extras.items():
if k not in extras:
extras[k] = v
else:
extras[k] += v

pbar.set_postfix(loss=total_loss / n)
pbar.set_postfix(loss=loss.mean.item())

if max_batches is not None and n >= max_batches:
break

if n > 0:
total_loss /= n

return total_loss
return loss.item(), {k: v.item() for k, v in extras.items()}


def compute_validation_loss(
Expand All @@ -190,12 +190,14 @@ def compute_validation_loss(
name: Optional[str] = None,
):
def compute_loss(info: StepInfo):
loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)
loss, extras = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)

prefix = "eval"
if name:
prefix += "/" + name
levanter.tracker.log({f"{prefix}/loss": loss}, step=info.step)
levanter.tracker.log(
{f"{prefix}/loss": loss} | {f"{prefix}/{k}": v for k, v in extras.items()}, step=info.step
)

if name:
logger.info(f"{name} validation loss: {loss:.3f}")
Expand Down
9 changes: 7 additions & 2 deletions src/levanter/doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,17 @@ def doremi_step(state: DoremiState, ref, batch, domains):
proxy = inference_mode(state.model, False)
with hax.axis_mapping(trainer.compute_axis_mapping):
# calculate per-token losses for proxy and ref
proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy)
ref_losses = loss_fn(ref, batch, reduction_axis=())
def scalar_loss_fn(p, batch):
ret, _, _ = loss_fn(p, batch)
return ret

proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: scalar_loss_fn(p, batch), proxy)
ref_losses = scalar_loss_fn(ref, batch)

# calculate excess losses, aggregate per-domain losses
excess_losses = proxy_losses - ref_losses
clipped_losses = hax.maximum(excess_losses, 0)
print(clipped_losses.shape)
per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains)

# Update domain weights
Expand Down
3 changes: 1 addition & 2 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,7 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample,
m = self.mp.cast_to_compute(m)

with hax.axis_mapping(axis_mapping):
losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=())
mask = batch.loss_mask # [Batch, Pos]
losses, mask, _extras = compute_next_token_loss(m, batch)
this_tokens = hax.sum(mask)
this_loss = hax.einsum("->", losses, mask) # to scalar

Expand Down
13 changes: 2 additions & 11 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import levanter.tracker
from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
from levanter.models.gpt2 import Gpt2Config
from levanter.models.loss import next_token_loss
from levanter.utils.hf_utils import HfTokenizer


Expand All @@ -58,7 +57,7 @@
import levanter.config
from levanter.checkpoint import load_checkpoint
from levanter.data import AsyncDataset, DataLoader
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss
from levanter.trainer import StepInfo, TrainerConfig
from levanter.utils.jax_utils import use_cpu_device
from levanter.utils.tree_utils import inference_mode
Expand Down Expand Up @@ -157,15 +156,7 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr
logits = logits.astype(jnp.float32)
Pos = logits.resolve_axis(self.EvalPos.name)

loss = next_token_loss(
Pos=Pos,
Vocab=model.Vocab,
logits=logits,
true_ids=example.tokens,
loss_mask=example.loss_mask,
reduction=hax.sum,
reduction_axis=Pos,
)
loss, _, _ = compute_next_token_loss(model, example)

not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool)
pred_targets = hax.argmax(logits, axis=model.Vocab)
Expand Down
64 changes: 38 additions & 26 deletions src/levanter/grad_accum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import functools
from typing import Callable, Optional, ParamSpec, TypeVar

Expand All @@ -9,34 +8,31 @@
from jax.sharding import PartitionSpec

import haliax as hax
import haliax.quantization as hq
from haliax import Axis
from haliax.partitioning import ResourceAxis
from haliax.util import is_named_array

from levanter.utils.jax_utils import zeros_like_tree
from levanter.utils.types import ComputeLossFunction


Args = ParamSpec("Args")
R = TypeVar("R")


class ReductionType(enum.Enum):
SUM = enum.auto()
MEAN = enum.auto()
# TODO: add MAX?
M_con = TypeVar("M_con", contravariant=True) # Model
X = TypeVar("X", contravariant=True) # Input


# TODO: should we use a custom_jvp on microbatched?

# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617
def microbatched(
fn: Callable[Args, R],
loss_fn: ComputeLossFunction[M_con, X],
Batch: Axis,
microbatch_size: int,
accum_axis_mapping,
compute_axis_mapping,
patch_in_rng_key: Optional[str] = "key",
reduce: ReductionType = ReductionType.MEAN,
accum_dtype: Optional[jnp.dtype] = None,
) -> Callable[Args, R]:
"""
Expand Down Expand Up @@ -78,20 +74,32 @@ def microbatched(
num_micro_steps = batch_size // microbatch_size

if num_micro_steps == 1:
return fn

@functools.wraps(loss_fn)
def no_accum_loss_fn(*args, **kwargs):
losses, where, extras = loss_fn(*args, **kwargs)
seen_tokens = where.sum().scalar()
extras["seen_tokens"] = seen_tokens
return hax.mean(losses, where=where).scalar(), extras

return eqx.filter_value_and_grad(no_accum_loss_fn, has_aux=True)

Microbatch = Batch.resize(microbatch_size)
AccumStep = Axis("accum_step", num_micro_steps)
assert num_micro_steps * microbatch_size == batch_size

if reduce not in ReductionType:
raise ValueError(f"accum_type must be one of {ReductionType}")
@functools.wraps(loss_fn)
def accum_loss_fn(*args, **kwargs):
losses, where, extras = loss_fn(*args, **kwargs)
return hax.sum(losses, where=where).scalar(), (where.sum(), extras)

@functools.wraps(fn)
grad_fn = eqx.filter_value_and_grad(accum_loss_fn, has_aux=True)

@functools.wraps(grad_fn)
def wrapped_fn(*args, **kwargs):

# first, determine the shape and make accumulator arrays
r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
r_shape = eqx.filter_eval_shape(grad_fn, *args, **kwargs)
acc = zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype)

# then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...)
Expand All @@ -106,30 +114,34 @@ def wrapped_fn(*args, **kwargs):
args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping)

def loop(acc, microbatch_and_key):
(loss, (total, extras)), grads = acc
microbatch, microbatch_kwargs, key = microbatch_and_key
with jax.named_scope("compute"):
microbatch_kwargs = microbatch_kwargs.copy()
if key is not None:
microbatch_kwargs[patch_in_rng_key] = key
this_r = fn(*microbatch, **microbatch_kwargs)
(loss_mb, (n_mb, extras_mb)), grads_mb = grad_fn(*microbatch, **microbatch_kwargs)

with jax.named_scope("accum"):
import haliax.quantization as hq

# TODO: this uses the latest value for the scale for fp8, which seems not ideal but probably ok?
overwrites, updates = hq.partition_for_grad_overwrite(this_r)
acc = hq.apply_updates(acc, updates, overwrites)
acc = hax.shard_with_axis_mapping(acc, accum_axis_mapping)
overwrites, updates = hq.partition_for_grad_overwrite(grads_mb)
grads = hq.apply_updates(grads, updates, overwrites)
grads = hax.shard_with_axis_mapping(grads, accum_axis_mapping)
loss += loss_mb
total += n_mb

return acc
return (loss, (total, {k: v + extras_mb[k] for k, v in extras.items()})), grads

with jax.named_scope("microbatched"):
acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key))

if reduce == ReductionType.MEAN:
acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc)

return acc
(loss, (total, extras)), grads, = hax.fold(
loop, AccumStep
)(acc, (args, kwargs, key))
grads = jax.tree_util.tree_map(lambda x: x / total, grads)
loss /= total
extras["seen_tokens"] = total

return (loss, extras), grads

return wrapped_fn

Expand Down
4 changes: 1 addition & 3 deletions src/levanter/main/train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def compute_loss(
example: AudioTextExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jax.numpy.ndarray | hax.NamedArray:
return m.compute_loss(example, key=key, reduction=reduction, reduction_axis=reduction_axis)
return m.compute_loss(example, key=key)

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(config: VizGpt2Config):
def compute_log_probs(model: LmHeadModel, example: LmExample):
model = inference_mode(model, True)
model = mp.cast_to_compute(model)
logprobs = compute_next_token_loss(model, example, reduction=None)
logprobs, where, _ = compute_next_token_loss(model, example)
# roll forward to get the loss for each predicted token
logprobs = hax.roll(logprobs, 1, Pos)
return logprobs.rearrange((EvalBatch, Pos)).array
Expand Down
12 changes: 7 additions & 5 deletions src/levanter/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmConfig
from levanter.utils.types import Extras


class AudioTextExample(eqx.Module):
Expand Down Expand Up @@ -97,9 +98,7 @@ def compute_loss(
example: AudioTextExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
) -> jnp.ndarray | NamedArray:
) -> tuple[jnp.ndarray | NamedArray, NamedArray, Extras]:
"""
Computes the cross-entropy loss for predicted ASR tokens. If reduction is not None, the loss is reduced
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
Expand All @@ -110,10 +109,13 @@ def compute_loss(
targets = hax.roll(example.tokens, -1, axis=self.Pos.name)
target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype)
loss = cross_entropy_loss(
logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask
logits,
self.Vocab,
target_y,
reduction=None,
)

return loss
return loss, example.loss_mask, {}

@property
def vocab_size(self) -> int:
Expand Down
Loading
Loading