Skip to content

Commit

Permalink
Add example for custom CSV dataset (tracel-ai#1129)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 11, 2024
1 parent f43b686 commit 535458e
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ wasm-bindgen-futures = "0.4.38"
wasm-logger = "0.2.0"
wasm-timer = "0.2.5"
console_error_panic_hook = "0.1.7"
reqwest = "0.11.23"


# WGPU stuff
Expand Down
37 changes: 29 additions & 8 deletions burn-dataset/src/dataset/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,18 @@ where

/// Create from a csv file.
///
/// The first line of the csv file must be the header. The header must contain the name of the fields in the struct.
/// The provided `csv::ReaderBuilder` can be configured to fit your csv format.
///
/// The supported field types are: String, integer, float, and bool.
///
/// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(reader);
/// See:
/// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
/// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)
pub fn from_csv<P: AsRef<Path>>(
path: P,
builder: &csv::ReaderBuilder,
) -> Result<Self, std::io::Error> {
let mut rdr = builder.from_path(path)?;

let mut items = Vec::new();

Expand All @@ -97,6 +100,7 @@ mod tests {
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
const JSON_FILE: &str = "tests/data/dataset.json";
const CSV_FILE: &str = "tests/data/dataset.csv";
const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";

type SqlDs = SqliteDataset<Sample>;

Expand All @@ -110,7 +114,7 @@ mod tests {
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SampleCvs {
pub struct SampleCsv {
column_str: String,
column_int: i64,
column_bool: bool,
Expand Down Expand Up @@ -147,7 +151,24 @@ mod tests {

#[test]
pub fn from_csv_rows() {
let dataset = InMemDataset::<SampleCvs>::from_csv(CSV_FILE).unwrap();
let rdr = csv::ReaderBuilder::new();
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();

let non_existing_record_index: usize = 10;
let record_index: usize = 1;

assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
assert!(!dataset.get(record_index).unwrap().column_bool);
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
}

#[test]
pub fn from_csv_rows_fmt() {
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b' ').has_headers(false);
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();

let non_existing_record_index: usize = 10;
let record_index: usize = 1;
Expand Down
2 changes: 2 additions & 0 deletions burn-dataset/tests/data/dataset-fmt.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
HI1 1 true 1.0
HI2 1 false 1.0
2 changes: 2 additions & 0 deletions examples/custom-csv-dataset/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Ignore downloaded csv file
*.csv
22 changes: 22 additions & 0 deletions examples/custom-csv-dataset/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
edition.workspace = true
license.workspace = true
name = "custom-csv-dataset"
description = "Example implementation for loading a custom CSV dataset from disk"
publish = false
version.workspace = true

[features]
default = ["burn/dataset"]

[dependencies]
burn = {path = "../../burn"}

# File download
reqwest = {workspace = true, features = ["blocking"]}
tempfile = {workspace = true}

# CSV parsing
csv = {workspace = true}
serde = {workspace = true, features = ["std", "derive"]}
11 changes: 11 additions & 0 deletions examples/custom-csv-dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Custom CSV Dataset

The [custom-csv-dataset](src/dataset.rs) example implements the `Dataset` trait to retrieve dataset elements from a `.csv` file on disk. For this example, we use the [diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) (original [source](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html)).

The dataset only contains 442 records, so we use [`InMemDataset::from_csv(path)`](src/dataset.rs#L80) method to read the csv dataset file into a vector (in-memory) of [`DiabetesPatient`](src/dataset.rs#L13) records (struct) with the help of `serde`.

## Example Usage

```sh
cargo run --example custom-csv-dataset
```
15 changes: 15 additions & 0 deletions examples/custom-csv-dataset/examples/custom-csv-dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use burn::data::dataset::Dataset;
use custom_csv_dataset::dataset::DiabetesDataset;

fn main() {
let dataset = DiabetesDataset::new().expect("Could not load diabetes dataset");

println!("Dataset loaded with {} rows", dataset.len());

// Display first and last elements
let item = dataset.get(0).unwrap();
println!("First item:\n{:?}", item);

let item = dataset.get(441).unwrap();
println!("Last item:\n{:?}", item);
}
121 changes: 121 additions & 0 deletions examples/custom-csv-dataset/src/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use burn::data::dataset::{Dataset, InMemDataset};
use serde::{Deserialize, Serialize};
use std::{
fs::File,
io::copy,
path::{Path, PathBuf},
};

/// Diabetes patient record.
/// For each field, we manually specify the expected header name for serde as all names
/// are capitalized and some field names are not very informative.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DiabetesPatient {
/// Age in years
#[serde(rename = "AGE")]
pub age: u8,

/// Sex categorical label
#[serde(rename = "SEX")]
pub sex: u8,

/// Body mass index
#[serde(rename = "BMI")]
pub bmi: f32,

/// Average blood pressure
#[serde(rename = "BP")]
pub bp: f32,

/// S1: total serum cholesterol
#[serde(rename = "S1")]
pub tc: u16,

/// S2: low-density lipoproteins
#[serde(rename = "S2")]
pub ldl: f32,

/// S3: high-density lipoproteins
#[serde(rename = "S3")]
pub hdl: f32,

/// S4: total cholesterol
#[serde(rename = "S4")]
pub tch: f32,

/// S5: possibly log of serum triglycerides level
#[serde(rename = "S5")]
pub ltg: f32,

/// S6: blood sugar level
#[serde(rename = "S6")]
pub glu: u8,

/// Y: quantitative measure of disease progression one year after baseline
#[serde(rename = "Y")]
pub response: u16,
}

/// Diabetes patients dataset, also used in [scikit-learn](https://scikit-learn.org/stable/).
/// See [Diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset).
///
/// The data is parsed from a single csv file (tab as the delimiter).
/// The dataset contains 10 baseline variables (age, sex, body mass index, average blood pressure and
/// 6 blood serum measurements for a total of 442 diabetes patients.
/// For each patient, the response of interest, a quantitative measure of disease progression one year
/// after baseline, was collected. This represents the target variable.
pub struct DiabetesDataset {
dataset: InMemDataset<DiabetesPatient>,
}

impl DiabetesDataset {
pub fn new() -> Result<Self, std::io::Error> {
// Download dataset csv file
let path = DiabetesDataset::download();

// Build dataset from csv with tab ('\t') delimiter
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b'\t');

let dataset = InMemDataset::from_csv(path, rdr).unwrap();

let dataset = Self { dataset };

Ok(dataset)
}
/// Download the CSV file from its original source on the web.
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
fn download() -> PathBuf {
// Point file to current example directory
let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap();
let file_name = example_dir.join("diabetes.csv");

if file_name.exists() {
println!("File already downloaded at {:?}", file_name);
} else {
// Get file from web
println!("Downloading file to {:?}", file_name);
let url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt";
let mut response = reqwest::blocking::get(url).unwrap();

// Create file to write the downloaded content to
let mut file = File::create(&file_name).unwrap();

// Copy the downloaded contents
copy(&mut response, &mut file).unwrap();
};

file_name
}
}

// Implement the `Dataset` trait which requires `get` and `len`
impl Dataset<DiabetesPatient> for DiabetesDataset {
fn get(&self, index: usize) -> Option<DiabetesPatient> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}
}
1 change: 1 addition & 0 deletions examples/custom-csv-dataset/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod dataset;

0 comments on commit 535458e

Please sign in to comment.