Skip to content

Commit

Permalink
added some scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpost committed May 27, 2023
1 parent 633fb60 commit e30a130
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 0 deletions.
54 changes: 54 additions & 0 deletions bin/extract_sent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

"""
Two run modes:
(a) Take a parallel stream, index + document, and extract sentence {index} from it
(b) Take a stream of documents, and a single index -i {index}, and extract {index}
That is, (b) fixes the index, where (a) it can be variable.
Used for studying the effect of forward and backward context on contragen.
"""

import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from docmt import data, DOC_SEPARATOR

def extract_sent(line, separator=DOC_SEPARATOR, index=None, proportional=False):
if proportional:
source, target = line.rstrip().split("\t")
source_pct = len(source.split(separator)[-1].split()) / len(source.split())
target_tokens = target.split()
num_target_tokens = int(source_pct * len(target_tokens))
target = " ".join(target_tokens[-num_target_tokens:])
# print(source_pct, source, target, sep="\n", file=sys.stderr)
return target
else:
index = index
if index is None:
fields = line.split("\t", maxsplit=1)
index = int(fields[0])
line = fields[1]

sents = line.rstrip().split(separator)
return sents[min(index, len(sents) - 1)].strip()


def main(args):
for line in sys.stdin:
print(extract_sent(line, index=args.index, separator=args.separator, proportional=args.proportional))


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--proportional", action="store_true", help="Expects source TAB target")
parser.add_argument("--separator", default=DOC_SEPARATOR)
parser.add_argument("--index", "-i", type=int, default=None, help="Index to extract; if None, read field from STDIN")
args = parser.parse_args()

main(args)

126 changes: 126 additions & 0 deletions bin/pack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3

"""Creates documents from monolingual data. Used to transform
sentences into sentences-with-context for document translation sytems.
There are two modes: chunking input lines within a document (--chunk),
or applying a sliding window (default).
The output is four columns: the docid, the start and end lines in the
original document, and the merged document string.
Context length is infinite by default, but controlled by two variables:
--max-tokens: the maximum number of tokens in the entire line
--max-sents: the maximum number of sentences (including the current one)
Length is determined by white-space-delimited tokens or by sentence-
piece tokens, if --spm-model is applied.
By default, separates sentences on a line with " <eos>", but you can
override this with "--separator".
Example usage:
# sliding window, separate with spaces
paste source.txt docids.txt \
| pack.py --max-sents 5 --max-tokens 200 --spm-model /path/to/spm --separator " "
# chunked, no sentence limit, separate with " <eos>"
paste source.txt docids.txt \
| pack.py --chunk --max-tokens 200 --spm-model /path/to/spm
"""

import os
import sys

from typing import List, Iterable, Tuple

sys.path.append(os.path.join(os.path.dirname(sys.argv[0]), ".."))
from docmt import read_docs


def main(args):

spm = None
if args.spm_model is not None:
import sentencepiece as sp
spm = sp.SentencePieceProcessor(model_file=args.spm_model)

def count_tokens(line):
return len(spm.encode(line) if spm else line.split())

def get_context(line: str,
context: List[Tuple]):
"""
Takes the line, and adds context until max_tokens is reached.
:return: The merged line, and the number of lines removed from the context.
"""

num_lines_context = len(context)
lines = [text[0] for text in context] + [line]
lens = [count_tokens(line) + 1 for line in lines] # add 1 for <eos> token
while len(lines) > 1 and sum(lens) > args.max_tokens:
lens.pop(0)
lines.pop(0)
if args.max_sents is not None:
while len(lines) > args.max_sents + 1:
lines.pop(0)
return args.separator.join(lines), num_lines_context - (len(lines) - 1)

def chunk_doc(doc: List[Tuple]):
"""
:return: The chunked document, number of lines.
"""
lens = [count_tokens(line[0]) + 1 for line in doc]
subdoc = []
subdoclen = 0
for segment in doc:
segmentlen = count_tokens(segment[0]) + 1

if (args.max_tokens != 0 and (subdoclen + segmentlen > args.max_tokens)) or \
(args.max_sents is not None and len(subdoc) >= args.max_sents):

yield args.separator.join(subdoc), len(subdoc)
subdoc = []
subdoclen = 0

subdoc.append(segment[0])
subdoclen += segmentlen

if len(subdoc):
yield args.separator.join(subdoc), len(subdoc)

lineno = 1
for docno, doc in enumerate(read_docs(args.infile, docfield=args.docid_field)):
docid = doc[0][-1]
if args.chunk:
for subdoc, subdoclen in chunk_doc(doc):
print(docid, lineno, lineno + subdoclen - 1, subdoc, sep="\t")
lineno += subdoclen
else:
doci = 0
for docj, line in enumerate(doc):
line_with_context, num_deleted = get_context(line[0], doc[doci:docj])
doci += num_deleted
# print(docno, count_tokens(line_with_context), args.max_tokens, line_with_context)
print(docid, lineno + doci, lineno + docj, line_with_context, sep="\t")
lineno += len(doc)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("infile", nargs="?", default=sys.stdin)
parser.add_argument("--max-tokens", "-t", type=int, metavar="T", default=250, help="Maximum tokens in total (default: %(default)s)")
parser.add_argument("--max-sents", "-c", type=int, metavar="N", default=None, help="Maximum sentences of context (default: %(default)s)")
parser.add_argument("--spm-model", "-m", default=None)
parser.add_argument("--chunk", action="store_true")
parser.add_argument("--docid-field", "-f", metavar="F", default=1, help="Field containing the doc ID (default: %(default)s)")
parser.add_argument("--separator", "-s", default=" <eos>",
help="separator for sentences (default: %(default)s)")
args = parser.parse_args()

main(args)
1 change: 1 addition & 0 deletions docmt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .data import read_docs, smart_open, DOC_SEPARATOR
62 changes: 62 additions & 0 deletions docmt/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import sys
import gzip

from typing import Iterable, List, Tuple


DOC_SEPARATOR = " <eos>"


def extract_sent(line, fieldno=-1, sep=DOC_SEPARATOR):
"""Splits up a line that is a document and returns the requested field
(which is a sentence)"""

return line.split(sep)[fieldno]


def read_docs(infile, docfield=-1) -> Iterable[List[Tuple]]:
"""Generator over documents; returns documents as list of lines.
:param infile: The file stream to read from.
:param docfield: The field containing the document ID (default: last field).
:return: A generator of documents, each a list of tuples. The tuples are the fields (e.g., source, target, docid).
"""
doc = []
prev_docid = None
for lineno, line in enumerate(infile, 1):
# Split on tabs, then strip whitespace from either side
fields = list(map(str.strip, line.rstrip().split("\t")))

docid = fields[docfield] if len(fields) > docfield else None
if docid == "0":
docid = None

if docid != prev_docid or docid is None:
if len(doc):
yield doc
doc = []

doc.append(tuple(fields))
prev_docid = docid

if len(doc):
yield doc


def smart_open(filepath):
"""
Generalized open; works for plain files, compressed files, and STDIN.
"""
infile = None
if filepath == "-":
infile = sys.stdin
elif filepath.endswith(".gz"):
infile = gzip.open(filepath, "rt")
else:
infile = open(filepath, "rt")
return infile


def main():
for doc in read_docs(sys.stdin):
print(doc)

0 comments on commit e30a130

Please sign in to comment.