Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight support for LigerCrossEntropy #420

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,22 @@ def liger_cross_entropy_kernel(
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
weight_stride,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
sum_of_non_ignore_weight,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
Expand All @@ -50,18 +54,22 @@ def liger_cross_entropy_kernel(
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
weight_stride (int): The stride of the weight tesnor.
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
sum_of_non_ignore_weight (float): The denominator when `reduction="mean"` if `weight` is given.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to dteremine whether assigning weight to each of the classes.
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

Expand All @@ -86,6 +94,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
weight = tl.load(weight_ptr + y).cast(tl.float32)

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

Expand Down Expand Up @@ -162,7 +173,12 @@ def liger_cross_entropy_kernel(
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
if HAS_WEIGHT:
X_block = X_block / (sum_of_non_ignore_weight)
else:
X_block = X_block / (n_non_ignore)
if HAS_WEIGHT:
X_block = X_block * weight
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
# chain rule
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
Expand Down Expand Up @@ -201,8 +217,16 @@ def liger_cross_entropy_kernel(
loss += z_loss
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore
if HAS_WEIGHT:
z_loss = z_loss / sum_of_non_ignore_weight
loss = loss / sum_of_non_ignore_weight
else:
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore

if HAS_WEIGHT:
z_loss = z_loss * weight
loss = loss * weight

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
Expand All @@ -224,6 +248,7 @@ def liger_cross_entropy_kernel(
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -254,6 +279,23 @@ def cross_entropy_forward(
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False

n_non_ignore = (target != ignore_index).sum().item()
sum_of_non_ignore_weight = n_non_ignore
if weight is not None:
assert (
weight.shape[0] == V
), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
if ignore_index >= 0 and ignore_index < V:
weight_mask = torch.ones_like(weight)
weight_mask[ignore_index] = 0
selected_weight = torch.gather(weight * weight_mask, dim=-1, index=target)
pramodith marked this conversation as resolved.
Show resolved Hide resolved
else:
selected_weight = torch.gather(weight, dim=-1, index=target)
sum_of_non_ignore_weight = selected_weight.sum().item()
Copy link
Collaborator Author

@Tcc0403 Tcc0403 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can rewrite it with torch.masked_select

sum_of_non_ignore_weight = (torch.gather(weight, dim=0, index=target.masked_select(target_mask))
            .sum()
            .item()
        )

Refer to torch's impl mentioned above

if weight.stride(-1) != 1:
weight = weight.contiguous()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
Expand All @@ -267,18 +309,22 @@ def cross_entropy_forward(
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
weight_stride=weight.stride(-1) if weight is not None else 0,
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
sum_of_non_ignore_weight=sum_of_non_ignore_weight,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -329,6 +375,7 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -343,6 +390,7 @@ def forward(
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size C and floating point dtype
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand All @@ -356,6 +404,7 @@ def forward(
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -397,4 +446,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
4 changes: 4 additions & 0 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,21 @@ def fused_linear_cross_entropy_forward(
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=_input, # dummy ptr, not used
weight_stride=0,
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
sum_of_non_ignore_weight=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -30,6 +31,7 @@ def __init__(
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.weight = weight
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
Expand All @@ -41,6 +43,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss = LigerCrossEntropyFunction.apply(
_input,
target,
self.weight,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def liger_cross_entropy(
loss, z_loss = LigerCrossEntropyFunction.apply(
input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down
67 changes: 66 additions & 1 deletion test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,27 @@ def _test_correctness_with_z_loss_with_other_params_once(
assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_weight_once(
target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol
):
torch.manual_seed(0)
torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction)

_tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

output = torch_ce(_input, target)
output2 = target_ce(_input2, target)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_not_last_layer_once(
target_ce, B, T, V, reduction, scalar, dtype, atol, rtol
):
Expand Down Expand Up @@ -345,6 +366,7 @@ def _test_correctness_functional(
y1, y1_z = liger_cross_entropy(
x1,
target,
None,
ignore_index=0,
lse_square_scale=1e-4,
label_smoothing=0.1,
Expand All @@ -353,7 +375,7 @@ def _test_correctness_functional(
return_z_loss=True,
)
y2, y2_z = LigerCrossEntropyFunction.apply(
x2, target, 0, 1e-4, 0.1, "mean", 30.0, True
x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True
)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
Expand Down Expand Up @@ -687,6 +709,41 @@ def test_correctness_with_z_loss_with_other_params_once(
)


@pytest.mark.parametrize(
"B, T, V",
[
(2, 4096, 32000), # llama2, mistral
# # weird shapes
(3, 423, 32000),
],
)
@pytest.mark.parametrize("weight", [0.5, 0.1])
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("reduction", ["sum", "mean"])
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(1.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness_with_weight_once(
B, T, V, weight, reduction, scalar, dtype, atol, rtol
):
weight = torch.rand(V, device=device, dtype=dtype)
test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction)
_test_correctness_with_weight_once(
test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol
)


@pytest.mark.parametrize(
"B, T, V",
[
Expand Down Expand Up @@ -746,17 +803,21 @@ def test_float32_internal():
X_stride=X_bf16.stride(-2),
Y_ptr=Y,
Y_stride=Y.stride(-1),
weight_ptr=X_bf16, # dummy ptr, not used
weight_stride=X_bf16.stride(-2),
z_loss_ptr=loss_bf16, # dummy ptr, not used
loss_ptr=loss_bf16,
loss_stride=loss_bf16.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
sum_of_non_ignore_weight=n_non_ignore, # not used
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
Expand All @@ -770,17 +831,21 @@ def test_float32_internal():
X_stride=X_fp32.stride(-2),
Y_ptr=Y,
Y_stride=Y.stride(-1),
weight_ptr=X_fp32, # dummy ptr, not used
weight_stride=X_fp32.stride(-2),
loss_ptr=loss_fp32,
z_loss_ptr=loss_fp32, # dummy ptr, not used
loss_stride=loss_fp32.stride(-1),
n_cols=n_cols,
n_non_ignore=n_non_ignore,
sum_of_non_ignore_weight=n_non_ignore, # not used
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
Expand Down
Loading