Skip to content

Commit

Permalink
updated tests for PAD/EOS
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffrey committed May 24, 2024
1 parent e2d7447 commit d33d1e1
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,63 @@ def run_around_tests():
def test_tokenize_shuffle_simple():
content_len = 2048
NUM_TOKENS = 86058
NUM_PAGES = 160
NUM_JSONLS = 16
EOS = 1
PAD = 0

exit_value = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len}"
)
assert exit_value == 0
ds = wds.WebDataset("test_output/00000001.tar").decode()
total = 0
eos_tokens = 0
padded_sequences = 0
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])
eos_tokens += x["json.gz"].count(EOS)
padded_sequences += 1 if x["json.gz"][-1] == PAD else 0

# assert total == NUM_TOKENS
assert eos_tokens == NUM_PAGES
assert padded_sequences == NUM_JSONLS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]

# assert out[0]["shard"] == "00000001"
# assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1)


def test_tokenize_shuffle_overide_eos_and_pad():
content_len = 2048
NUM_TOKENS = 86058
NUM_PAGES = 160
NUM_JSONLS = 16
EOS = 1
PAD = 0

# Swap the identity of EOS and PAD special tokens to test whether --eos_overwrite and --pad_overwrite flags work correctly.
exit_value = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len} --eos_overwrite {EOS} --pad_overwrite {PAD}"
)
assert exit_value == 0
ds = wds.WebDataset("test_output/00000001.tar").decode()
total = 0
eos_tokens = 0
padded_sequences = 0
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])
eos_tokens += x["json.gz"].count(EOS)
padded_sequences += 1 if x["json.gz"][-1] == PAD else 0

# assert total == NUM_TOKENS
assert eos_tokens == NUM_PAGES
assert padded_sequences == NUM_JSONLS

with open("test_output/manifest.jsonl", "rb") as f:
out = f.read()
Expand Down

0 comments on commit d33d1e1

Please sign in to comment.