-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T2T batching #786
base: master
Are you sure you want to change the base?
T2T batching #786
Conversation
@@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray: | |||
|
|||
return np.array(numbers, dtype=dtype) | |||
|
|||
def reader(files: List[str])-> Iterable[List[np.ndarray]]: | |||
def reader(files: List[str]) -> Iterable[List[np.ndarray]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tohle nesouvisí s tou změnou, jen to zanese konflikt do branche s tf datasetem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ale jestli to jinak neprojde přes travis, tak to tu nechej
@@ -4,6 +4,7 @@ tf_manager=<tf_manager> | |||
output="tests/outputs/hier-multiattention" | |||
overwrite_output_dir=True | |||
epochs=1 | |||
batch_size=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch size by neměla být povinná jen kvůli tomu, že je někde nějaký workaround..
@@ -85,6 +85,9 @@ def training_loop(cfg: Namespace) -> None: | |||
trainer_result = cfg.tf_manager.execute( | |||
batch, feedables, cfg.trainers, train=True, | |||
summaries=True) | |||
# workaround: we need to use validation batching scheme | |||
# during evaluation | |||
batch.batching = BatchingScheme(batch_size=cfg.batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tohle neni validation batching scheme. zahoď tuhle změnu, v mým refaktoru už to funguje správně a tohle by zbytečně zaneslo konflikt.
batch sizes and sequence length tolerance. | ||
min_length: int, sequences shorter than this will be skipped. | ||
Return: | ||
A dictionary with parameters that can be passed to input_pipeline: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tohle neni pravda
@@ -95,6 +95,84 @@ def __init__(self, | |||
# pylint: enable=too-few-public-methods | |||
|
|||
|
|||
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1): | |||
"""Create a default set of length-bucket boundaries.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
přidal bych příklad vstupu a výstupu, moc nechápu proč length bucket step je float
@@ -95,6 +95,84 @@ def __init__(self, | |||
# pylint: enable=too-few-public-methods | |||
|
|||
|
|||
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chybí typový anotace
max_length = max_length or batch_size | ||
if max_length < min_length: | ||
raise ValueError("max_length must be greater or equal to min_length") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tady by se mělo kontrolovat že length_bucket_step je > 1.0 a hodit valueerror se zprávou a nenechávat to až na assert
v pomocný funkci
ad workaround - to už svuj pull request má, proč je to tady taky? |
Workaround == je to rozvrtane (rozumej, pada to v normalnich scenarich), takze potrebuju rychly fix, abych mohl pracovat na dalsich vecech. Vetsina tech veci na sobe zavisi, na druhou stranu se daji semanticky rozdelit, coz jsem udelal do pull requestu. Klidne muzu priste udelat jeden velky PR a nebudeme muset resit zavislosti. |
Rozumím tomu správě, že tohle potřeba zamergovat jako první? Na čem to teda přesně vázne? |
Je potreba opravit dokumentaci v tech dataset.* metodach vykradenych z t2t (a uvest, ze je berem od nich). Dale doplnit anotace... Jak rikam slo prakticky o copy-paste, abych si nemusel pokazde rucne pocitat bucket_batch_sizes a bucket_boundaries. Samozrejme ty ostatni PR by mely fungovat i bez tohoto, ale budes si je muset rebasnout :) |
Tohle je teda součást #802? Jestli jo, tak to prosím zavři. |
Neni. Spatne jsem rebasnul |
included batching scheme methods from:
https://github.com/tensorflow/tensor2tensor/blob/415585f40d9f21c56df7bda35033bc915d82321e/tensor2tensor/utils/data_reader.py