From 748c57773b226a1c4598109be1659482f8fd1ab7 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Fri, 4 Aug 2023 16:45:09 -0700 Subject: [PATCH] Seamlessly load audio from Xeno-Canto or URL (eg, using the A2O API). PiperOrigin-RevId: 553946803 --- chirp/audio_utils.py | 34 +++++++++++++++++++++---- chirp/inference/search_embeddings.ipynb | 11 ++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/chirp/audio_utils.py b/chirp/audio_utils.py index beca66fe..86d385dc 100644 --- a/chirp/audio_utils.py +++ b/chirp/audio_utils.py @@ -19,6 +19,7 @@ """ import concurrent import functools +import io import logging import os import tempfile @@ -46,7 +47,17 @@ _BOUNDARY_TO_PADDING_MODE = {'zeros': 'CONSTANT'} -def load_audio( +def load_audio(path: str, target_sample_rate: int, **kwargs) -> jnp.ndarray: + """Load a general audio resource.""" + if path.startswith('xc'): + return load_xc_audio(path, target_sample_rate) + elif path.startswith('http'): + return load_url_audio(path, target_sample_rate) + else: + return load_audio_file(path, target_sample_rate, **kwargs) + + +def load_audio_file( filepath: str | epath.Path, target_sample_rate: int, resampling_type: str = 'polyphase', @@ -170,7 +181,7 @@ def multi_load_audio_window( yield futures.pop(0).result() -def load_xc_audio(xc_id: str, sample_rate: int) -> jnp.ndarray | None: +def load_xc_audio(xc_id: str, sample_rate: int) -> jnp.ndarray: """Load audio from Xeno-Canto given an ID like 'xc12345'.""" if not xc_id.startswith('xc'): raise ValueError(f'XenoCanto id {xc_id} does not start with "xc".') @@ -190,12 +201,25 @@ def load_xc_audio(xc_id: str, sample_rate: int) -> jnp.ndarray | None: try: data = session.get(url=url).content except requests.exceptions.RequestException as e: - print(f'Failed to load audio from Xeno-Canto {xc_id}') - return None + raise requests.exceptions.RequestException( + f'Failed to load audio from Xeno-Canto {xc_id}' + ) from e with tempfile.NamedTemporaryFile(suffix='.mp3', mode='wb') as f: f.write(data) f.flush() - audio = load_audio(f.name, target_sample_rate=sample_rate) + audio = load_audio_file(f.name, target_sample_rate=sample_rate) + return audio + + +def load_url_audio(url: str, sample_rate: int) -> jnp.ndarray: + """Load audio from a URL.""" + data = requests.get(url).content + with io.BytesIO(data) as f: + sf = soundfile.SoundFile(f) + audio = sf.read(dtype='float32') + audio = librosa.resample( + audio, sf.samplerate, sample_rate, res_type='polyphase' + ) return audio diff --git a/chirp/inference/search_embeddings.ipynb b/chirp/inference/search_embeddings.ipynb index baa571a4..048e483c 100644 --- a/chirp/inference/search_embeddings.ipynb +++ b/chirp/inference/search_embeddings.ipynb @@ -110,17 +110,14 @@ "source": [ "#@title Load query audio. { vertical-output: true }\n", "\n", - "# Point to an audio file or Xeno-Canto id (like 'xc12345') of your choice.\n", + "# Point to an audio file, Xeno-Canto id (like 'xc12345') or audio file URL.\n", "audio_path = 'xc12345' #@param\n", "# Muck around with manual selection of the query start time...\n", "start_s = 1 #@param\n", "\n", "window_s = config.model_config['window_size_s']\n", "sample_rate = config.model_config['sample_rate']\n", - "if audio_path.startswith('xc'):\n", - " audio = audio_utils.load_xeno_canto_audio(audio_path, sample_rate)\n", - "else:\n", - " audio = audio_utils.load_audio(audio_path, sample_rate)\n", + "audio = audio_utils.load_audio(audio_path, sample_rate)\n", "\n", "# Display the full file.\n", "display.plot_audio_melspec(audio, sample_rate)\n", @@ -128,9 +125,11 @@ "# Display the selected window.\n", "print('-' * 80)\n", "print('Selected audio window:')\n", - "# TODO(tomdenton): Pad or shift if too close to the end of the file.\n", "st = int(start_s * sample_rate)\n", "end = int(st + window_s * sample_rate)\n", + "if end \u003e audio.shape[0]:\n", + " end = audio.shape[0]\n", + " st = max([0, end - window_s * sample_rate])\n", "audio_window = audio[st:end]\n", "display.plot_audio_melspec(audio_window, sample_rate)\n", "\n",