-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #216 from Renumics/add_hf_embedding_audio
Add hf embedding audio
- Loading branch information
Showing
1 changed file
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,307 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Create embeddings with the transformer library\n", | ||
"\n", | ||
"We use the Huggingface transformers library to create an embedding for an audio dataset. For more data-centric AI workflows, check out our [Awesome Open Data-centric AI](https://github.com/Renumics/awesome-open-data-centric-ai) list on Github\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## tldr; Play as callable functions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Install required packages with PIP" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install renumics-spotlight transformers torch datasets umap-learn numpy" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Play as copy-and-paste functions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import datasets\n", | ||
"from transformers import AutoFeatureExtractor, AutoModel, ASTForAudioClassification\n", | ||
"import torch\n", | ||
"from renumics import spotlight\n", | ||
"import pandas as pd\n", | ||
"import umap\n", | ||
"import numpy as np\n", | ||
"import requests\n", | ||
"import json\n", | ||
"\n", | ||
"def __set_device():\n", | ||
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\" \n", | ||
" if device == \"cuda\":\n", | ||
" torch.cuda.empty_cache()\n", | ||
" return device\n", | ||
"\n", | ||
"\n", | ||
"def extract_embeddings(model, feature_extractor):\n", | ||
" \"\"\"Utility to compute embeddings.\"\"\"\n", | ||
" device = model.device\n", | ||
"\n", | ||
" def pp(batch):\n", | ||
" audios = [element[\"array\"] for element in batch[\"audio\"]]\n", | ||
" inputs = feature_extractor(raw_speech=audios, return_tensors=\"pt\", padding=True).to(device) \n", | ||
" embeddings = model(**inputs).last_hidden_state[:, 0].cpu()\n", | ||
" \n", | ||
" return {\"embedding\": embeddings}\n", | ||
" \n", | ||
"\n", | ||
" return pp\n", | ||
"\n", | ||
"\n", | ||
"def huggingface_embedding(dataset, modelname, batched=True, batch_size=8):\n", | ||
" # initialize huggingface model\n", | ||
" feature_extractor = AutoFeatureExtractor.from_pretrained(modelname, padding=True)\n", | ||
" model = AutoModel.from_pretrained(modelname, output_hidden_states=True)\n", | ||
"\n", | ||
" #compute embedding \n", | ||
" device = __set_device()\n", | ||
" extract_fn = extract_embeddings(model.to(device), feature_extractor)\n", | ||
" updated_dataset = dataset.map(extract_fn, batched=batched, batch_size=batch_size)\n", | ||
" \n", | ||
" return updated_dataset\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Step-by-step example on speech-commands\n", | ||
"\n", | ||
"### Load speech-commands from Huggingface hub" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Map enrichment on subset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"\n", | ||
"dataset = datasets.load_dataset('speech_commands', 'v0.01', split=\"all\")\n", | ||
"labels = dataset.features[\"label\"].names\n", | ||
"num_rows = dataset.num_rows" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"subset_dataset = dataset.select([random.randint(0, num_rows) for i in range(100)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's have a look at all of the labels that we want to predict" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(labels)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Compute embedding with audio transformer from Huggingface" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset_enriched = huggingface_embedding(subset_dataset, \"MIT/ast-finetuned-speech-commands-v2\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Reduce embeddings for faster visualization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"embeddings = np.stack(np.array(dataset_enriched['embedding']))\n", | ||
"reducer = umap.UMAP()\n", | ||
"reduced_embedding = reducer.fit_transform(embeddings)\n", | ||
"dataset_enriched = dataset_enriched.add_column(\"embedding_reduced\", list(reduced_embedding))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Perform EDA with Spotlight" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = dataset_enriched.to_pandas()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df.head(10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df_show = df.drop(columns=[\"embedding\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Perform EDA with Spotlight\n", | ||
"\n", | ||
"> ⚠️ Running Spotlight in Colab currently has severe limitations (slow, no similarity map, no layouts) due to Colab restrictions (e.g. no websocket support). Run the notebook locally for the full Spotlight experience" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# handle google colab differently\n", | ||
"import sys\n", | ||
"\n", | ||
"IN_COLAB = \"google.colab\" in sys.modules\n", | ||
"\n", | ||
"if IN_COLAB:\n", | ||
" # visualization in Google Colab only works in chrome and does not support websockets, we need some hacks to visualize something\n", | ||
" df_show[\"embx\"] = [emb[0] for emb in df_show[\"embedding_reduced\"]]\n", | ||
" df_show[\"emby\"] = [emb[1] for emb in df_show[\"embedding_reduced\"]]\n", | ||
" port = 50123\n", | ||
" layout_url = \"https://raw.githubusercontent.com/Renumics/spotlight/main/playbook/rookie/embedding_layout_colab.json\"\n", | ||
" response = requests.get(layout_url)\n", | ||
" layout = spotlight.layout.nodes.Layout(**json.loads(response.text))\n", | ||
" spotlight.show(df_show, port=port, dtype={\"audio\": spotlight.Audio}, layout=layout)\n", | ||
" from google.colab.output import eval_js # type: ignore\n", | ||
"\n", | ||
" print(str(eval_js(f\"google.colab.kernel.proxyPort({port}, {{'cache': true}})\")))\n", | ||
"\n", | ||
"else:\n", | ||
" layout_url = \"https://raw.githubusercontent.com/Renumics/spotlight/main/playbook/rookie/embedding_layout.json\"\n", | ||
" response = requests.get(layout_url)\n", | ||
" layout = spotlight.layout.nodes.Layout(**json.loads(response.text))\n", | ||
" spotlight.show(\n", | ||
" df_show,\n", | ||
" dtype={\"audio\": spotlight.Audio, \"embedding_reduced\": spotlight.Embedding},\n", | ||
" layout=layout,\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Optional: Save enriched dataframe to disk" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#dataset_enriched.to_parquet('dataset_audio_annotated_and_embedding.parquet.gzip', compression='gzip')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "a0712b5d67739eb35ad63a27edd5fddd52f439fecfdecbb3b06a7c473af2eb31" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |