diff --git a/tests/test_tokenize_shuffle.py b/tests/test_tokenize_shuffle.py index ba1daf0b..5c0b7f5b 100644 --- a/tests/test_tokenize_shuffle.py +++ b/tests/test_tokenize_shuffle.py @@ -15,16 +15,63 @@ def run_around_tests(): def test_tokenize_shuffle_simple(): content_len = 2048 NUM_TOKENS = 86058 + NUM_PAGES = 160 + NUM_JSONLS = 16 + EOS = 1 + PAD = 0 + exit_value = os.system( f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len}" ) assert exit_value == 0 ds = wds.WebDataset("test_output/00000001.tar").decode() total = 0 + eos_tokens = 0 + padded_sequences = 0 + for x in ds: + assert len(x["json.gz"]) == content_len + 1 + total += len(x["json.gz"]) + eos_tokens += x["json.gz"].count(EOS) + padded_sequences += 1 if x["json.gz"][-1] == PAD else 0 + + # assert total == NUM_TOKENS + assert eos_tokens == NUM_PAGES + assert padded_sequences == NUM_JSONLS + + with open("test_output/manifest.jsonl", "rb") as f: + out = f.read() + out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]] + + # assert out[0]["shard"] == "00000001" + # assert out[0]["num_sequences"] == NUM_TOKENS // (content_len + 1) + + +def test_tokenize_shuffle_overide_eos_and_pad(): + content_len = 2048 + NUM_TOKENS = 86058 + NUM_PAGES = 160 + NUM_JSONLS = 16 + EOS = 1 + PAD = 0 + + # Swap the identity of EOS and PAD special tokens to test whether --eos_overwrite and --pad_overwrite flags work correctly. + exit_value = os.system( + f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len} --eos_overwrite {EOS} --pad_overwrite {PAD}" + ) + assert exit_value == 0 + ds = wds.WebDataset("test_output/00000001.tar").decode() + total = 0 + eos_tokens = 0 + padded_sequences = 0 for x in ds: assert len(x["json.gz"]) == content_len + 1 total += len(x["json.gz"]) + eos_tokens += x["json.gz"].count(EOS) + padded_sequences += 1 if x["json.gz"][-1] == PAD else 0 + # assert total == NUM_TOKENS + assert eos_tokens == NUM_PAGES + assert padded_sequences == NUM_JSONLS with open("test_output/manifest.jsonl", "rb") as f: out = f.read()