-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add SMS-WSJ RETURNN datasets #116
base: main
Are you sure you want to change the base?
Conversation
I added an extension to allow using WSJ alignments directly instead of the tailored ones for an SMS-WSJ bliss corpus. A corresponding |
num_outputs=None, | ||
rasr_num_outputs=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need both? Maybe they can be merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_outputs
allows to set this with all the details while rasr_num_outputs
just assumes things as SmsWsjMixtureEarlyDataset
. Not sure if the full control is necessary at some point. We could remove num_outputs
and add it if it's really needed or we remove rasr_num_outputs
and always use more detail which seems to overcomplicate things for our usual case though.
Co-authored-by: SimBe195 <[email protected]>
Co-authored-by: SimBe195 <[email protected]>
Co-authored-by: SimBe195 <[email protected]>
Co-authored-by: SimBe195 <[email protected]>
Co-authored-by: SimBe195 <[email protected]>
Co-authored-by: SimBe195 <[email protected]>
Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq, | ||
therefore the buffer does not contain the initial indices when continuing the training from an epoch > 0. | ||
""" | ||
out = super().init_seq_order(epoch=epoch, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out = super().init_seq_order(epoch=epoch, **kwargs) | |
self.seq_ordering = self._seq_ordering | |
out = super().init_seq_order(epoch=epoch, **kwargs) |
Not sure if this is relevant in general, but when using ReturnnSearchJobV2
, self.seq_ordering
might be changed by def search()
in returnn/tf/engine.py
, which leads to problems with the update_buffer
function, so it has to be reset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate what the issue is in update_buffer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'm sorry, what I said was not 100% right. I recreated the problem. The issue is not update_buffer
.
In my case, self.seq_ordering
was changed to "sorted_reverse"
. The init_seq_order()
function in MapDatasetWrapper
calls self.get_seq_order_for_epoch
. If self.seq_ordering
is "sorted_reverse"
, get_seq_len()
is called and it tries to get the sequence length for sequences which are not in the buffer. Therefore, self.seq_ordering
has to be "default"
, so not all sequence lengths are needed. Maybe there is a better and more elegant solution to this problem, but this was the easiest one I could think of
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe in that case you could set buffer=False
. Or find another way to avoid overwriting of seq_ordering
, but enforcing it here does not seem like the best idea to me.
if "seq_ordering" not in kwargs: | ||
print("Warning: no shuffling is enabled by default", file=returnn_log.v2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "seq_ordering" not in kwargs: | |
print("Warning: no shuffling is enabled by default", file=returnn_log.v2) | |
if "seq_ordering" not in kwargs: | |
print("Warning: no shuffling is enabled by default", file=returnn_log.v2) | |
else: | |
self._seq_ordering = kwargs["seq_ordering"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See other thread #116 (comment)
file=returnn_log.v5, | ||
) | ||
self._ds_iterator = iter(self._ds) | ||
for idx_ in range(min(self._buffer.max_size // 2, len(self))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having a class for the buffer, but I have faced a problem with the update_buffer
function, more precisely, the automatic popping of the oldest inserted element.
When the end of the dataset is reached, the buffer is "refilled" with self._buffer.max_size // 2
sequences from the beginning of the dataset and the oldest self._buffer.max_size // 2
sequences are popped. However, when working with batches, this may cause that sequences which are needed for the batch are removed from the buffer too soon.
This could maybe be avoided by filling one sequence from the beginning of the dataset at a time. Maybe something like:
for idx_ in range((seq_idx + self._buffer.max_size // 2) - len(self) + 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with that might be that you only get here if idx == len(self) - 1 and 0 not in self._buffer
. So probably when seq_idx + self._buffer.max_size // 2 == len(self) - 1
. So you would have range(0)
, right? A simpler thing might be to just increase the buffer size in your case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think increasing the buffer size will help because the oldest self._buffer.max_size // 2
sequences are popped, no matter if the buffer size is 40 or 400, you will always loose the first half of sequences in the buffer with that loop.
So tell me if I'm wrong, but due to min(seq_idx + self._buffer.max_size // 2, len(self))
, we get there when seq_idx + self._buffer.max_size // 2 -1 == len(self) -1
and not seq_idx + self._buffer.max_size // 2 == len(self) - 1
. That's why I added +1
, so we have at least range(1)
. However you are right that the condition 0 not in self._buffer
causes a problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, the oldest self._buffer.max_size // 2
sequences are always popped, but if the buffer size is large enough, the remaining old sequences are sufficient to cover everything that you need for your remaining batches, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's easier to give an example:
Let's say len(self)
is 125 and my buffer size is 40. My batch begins at sequence 100 and will contain 10 sequences. Now, at some point, update_buffer
is called with seq_idx=105
.
Then, in the first part, the buffer is filled with 85,...,105,...,124
(because 124 == seq_idx + self._buffer.max_size // 2 -1
, so we do for idx in range(105, 125)
). Now, the condition if idx == len(self) - 1 and 0 not in self._buffer
is fulfilled (because 124 == len(self) - 1
). for idx_ in range(min(self._buffer.max_size // 2, len(self)))
fills the buffer with sequences 0 to 19. As a result, the first 20 elements of the buffer are popped, so the buffer is now filled with elements 105,...,124,0,...,19
. However, for the batch we need sequences 100 to 104, which are not in the buffer anymore.
That means, everything before sequence 105 is removed, there are no old sequences remaining. With an increased batch size, the problem would just appear earlier.
Again, tell me if I am wrong, maybe I am doing a mistake here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree. In general as soon as update_buffer is called on seq_idx = len(self) - buffer_size // 2
the start of the dataset is added and the buffer is now filled with sequences [len(self) - buffer_size // 2 ... len(self) - 1] + [0 ... buffer_size // 2 - 1]
so this only keeps the current minibatch intact if seq_idx
happens to be the first sequence in the minibatch, regardless of buffer_size. In the example of @larissakl it only runs successfully if the batch begins at sequence 105.
I ran into this error as well and changed the instances of // 2
to // 3
locally which seems to give enough room for old sequences but maybe the safer solution is to only add seq_idx + buffer_size // 2 - len(self)
sequences from the dataset start. For example when update_buffer(118)
is called, the first 13
sequences would be added such that the buffer contains 98,...,118,...,124,0,...,12
afterwards.
BTW I think we should add some documentation that the buffer size should be set large enough so that the number of sequences in a mini-batch never exceeds buffer_size // 2
; otherwise the training will crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran into this error as well and changed the instances of
// 2
to// 3
locally which seems to give enough room for old sequences but maybe the safer solution is to only addseq_idx + buffer_size // 2 - len(self)
sequences from the dataset start. For example whenupdate_buffer(118)
is called, the first13
sequences would be added such that the buffer contains98,...,118,...,124,0,...,12
afterwards.
That's what I meant by for idx_ in range((seq_idx + self._buffer.max_size // 2) - len(self) + 1)
. I tried the following and it seems to work:
for idx in range(seq_idx, min(seq_idx + self._buffer.max_size // 2, len(self))):
if idx not in self._buffer:
self._buffer[idx] = next(self._ds_iterator)
if idx == len(self) - 1:
if 0 not in self._buffer:
print(f"Reached end of dataset, reset iterator", file=returnn_log.v4)
try:
next(self._ds_iterator)
except StopIteration:
pass
else:
print(
"WARNING: reached final index of dataset, but iterator has more sequences. "
"Maybe the training was restarted from an epoch > 1?",
file=returnn_log.v3,
)
print(
f"Current buffer indices: {self._buffer.keys()}",
file=returnn_log.v5,
)
self._ds_iterator2 = iter(self._ds)
for idx_ in range((seq_idx + self._buffer.max_size // 2) - len(self) + 1):
if idx_ not in self._buffer:
self._buffer[idx_] = next(self._ds_iterator2)
if seq_idx == len(self) - 1:
self._ds_iterator = self._ds_iterator2
It realizes the idea of only adding seq_idx + buffer_size // 2 - len(self)
sequences from the dataset start. The iterator is only reset if 0 is not in the buffer. If this happens, we have to work with two iterators in parallel, one at the end of the dataset (until the last sequence is reached) and one at the start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a fix which also simplifies this a bit. When we reach the last index (len(self) - 1
) for the first time, we reset the iterator and add the index 0). Then, we always add one index more at the beginning of the dataset by using idx % len(self)
instead of just idx
capped at len(self)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please go ahead and test that, for my local test case the issue is resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I think we should add some documentation that the buffer size should be set large enough so that the number of sequences in a mini-batch never exceeds buffer_size // 2; otherwise the training will crash.
I added that in the docstring, you can have a look whether you find that sufficient.
The unzipping logic was quite annoying and prone to errors. Therefore, I added a custom reader class to allow reading the data directly from the zip in the |
Just as a note: We might fix the start index for higher epochs in |
As agreed offline, let's get active on this PR again. The main issue right now is the bug for restarting training at epochs > 1. There are also other open threads and the question about the effect of the buffer size. @larissakl: @SimBe195 mentioned that you might have some idea to fix the restarting issue. Is that right? Regarding the buffer size, I think we could also study that after merging. |
In general, this issue should go along with the |
@larissakl do you have a setup to reproduce the issue? I tried (trained a model for one subepoch and then aborted and restarted the training) but that works just fine. |
I added some helper classes which allow wrapping the SMS-WSJ dataset in RETURNN. They basically wrap RETURNN's
MapDatasetBase
andMapDatasetWrapper
. They are part ofi6_experiments
to allow integrating them into a pipeline. An example usage in the sisyphus config would beI will add a baseline pipeline for a joint separation and recognition system lateron.