-
Notifications
You must be signed in to change notification settings - Fork 1
/
split_scps.py
executable file
·106 lines (87 loc) · 3.22 KB
/
split_scps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
import argparse
import logging
import sys
from collections import Counter
from itertools import zip_longest
from pathlib import Path
from typing import List, Optional
from espnet.utils.cli_utils import get_commandline_args
def split_scps(
scps: List[str],
num_splits: int,
names: Optional[List[str]],
output_dir: str,
log_level: str,
):
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if num_splits < 2:
raise RuntimeError(f"{num_splits} must be more than 1")
if names is None:
names = [Path(s).name for s in scps]
if len(set(names)) != len(names):
raise RuntimeError(f"names are duplicated: {names}")
for name in names:
(Path(output_dir) / name).mkdir(parents=True, exist_ok=True)
scp_files = [open(s, "r", encoding="utf-8") for s in scps]
# Create output files in 'w' mode to overwrite existing files if any
out_files = {
name: {
num: (Path(output_dir) / name / f"split.{num}").open("w", encoding="utf-8")
for num in range(num_splits)
}
for name in names
}
counter = Counter()
linenum = -1
for linenum, lines in enumerate(zip_longest(*scp_files)):
if any(line is None for line in lines):
raise RuntimeError("Number of lines are mismatched")
prev_key = None
for line in lines:
key = line.rstrip().split(maxsplit=1)[0]
if prev_key is not None and prev_key != key:
raise RuntimeError("Not sorted or not having same keys")
prev_key = key
# Select a piece from split texts alternatively
num = linenum % num_splits
counter[num] += 1
# Write lines respectively
for line, name in zip(lines, names):
out_files[name][num].write(line)
if linenum + 1 < num_splits:
raise RuntimeError(
f"The number of lines is less than num_splits: {linenum + 1} < {num_splits}"
)
for name in names:
with (Path(output_dir) / name / "num_splits").open("w", encoding="utf-8") as f:
f.write(str(num_splits))
logging.info(f"N lines of split text: {set(counter.values())}")
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Split scp files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--scps", required=True, help="Input texts", nargs="+")
parser.add_argument("--names", help="Output names for each files", nargs="+")
parser.add_argument("--num_splits", help="Split number", type=int)
parser.add_argument("--output_dir", required=True, help="Output directory")
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
split_scps(**kwargs)
if __name__ == "__main__":
main()