Skip to content

Commit

Permalink
Merge pull request #4 from idiap/add-jax-support
Browse files Browse the repository at this point in the history
Add Jax support
  • Loading branch information
PuckCh authored Sep 24, 2024
2 parents b17ce0c + ec22522 commit 86cc708
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ jobs:
run: |
pip install --upgrade pip
pip install build
- name: Patch the README links to point to the correct files at the current tag
run: |
perl -i.bak -pe 's/\[(.*)\]\((?!http)\.?\/?(.*)\.([a-z]+)\)/[\1](https:\/\/raw.github.com\/idiap\/RawSpeechClassification\/${{ github.event.release.tag_name }}\/\2.\3)/g' README.md
perl -i.bak -pe 's/\[(.*)\]\((?!http)\.?\/?(.*)\)/[\1](https:\/\/github.com\/idiap\/RawSpeechClassification\/tree\/${{ github.event.release.tag_name }}\/\2)/g' README.md
rm README.md.bak
- name: Package the project
run: python -m build
- name: Produce a GitHub actions artifact (the package)
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ SPDX-License-Identifier: GPL-3.0-only

# Changelog

## September 2024

- Add Jax backend
- Make pip installable package on PyPi

## August 2024

- Update the code for Keras 3 with PyTorch or Tensorflow backend
Expand Down
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ SPDX-License-Identifier: GPL-3.0-only

# Raw Speech Classification

[![PyPI package](https://shields.io/pypi/v/raw-speech-classification.svg?logo=pypi)](https://pypi.org/project/raw-speech-classification)

Trains CNN (or any neural network based) classifiers from raw speech using Keras and
tests them. The inputs are lists of wav files, where each file is labelled. It then
creates fixed length signals and processes them. During testing, it computes scores at
Expand All @@ -14,7 +16,7 @@ the fixed length signals.

## Installation

### From source, In a Conda environment
### From source in a conda environment

To install Keras 3 with PyTorch backend, run:

Expand All @@ -28,6 +30,12 @@ To install Keras 3 with TensorFlow backend, run:
conda env create -f conda/rsclf-tensorflow.yaml
```

To install Keras 3 with Jax backend, run:

```bash
conda env create -f conda/rsclf-jax.yaml
```

Then install the package in that environment (the default name is `rsclf`) with:

```bash
Expand All @@ -49,15 +57,28 @@ or
pip install raw-speech-classification[tensorflow]
```

You'll also need to set the `KERAS_BACKEND` environment variable to the correct backend
or

```bash
pip install raw-speech-classification[jax]
```

If you already have an environment with PyTorch, TensorFlow, or Jax
installed, you can simply run:

```bash
pip install raw-speech-classification
```

You will also need to set the `KERAS_BACKEND` environment variable to the correct backend
before running `rsclf-train` or `rsclf-test` (see below), or globally for the current
bash session with:

```bash
export KERAS_BACKEND=torch
```

Replace `torch` by `tensorflow` accordingly.
Replace `torch` by `tensorflow` or `jax` accordingly.

## Using the code

Expand All @@ -68,7 +89,7 @@ Replace `torch` by `tensorflow` accordingly.
`root` option could be `/home/bob/data/my_dataset` and the content of the files would
then be like:

```txt
```text
part1/file1.wav 1
part1/file2.wav 0
```
Expand Down Expand Up @@ -127,7 +148,7 @@ obtain the following curve in `results/seg-f1/plot.png`:
probabilities. If you need the results per speaker, configure it accordingly (see the
script for details). The default output format is:

```txt
```text
<speakerID> <label> [<posterior_probability_vector>]
```

Expand Down
16 changes: 16 additions & 0 deletions conda/rsclf-jax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright © Idiap Research Institute <[email protected]>
#
# SPDX-License-Identifier: GPL-3.0-only

name: rsclf
dependencies:
- python=3.11
- pip
- pip:
- keras
- h5py
- scipy
- jax[cuda12]
- matplotlib
- numpy
- polars
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "raw-speech-classification"
version = "1.0.1"
version = "1.0.2"
license = {file = "LICENSES/GPL-3.0-only.txt"}
authors = [
{ name = "S. Pavankumar Dubagunta" },
Expand Down Expand Up @@ -50,6 +50,9 @@ torch = [
tensorflow = [
"tensorflow[and-cuda]",
]
jax = [
"jax[cuda12]",
]
dev = [
"pre-commit",
]
Expand Down

0 comments on commit 86cc708

Please sign in to comment.