Skip to content

Commit

Permalink
Merge pull request #33 from microsoft/alvinn/march_12_updated_baselin…
Browse files Browse the repository at this point in the history
…e_mvdr

Latest baseline, MVDR, results, and more
  • Loading branch information
nidleo authored Mar 12, 2024
2 parents b2d2483 + 159f046 commit 5ed9039
Show file tree
Hide file tree
Showing 22 changed files with 837 additions and 312 deletions.
61 changes: 51 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
[![Slack][slack-badge]][slack-invite]


[slack-badge]: https://img.shields.io/badge/slack-chat-green.svg?logo=slack
[slack-invite]: https://join.slack.com/t/chime-fey5388/shared_invite/zt-1oha0gedv-JEUr1mSztR7~iK9AxM4HOA

# Introduction
Welcome to the "NOTSOFAR-1: Distant Meeting Transcription with a Single Device" Challenge.

This repo contains the baseline system code for the NOTSOFAR-1 Challenge.

For more details see:
1. CHiME website: https://www.chimechallenge.org/current/task2/index
2. Preprint: https://arxiv.org/abs/2401.08887
- For more information about NOTSOFAR, visit [CHiME's official challenge website](https://www.chimechallenge.org/current/task2/index)
- [Register](https://www.chimechallenge.org/current/task2/submission) to participate.
- [Baseline system description](https://www.chimechallenge.org/current/task2/baseline).
- Contact us: join the `chime-8-notsofar` channel on the [CHiME Slack](https://join.slack.com/t/chime-fey5388/shared_invite/zt-1oha0gedv-JEUr1mSztR7~iK9AxM4HOA), or open a [GitHub issue](https://github.com/microsoft/NOTSOFAR1-Challenge/issues).

### 📊 Baseline Results on NOTSOFAR dev-set-1

Values are presented in `tcpWER / tcORC-WER (session count)` format.
<br>
As mentioned in the [official website](https://www.chimechallenge.org/current/task2/index#tracks),
systems are ranked based on the speaker-attributed
[tcpWER](https://github.com/fgnt/meeteval/blob/main/doc/tcpwer.md)
, while the speaker-agnostic [tcORC-WER](https://github.com/fgnt/meeteval) serves as a supplementary metric for analysis.
<br>
We include analysis based on a selection of hashtags from our [metadata](https://www.chimechallenge.org/current/task2/data#metadata), providing insights into how different conditions affect system performance.



| | Single-Channel | Multi-Channel |
|----------------------|-----------------------|-----------------------|
| All Sessions | **46.8** / 38.5 (177) | **32.4** / 26.7 (106) |
| #NaturalMeeting | 47.6 / 40.2 (30) | 32.3 / 26.2 (18) |
| #DebateOverlaps | 54.9 / 44.7 (39) | 38.0 / 31.4 (24) |
| #TurnsNoOverlap | 32.4 / 29.7 (10) | 21.2 / 18.8 (6) |
| #TransientNoise=high | 51.0 / 43.7 (10) | 33.6 / 29.1 (5) |
| #TalkNearWhiteboard | 55.4 / 43.9 (40) | 39.9 / 31.2 (22) |







# Project Setup
Expand Down Expand Up @@ -143,7 +177,8 @@ Begin by exploring the following components:

### Training datasets
For training and fine-tuning your models, NOTSOFAR offers the **simulated training set** and the training portion of the
**recorded meeting dataset**. Refer to the `download_simulated_subset` and `download_meeting_subset` functions in `utils/azure_storage.py`,
**recorded meeting dataset**. Refer to the `download_simulated_subset` and `download_meeting_subset` functions in
[utils/azure_storage.py](https://github.com/microsoft/NOTSOFAR1-Challenge/blob/main/utils/azure_storage.py#L109),
or the [NOTSOFAR-1 Datasets](#notsofar-1-datasets---download-instructions) section.


Expand All @@ -159,10 +194,12 @@ python run_training_css_local.py
## 2. Training on the full simulated training dataset

### Step 1: Download the simulated training dataset
You can use the `download_simulated_subset` function in `utils/azure_storage.py` to download the training dataset from blob storage.
You can use the `download_simulated_subset` function in
[utils/azure_storage.py](https://github.com/microsoft/NOTSOFAR1-Challenge/blob/main/utils/azure_storage.py)
to download the training dataset from blob storage.
You have the option to download either the complete dataset, comprising almost 1000 hours, or a smaller, 200-hour subset.

For example, to download the entire 1000-hour dataset, make the following calls to download both the training and validation subsets:
Examples:
```python
ver='v1.5' # this should point to the lateset and greatest version of the dataset.
Expand Down Expand Up @@ -223,8 +260,11 @@ Alternatively, using AzCopy CLI, set these arguments and run the following comma
- `version`: version to download (`240103g` / etc.). Use the latest version.
- `datasets_path` - path to the directory where you want to download the benchmarking dataset (destination directory must exist). <br>
Currently only **dev_set** (no GT) and **train_set** are available. See timeline on the [NOTSOFAR page](https://www.chimechallenge.org/current/task2/index) for when the other sets will be released.
See doc in `download_meeting_subset` function in `utils/azure_storage.py` for latest available versions.
Train, dev, and eval sets are released for the NOTSOFAR challenge are released in stages.
See release timeline on the [NOTSOFAR page](https://www.chimechallenge.org/current/task2/index#dates).
See doc in `download_meeting_subset` function in
[utils/azure_storage.py](https://github.com/microsoft/NOTSOFAR1-Challenge/blob/main/utils/azure_storage.py#L109)
for latest available versions.
```bash
azcopy copy https://notsofarsa.blob.core.windows.net/benchmark-datasets/<subset_name>/<version>/MTG <datasets_path>/benchmark --recursive
Expand Down Expand Up @@ -259,7 +299,7 @@ azcopy copy https://notsofarsa.blob.core.windows.net/css-datasets/<version>/<vol
Example:
```bash
azcopy copy https://notsofarsa.blob.core.windows.net/css-datasets/v1.5/1000hrs/train . --recursive
azcopy copy https://notsofarsa.blob.core.windows.net/css-datasets/v1.5/200hrs/train . --recursive
```
Expand All @@ -273,4 +313,5 @@ Thank you for your interest and patience.
# 🤝 Contribute
Please refer to our [contributing guide](CONTRIBUTING.md) for more information on how to contribute!
Please refer to our [contributing guide](CONTRIBUTING.md) for more information on how to contribute!
4 changes: 4 additions & 0 deletions configs/inference/debug_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ css:
device: "cuda:0"
show_progressbar: True
slice_audio_for_debug: False
mc_mvdr: True
mc_mask_floor_db: 0.
sc_mask_floor_db: -inf
activity_th: 0.3

diarization:
method: 'word_nmesc' # choose from "word_nmesc", "nmesc" and "nmesc_msdd"
Expand Down
4 changes: 4 additions & 0 deletions configs/inference/inference_v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ css:
show_progressbar: True
slice_audio_for_debug: False
pass_through_ch0: False
mc_mvdr: True
mc_mask_floor_db: 0. # for MC, MVDR without any direct masking worked best
sc_mask_floor_db: -inf # for SC, direct masking without floor worked best
activity_th: 0.3

diarization:
method: 'word_nmesc' # choose from "word_nmesc", "nmesc" and "nmesc_msdd"
Expand Down
71 changes: 47 additions & 24 deletions css/css.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import nn
from tqdm import trange

from css.css_with_conformer.utils.mvdr_util import make_mvdr
from css.training.conformer_wrapper import ConformerCssWrapper
from css.training.train import TrainCfg, get_model
from css.training.losses import mse_loss, l1_loss, PitWrapper
Expand All @@ -20,7 +21,7 @@
from utils.mic_array_model import multichannel_mic_pos_xyz_cm
from utils.conf import load_yaml_to_dataclass
from utils.numpy_utils import dilate, erode
from utils.plot_utils import plot_stitched_masks, plot_left_right_stitch
from utils.plot_utils import plot_stitched_masks, plot_left_right_stitch, plot_separation_methods
from utils.audio_utils import play_wav

_LOG = get_logger('css')
Expand All @@ -33,18 +34,23 @@ class CssCfg:
hop_size_sec: float = 1.5 # in seconds
normalize_segment_power: bool = False
stitching_loss: str = 'l1' # loss function for stitching adjacent segments ('l1' or 'mse')
stitching_input: str = 'mask' # type of input for stitching loss ('mask' or 'masked_mag')
stitching_input: str = 'mask' # type of input for stitching loss ('mask' or 'separation_result')
seg_weight_m0_sec: float = 0.15 # see calc_segment_weight
seg_weight_m1_sec: float = 0.3
activity_th: float = 0.4 # threshold for segmentation mask
activity_dilation_sec: float = 0.4 # dilation and erosion for segmentation mask
activity_erosion_sec: float = 0.2
device: Optional[str] = None
show_progressbar: bool = True
checkpoint_sc: str = 'notsofar/conformer0.5/sc' # segment-wise single-channel model
checkpoint_mc: str = 'notsofar/conformer0.5/mc' # segment-wise multi-channel model
# segment-wise single-channel model
checkpoint_sc: str = 'notsofar/conformer1.0/sc'
# segment-wise multi-channel model
checkpoint_mc: str = 'notsofar/conformer1.0/mc'
device_id: int = 0
num_spks: int = 3 # the number of streams the separation models outputs
mc_mvdr: bool = True # if True, applies MVDR to the multi-channel input
mc_mask_floor_db: float = 0. # mask floor in db. -inf means no floor. 0 means mask has no effect
sc_mask_floor_db: float = -np.inf
pass_through_ch0: bool = False # if True, simply returns the first channel of the input and skips CSS
slice_audio_for_debug: bool = False # if True, only processes 10 seconds of the input audio

Expand Down Expand Up @@ -103,7 +109,7 @@ def css_inference(out_dir: str, models_dir: str, session: pd.Series, cfg: CssCfg
mixwav = mixwav[np.newaxis, :, np.newaxis] # [Batch, Nsamples, Channels]

if cfg.slice_audio_for_debug:
mixwav = mixwav[:, sr*20:sr*30, :]
mixwav = mixwav[:, sr*100:sr*110, :]

separated_wavs = separate_and_stitch(mixwav, separator, sr, device, cfg)

Expand Down Expand Up @@ -213,7 +219,7 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper,
T = segment_frames
pad_shape = stft_mix[:, :, :T].shape # [B, F, T, Mics]

masked_seg_list = []
separated_seg_list = []
spk_masks_list = []

assert not separator.training
Expand Down Expand Up @@ -244,17 +250,34 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper,

assert masks['spk_masks'].shape[3] == cfg.num_spks
assert stft_seg.shape[:3] == stft_seg.shape[:3]
ref_channel = 0
stft_seg_device_chref = stft_seg_device[:, :, :, ref_channel]
# masks['spk_masks']: [B, F, T, num_spks]
# stft_seg_device: [B, F, T, Channels]
# stft_seg_device_chref: [B, F, T]
ref_channel = 0
stft_seg_device_chref = stft_seg_device[:, :, :, ref_channel]

# mask multiplication
# TODO: support segment-level MVDR for MC with up to 3 speakers
masked_seg = [stft_seg_device_chref * m # m: [B, F, T]
for m in th.unbind(masks['spk_masks'], dim=3)]
masked_seg = torch.stack(masked_seg, dim=-1) # [B, F, T, num_spks]
num_channels = stft_seg_device.shape[3]
if num_channels > 1 and cfg.mc_mvdr:
mvdr_responses = make_mvdr(masks['spk_masks'].squeeze(0).moveaxis(2, 0).cpu().numpy(),
masks['noise_masks'].squeeze(0).moveaxis(2, 0).cpu().numpy(),
mix_stft=stft_seg_device.squeeze(0).moveaxis(2, 0).cpu().numpy(),
return_stft=True)
mvdr_responses = torch.from_numpy(np.stack(mvdr_responses, axis=-1)).unsqueeze(0).to(device)
# [B, F, T, num_spks]
seg_for_masking = mvdr_responses
else:
seg_for_masking = stft_seg_device_chref.unsqueeze(-1) # [B, F, T, 1]

# floored mask multiplication. if mask_floor_db == 0, mask is all-ones (assuming mask in [0, 1] range)
mask_floor_db = cfg.mc_mask_floor_db if num_channels > 1 else cfg.sc_mask_floor_db
assert mask_floor_db <= 0
mask_floor = 10. ** (mask_floor_db / 20.) # dB to amplitude
mask_clipped = torch.clip(masks['spk_masks'], min=mask_floor)
separated_seg = seg_for_masking * mask_clipped # [B, F, T, num_spks]

# Plot for debugging
# plot_separation_methods(stft_seg_device_chref, masks, mvdr_responses, separator, cfg,
# plots=['mvdr', 'masked_mvdr', 'spk_masks', 'masked_ref_ch', 'mixture'])

if cfg.normalize_segment_power:
# normalize to match the input mixture power
Expand All @@ -264,15 +287,15 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper,
dim=(1, 2), keepdim=True)
)

assert torch.is_complex(masked_seg)
assert torch.is_complex(separated_seg)
sep_energy = torch.sqrt(
torch.mean(masked_seg[:, :, :t].sum(-1).abs().pow(2), # sum over spks, squared mag
torch.mean(separated_seg[:, :, :t].sum(-1).abs().pow(2), # sum over spks, squared mag
dim=(1, 2), keepdim=True
)
)
masked_seg = (mix_energy / sep_energy)[..., None] * masked_seg
separated_seg = (mix_energy / sep_energy)[..., None] * separated_seg

masked_seg_list.append(masked_seg.cpu()) # [B, F, T, num_spks]
separated_seg_list.append(separated_seg.cpu()) # [B, F, T, num_spks]
spk_masks_list.append(masks['spk_masks'].cpu()) # [B, F, T, num_spks]


Expand All @@ -283,7 +306,7 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper,
# add first segment
wg_seg = calc_segment_weight(segment_frames, m0_frames, m1_frames, is_first_seg=True)
wg_stitched[:segment_frames] += wg_seg
stft_stitched[:, :, :segment_frames] += wg_seg.view(1, 1, -1, 1) * masked_seg_list[0]
stft_stitched[:, :, :segment_frames] += wg_seg.view(1, 1, -1, 1) * separated_seg_list[0]
mask_stitched[:, :, :segment_frames] += wg_seg.view(1, 1, -1, 1) * spk_masks_list[0]

pit = PitWrapper({'mse': mse_loss, 'l1': l1_loss}[cfg.stitching_loss])
Expand All @@ -292,9 +315,9 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper,
for i in range(1, num_segments):
if cfg.stitching_input == 'mask':
left_input, right_input = spk_masks_list[i-1], spk_masks_list[i]
elif cfg.stitching_input == 'masked_mag':
elif cfg.stitching_input == 'separation_result':
# masked magnitudes
left_input, right_input = masked_seg_list[i - 1].abs(), masked_seg_list[i].abs()
left_input, right_input = separated_seg_list[i - 1].abs(), separated_seg_list[i].abs()
else:
assert False, f'unexpected stitching_input: {cfg.stitching_input}'

Expand All @@ -303,21 +326,21 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper,

# Plot for debugging:
# plot_left_right_stitch(separator, left_input, right_input, right_perm,
# overlap_frames, cfg, stft_seg_to_play=masked_seg_list[i][..., 0], fs=fs)
# overlap_frames, cfg, stft_seg_to_play=separated_seg_list[i][..., 0], fs=fs)

# permute current segment to match with the previous one
for ib in range(batch_size):
spk_masks_list[i][ib] = spk_masks_list[i][ib, ..., right_perm[ib]]
masked_seg_list[i][ib] = masked_seg_list[i][ib, ..., right_perm[ib]]
separated_seg_list[i][ib] = separated_seg_list[i][ib, ..., right_perm[ib]]

st = i * hop_frames
en = min(st + segment_frames, mix_frames)
# weighted overlap-and-add
wg_seg = calc_segment_weight(segment_frames, m0_frames, m1_frames, is_last_seg=(i==num_segments-1))
wg_seg = wg_seg[:en-st] # last segment may be shorter
wg_stitched[st:en] += wg_seg
assert torch.is_complex(masked_seg_list[i]), 'summation assumes complex representation'
stft_stitched[:, :, st:en] += wg_seg.view(1, 1, -1, 1) * masked_seg_list[i][:, :, :en-st]
assert torch.is_complex(separated_seg_list[i]), 'summation assumes complex representation'
stft_stitched[:, :, st:en] += wg_seg.view(1, 1, -1, 1) * separated_seg_list[i][:, :, :en-st]
mask_stitched[:, :, st:en] += wg_seg.view(1, 1, -1, 1) * spk_masks_list[i][:, :, :en-st]

assert (wg_stitched > 1e-5).all(), 'zero weights found. check hop_size, segment_size or m0, m1'
Expand Down
3 changes: 3 additions & 0 deletions css/css_with_conformer/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
The code under this directory is mostly a copy of "CSS with Conformer" from the original repo at the URL below.
Some extentions were made when adopting to NOTSOFAR.

We didn't copy the README.md file from the original repo because it contains Shared Access Signatures (SAS) that
are considered secrets by some version control systems. One can find the original README.md file at the URL below.
https://github.com/Sanyuan-Chen/CSS_with_Conformer/blob/master/README.md
2 changes: 1 addition & 1 deletion css/css_with_conformer/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run(args):
print('spks',len(spks),spks[0].shape)

if args.mvdr:
res1, res2 = make_mvdr(np.asfortranarray(mixed.T), spks)
res1, res2 = make_mvdr(spks[:2], spks[2:], np.asfortranarray(mixed.T))
spks = [res1, res2]

sf.write(dump_dir / f"{duration_sec}_mix.wav", egs['mix'][0].cpu().numpy(), sr)
Expand Down
Loading

0 comments on commit 5ed9039

Please sign in to comment.