-
Notifications
You must be signed in to change notification settings - Fork 71
/
fused_cross_entropy.py
423 lines (395 loc) · 15.7 KB
/
fused_cross_entropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
from typing import Any, Tuple
import torch
import torch.nn as nn
import triton
import triton.language as tl
from fla.utils import contiguous
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
@triton.heuristics({
"HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
})
@triton.jit
def cross_entropy_fwd_kernel(
loss_ptr, # data ptrs
lse_ptr,
z_loss_ptr,
logits_ptr,
labels_ptr,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
n_rows,
logits_row_stride, # strides
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf"))
logits = logits.to(tl.float32) * logit_scale
max_logits = tl.max(logits, 0)
if HAS_SMOOTHING:
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
if label_idx == ignore_index:
loss = 0.0
z_loss = 0.0
else:
label_idx -= class_start_idx
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
n_cols, (col_block_idx + 1) * BLOCK_SIZE
):
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
- label_smoothing * sum_logits / total_classes
- (1 - label_smoothing) * logits_label
)
else:
loss = (lse if not SPLIT else 0.0) - logits_label
else:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the label_smoothing loss
if HAS_SMOOTHING:
loss = label_smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
else:
loss = 0.0
if not SPLIT:
z_loss = lse_square_scale * lse * lse
loss += z_loss
else:
z_loss = 0.0
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
if not SPLIT:
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
@triton.heuristics({
"HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0,
})
@triton.jit
def cross_entropy_bwd_kernel(
dlogits_ptr, # data ptrs
dloss_ptr,
logits_ptr,
lse_ptr,
labels_ptr,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
dlogits_row_stride,
dloss_row_stride,
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignore_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
label_idx -= class_start_idx
if HAS_SMOOTHING:
smooth_negative = label_smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - (1 - label_smoothing), probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
def fused_cross_entropy_forward(
logits: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignore_index: int = -100,
process_group=None,
):
n_rows, n_cols = logits.shape
assert target.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
if logits.stride(-1) != 1:
logits = logits.contiguous()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE = 64 * 1024
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
num_warps = (
4
if BLOCK_SIZE < 2048
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_fwd_kernel[(n_rows, n_splits)](
losses, # data ptrs
lse,
z_losses,
logits,
target,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx,
n_cols, # shapes
n_rows,
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
SPLIT=split
)
if split:
# If there's no label_smoothing, if target are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's label_smoothing=0.1, for target in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For target not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if n_splits > 1:
lse = torch.logsumexp(lse, dim=0)
losses = losses.sum(dim=0)
if world_size > 1:
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
handle_losses.wait()
# After the allreduce, if there's no label_smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's label_smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses += lse
if lse_square_scale != 0.0:
z_losses = lse_square_scale * lse.square()
z_losses.masked_fill_(target == ignore_index, 0.0)
losses += z_losses
else:
z_losses = torch.zeros_like(losses)
losses.masked_fill_(target == ignore_index, 0.0)
return losses, z_losses, lse, total_classes, class_start_idx
class CrossEntropyLossFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
logits,
target,
label_smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignore_index=-100,
inplace_backward=False,
process_group=None,
):
losses, z_losses, lse, total_classes, class_start_idx = fused_cross_entropy_forward(
logits,
target,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
process_group,
)
ctx.save_for_backward(logits, lse, target)
ctx.mark_non_differentiable(z_losses)
ctx.label_smoothing = label_smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignore_index = ignore_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
return losses, z_losses
@staticmethod
@contiguous
def backward(ctx, grad_losses, grad_z_losses):
del grad_z_losses # z_losses are only for logging.
logits, lse, target = ctx.saved_tensors
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
n_rows, n_cols = logits.shape
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_bwd_kernel[grid](
dlogits, # data ptrs
grad_losses,
logits,
lse,
target,
ctx.label_smoothing,
ctx.logit_scale,
ctx.lse_square_scale,
ctx.ignore_index,
ctx.total_classes,
ctx.class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
dlogits.stride(0),
grad_losses.stride(0),
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits, None, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
target: torch.Tensor,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignore_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
logits: [batch, vocab_size]
target: [batch,]
label_smoothing: float
logit_scale: float.
Multiply logits by this scale before calculating the loss.
lse_square_scale: float.
If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignore_index: int.
If target == ignore_index, the loss is set to 0.0.
inplace_backward: bool.
If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group:
if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: [batch,], float
z_losses: [batch,], float
"""
return CrossEntropyLossFunction.apply(
logits,
target,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
inplace_backward,
process_group,
)
class FusedCrossEntropyLoss(nn.Module):
def __init__(
self,
ignore_index: int = -100,
reduction: str = "mean",
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
inplace_backward: bool = False,
process_group: Any = None,
return_z_loss: bool = False,
):
"""
Arguments:
ignore_index: int. If target == ignore_index, the loss is set to 0.0.
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
return_z_loss: bool. If True, we return the component of the loss contributed by
the lse_square_scale value. This value is only for logging and does not support
backprop.
"""
super().__init__()
if reduction not in ["mean", "none", "sum"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.logit_scale = logit_scale
self.lse_square_scale = lse_square_scale
self.inplace_backward = inplace_backward
self.process_group = process_group
self.return_z_loss = return_z_loss
def forward(self, input, target):
"""
Arguments:
input: (batch, vocab_size)
target: (batch,)
Returns:
losses: (batch,) if reduction is 'none', else (1,), dtype float
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
"""
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
loss, z_loss = cross_entropy_loss(
input,
target,
label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale,
ignore_index=self.ignore_index,
inplace_backward=self.inplace_backward,
process_group=self.process_group,
)
if self.reduction == "mean":
loss = loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
loss = loss.sum()
else:
loss = loss
if not self.return_z_loss:
return loss
if self.reduction == "mean":
z_loss = z_loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
z_loss = z_loss.sum()
else:
z_loss = z_loss
return loss, z_loss