Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SMS-WSJ RETURNN datasets #116

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open

add SMS-WSJ RETURNN datasets #116

wants to merge 47 commits into from

Conversation

vieting
Copy link
Contributor

@vieting vieting commented Feb 2, 2023

I added some helper classes which allow wrapping the SMS-WSJ dataset in RETURNN. They basically wrap RETURNN's MapDatasetBase and MapDatasetWrapper. They are part of i6_experiments to allow integrating them into a pipeline. An example usage in the sisyphus config would be

def segment_to_rasr(seg_name):
    """
    Maps SMS-WSJ segnemt name (e.g. "0_4k6c0303_4k4c0319") to list of RASR segment names
    (e.g. ["cv_dev93/4k6c0303/0000", "cv_dev93/4k4c0319/0000"]), which can be found in HDF.
    """
    segs = str(seg_name).split("_")[1:]
    return [f"cv_dev93/{seg}/0000" for seg in segs]


returnn_config = ReturnnConfig(
  dict(
    dev={
      "class": CodeWrapper("SmsWsjMixtureEarlyAlignmentDataset"),
      "dataset_name": "cv_dev93",
      "json_path": "/path/to/sms_wsj.json",
      "rasr_num_outputs": 9001,
      "segment_to_rasr": segment_to_rasr
      "rasr_classes_hdf": "/path/to/cv_dev93_alignments.hdf",
      "pad_label": 9000,
      "seq_ordering": "default",
    },
    ...
  ),
  python_prolog=[
    SmsWsjBase,
    SmsWsjBaseWithRasrClasses,
    SmsWsjWrapper,
    SmsWsjMixtureEarlyDataset,
    SmsWsjMixtureEarlyAlignmentDataset,
  ],
  ...
)

I will add a baseline pipeline for a joint separation and recognition system lateron.

@vieting
Copy link
Contributor Author

vieting commented Feb 6, 2023

I added an extension to allow using WSJ alignments directly instead of the tailored ones for an SMS-WSJ bliss corpus. A corresponding segment_to_rasr function needs to be given for each case now.

common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
Comment on lines 452 to 453
num_outputs=None,
rasr_num_outputs=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need both? Maybe they can be merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_outputs allows to set this with all the details while rasr_num_outputs just assumes things as SmsWsjMixtureEarlyDataset. Not sure if the full control is necessary at some point. We could remove num_outputs and add it if it's really needed or we remove rasr_num_outputs and always use more detail which seems to overcomplicate things for our usual case though.

common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
common/datasets/sms_wsj/returnn_datasets.py Outdated Show resolved Hide resolved
Override this in order to update the buffer. get_seq_length is often called before _collect_single_seq,
therefore the buffer does not contain the initial indices when continuing the training from an epoch > 0.
"""
out = super().init_seq_order(epoch=epoch, **kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = super().init_seq_order(epoch=epoch, **kwargs)
self.seq_ordering = self._seq_ordering
out = super().init_seq_order(epoch=epoch, **kwargs)

Not sure if this is relevant in general, but when using ReturnnSearchJobV2, self.seq_ordering might be changed by def search() in returnn/tf/engine.py, which leads to problems with the update_buffer function, so it has to be reset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate what the issue is in update_buffer?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'm sorry, what I said was not 100% right. I recreated the problem. The issue is not update_buffer.
In my case, self.seq_ordering was changed to "sorted_reverse". The init_seq_order() function in MapDatasetWrapper calls self.get_seq_order_for_epoch. If self.seq_ordering is "sorted_reverse", get_seq_len() is called and it tries to get the sequence length for sequences which are not in the buffer. Therefore, self.seq_ordering has to be "default", so not all sequence lengths are needed. Maybe there is a better and more elegant solution to this problem, but this was the easiest one I could think of

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in that case you could set buffer=False. Or find another way to avoid overwriting of seq_ordering, but enforcing it here does not seem like the best idea to me.

Comment on lines +362 to +363
if "seq_ordering" not in kwargs:
print("Warning: no shuffling is enabled by default", file=returnn_log.v2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "seq_ordering" not in kwargs:
print("Warning: no shuffling is enabled by default", file=returnn_log.v2)
if "seq_ordering" not in kwargs:
print("Warning: no shuffling is enabled by default", file=returnn_log.v2)
else:
self._seq_ordering = kwargs["seq_ordering"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other thread #116 (comment)

file=returnn_log.v5,
)
self._ds_iterator = iter(self._ds)
for idx_ in range(min(self._buffer.max_size // 2, len(self))):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having a class for the buffer, but I have faced a problem with the update_buffer function, more precisely, the automatic popping of the oldest inserted element.
When the end of the dataset is reached, the buffer is "refilled" with self._buffer.max_size // 2 sequences from the beginning of the dataset and the oldest self._buffer.max_size // 2 sequences are popped. However, when working with batches, this may cause that sequences which are needed for the batch are removed from the buffer too soon.

This could maybe be avoided by filling one sequence from the beginning of the dataset at a time. Maybe something like:
for idx_ in range((seq_idx + self._buffer.max_size // 2) - len(self) + 1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with that might be that you only get here if idx == len(self) - 1 and 0 not in self._buffer. So probably when seq_idx + self._buffer.max_size // 2 == len(self) - 1. So you would have range(0), right? A simpler thing might be to just increase the buffer size in your case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think increasing the buffer size will help because the oldest self._buffer.max_size // 2 sequences are popped, no matter if the buffer size is 40 or 400, you will always loose the first half of sequences in the buffer with that loop.

So tell me if I'm wrong, but due to min(seq_idx + self._buffer.max_size // 2, len(self)), we get there when seq_idx + self._buffer.max_size // 2 -1 == len(self) -1 and not seq_idx + self._buffer.max_size // 2 == len(self) - 1. That's why I added +1, so we have at least range(1). However you are right that the condition 0 not in self._buffer causes a problem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the oldest self._buffer.max_size // 2 sequences are always popped, but if the buffer size is large enough, the remaining old sequences are sufficient to cover everything that you need for your remaining batches, no?

Copy link

@larissakl larissakl Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's easier to give an example:
Let's say len(self) is 125 and my buffer size is 40. My batch begins at sequence 100 and will contain 10 sequences. Now, at some point, update_buffer is called with seq_idx=105.
Then, in the first part, the buffer is filled with 85,...,105,...,124 (because 124 == seq_idx + self._buffer.max_size // 2 -1, so we do for idx in range(105, 125)). Now, the condition if idx == len(self) - 1 and 0 not in self._buffer is fulfilled (because 124 == len(self) - 1). for idx_ in range(min(self._buffer.max_size // 2, len(self))) fills the buffer with sequences 0 to 19. As a result, the first 20 elements of the buffer are popped, so the buffer is now filled with elements 105,...,124,0,...,19. However, for the batch we need sequences 100 to 104, which are not in the buffer anymore.
That means, everything before sequence 105 is removed, there are no old sequences remaining. With an increased batch size, the problem would just appear earlier.
Again, tell me if I am wrong, maybe I am doing a mistake here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. In general as soon as update_buffer is called on seq_idx = len(self) - buffer_size // 2 the start of the dataset is added and the buffer is now filled with sequences [len(self) - buffer_size // 2 ... len(self) - 1] + [0 ... buffer_size // 2 - 1] so this only keeps the current minibatch intact if seq_idx happens to be the first sequence in the minibatch, regardless of buffer_size. In the example of @larissakl it only runs successfully if the batch begins at sequence 105.

I ran into this error as well and changed the instances of // 2 to // 3 locally which seems to give enough room for old sequences but maybe the safer solution is to only add seq_idx + buffer_size // 2 - len(self) sequences from the dataset start. For example when update_buffer(118) is called, the first 13 sequences would be added such that the buffer contains 98,...,118,...,124,0,...,12 afterwards.

BTW I think we should add some documentation that the buffer size should be set large enough so that the number of sequences in a mini-batch never exceeds buffer_size // 2; otherwise the training will crash.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into this error as well and changed the instances of // 2 to // 3 locally which seems to give enough room for old sequences but maybe the safer solution is to only add seq_idx + buffer_size // 2 - len(self) sequences from the dataset start. For example when update_buffer(118) is called, the first 13 sequences would be added such that the buffer contains 98,...,118,...,124,0,...,12 afterwards.

That's what I meant by for idx_ in range((seq_idx + self._buffer.max_size // 2) - len(self) + 1). I tried the following and it seems to work:

        for idx in range(seq_idx, min(seq_idx + self._buffer.max_size // 2, len(self))):
            if idx not in self._buffer:
                self._buffer[idx] = next(self._ds_iterator)
            if idx == len(self) - 1:
                if 0 not in self._buffer:
                    print(f"Reached end of dataset, reset iterator", file=returnn_log.v4)
                    try:
                        next(self._ds_iterator)
                    except StopIteration:
                        pass
                    else:
                        print(
                            "WARNING: reached final index of dataset, but iterator has more sequences. "
                            "Maybe the training was restarted from an epoch > 1?",
                            file=returnn_log.v3,
                        )
                    print(
                        f"Current buffer indices: {self._buffer.keys()}",
                        file=returnn_log.v5,
                    )
                    self._ds_iterator2 = iter(self._ds)
                for idx_ in range((seq_idx + self._buffer.max_size // 2) - len(self) + 1):
                    if idx_ not in self._buffer:
                        self._buffer[idx_] = next(self._ds_iterator2)
                    if seq_idx == len(self) - 1:
                        self._ds_iterator = self._ds_iterator2

It realizes the idea of only adding seq_idx + buffer_size // 2 - len(self) sequences from the dataset start. The iterator is only reset if 0 is not in the buffer. If this happens, we have to work with two iterators in parallel, one at the end of the dataset (until the last sequence is reached) and one at the start.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a fix which also simplifies this a bit. When we reach the last index (len(self) - 1) for the first time, we reset the iterator and add the index 0). Then, we always add one index more at the beginning of the dataset by using idx % len(self) instead of just idx capped at len(self).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please go ahead and test that, for my local test case the issue is resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I think we should add some documentation that the buffer size should be set large enough so that the number of sequences in a mini-batch never exceeds buffer_size // 2; otherwise the training will crash.

I added that in the docstring, you can have a look whether you find that sufficient.

@vieting
Copy link
Contributor Author

vieting commented Feb 13, 2023

The unzipping logic was quite annoying and prone to errors. Therefore, I added a custom reader class to allow reading the data directly from the zip in the SmsWsj dataset. This clearly simplifies the pre-processing and I hope it does not affect the dataset performance but I don't have numbers on that.

@vieting
Copy link
Contributor Author

vieting commented Mar 1, 2023

Just as a note: We might fix the start index for higher epochs in init_seq_order() and study the effect of the buffer size. I suggest to do that after the deadline.

@vieting
Copy link
Contributor Author

vieting commented Mar 28, 2023

As agreed offline, let's get active on this PR again. The main issue right now is the bug for restarting training at epochs > 1. There are also other open threads and the question about the effect of the buffer size.

@larissakl: @SimBe195 mentioned that you might have some idea to fix the restarting issue. Is that right?

Regarding the buffer size, I think we could also study that after merging.

@larissakl
Copy link

In general, this issue should go along with the partiton_epoch parameter, because when restarting from an epoch divisible by partition_epoch it works fine.
In returnn's Dataset class (returnn/datasets/basic) the get_seq_order_for_epoch() function seems to take partition_epoch into accout. Therefore, I thought self.partition_epoch might not be set properly. However, unfortunately, I tested and this did not solve it. Maybe we should nevertheless check if all relevant parameters are set and if the right functions, especially in the super classes, are called

@vieting
Copy link
Contributor Author

vieting commented Mar 30, 2023

@larissakl do you have a setup to reproduce the issue? I tried (trained a model for one subepoch and then aborted and restarted the training) but that works just fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants