Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss generating token counts #1610

Merged
merged 12 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion llmfoundry/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_data_spec(

def get_tokens_per_batch_func(
decoder_only: bool = True,
) -> Callable[[Batch], int]:
) -> Callable[[Batch], Union[int, dict[str, int]]]:
"""Returns a callable that counts the number of tokens in a batch.

Args:
Expand Down Expand Up @@ -114,13 +114,28 @@ def get_num_tokens_in_batch(batch: Batch) -> int:
else:
input_ids_tokens = batch['input_ids'].numel()

loss_generating_tokens = 0
if 'labels' in batch:
loss_generating_tokens = int(
torch.sum(batch['labels'] != -100).item(),
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
)

# Subtract one for each example in the batch because the labels
# will be shifted by one
loss_generating_tokens -= batch['labels'].shape[0]

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_attention_mask']).item(),
)

if loss_generating_tokens != 0:
return {
'total': input_ids_tokens + decoder_input_ids_tokens,
'loss_generating': loss_generating_tokens,
}
return input_ids_tokens + decoder_input_ids_tokens

return get_num_tokens_in_batch
Expand Down
13 changes: 11 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ def test_token_counting_func_dataloader_setting(

batch_strings = []
expected_token_count = 0
expected_loss_generating_token_count = 0
for _ in range(batch_size):
# Get randomly different lengths if we are going to add padding
sample_length = random.randint(1, model_max_length // 4) if (
Expand All @@ -1208,8 +1209,14 @@ def test_token_counting_func_dataloader_setting(
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy() # type: ignore
batch_tokenized = [{'turns': [b]} for b in batch_tokenized]
expected_loss_generating_token_count = expected_token_count
expected_token_count *= 2
expected_token_count += 1 * batch_size # for the eos token
expected_loss_generating_token_count += 1 * batch_size # for the eos token
else:
expected_loss_generating_token_count = expected_token_count

expected_loss_generating_token_count -= 1 * batch_size # because the labels will be shifted

common_args = {
'drop_last': False,
Expand Down Expand Up @@ -1311,9 +1318,11 @@ def build_from_hf(
raise NotImplementedError()

batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore
actual_token_count = dl.get_num_tokens_in_batch(batch_collated)
actual_total_token_count = dl.get_num_tokens_in_batch(batch_collated, token_type='total')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i might be missing something, but how can we pass in token_type here when it's not in the function definition here? https://github.com/mosaicml/llm-foundry/pull/1610/files#diff-9568d89aed75ca69416abe2a592c6bb9732129049a62c34e4e9263c18495a236R99

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function being called here is actually defined on the DataSpec class in Composer (https://github.com/mosaicml/composer/blob/28756dd52e96371689b764cb72c336406460ad35/composer/core/data_spec.py#L301). The DataSpec takes in a function from the user and uses it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of the reason for doing it this way was to maintain backwards compatibility with any existing user defined get_num_tokens_in_batch functions out there.

actual_loss_generating_token_count = dl.get_num_tokens_in_batch(batch_collated, token_type='loss_generating')

assert actual_token_count == expected_token_count
assert actual_total_token_count == expected_token_count
assert actual_loss_generating_token_count == expected_loss_generating_token_count


def test_build_unknown_dataloader():
Expand Down
Loading