forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for custom CSV dataset (tracel-ai#1129)
- Loading branch information
Showing
9 changed files
with
204 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
HI1 1 true 1.0 | ||
HI2 1 false 1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Ignore downloaded csv file | ||
*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
[package] | ||
authors = ["guillaumelagrange <[email protected]>"] | ||
edition.workspace = true | ||
license.workspace = true | ||
name = "custom-csv-dataset" | ||
description = "Example implementation for loading a custom CSV dataset from disk" | ||
publish = false | ||
version.workspace = true | ||
|
||
[features] | ||
default = ["burn/dataset"] | ||
|
||
[dependencies] | ||
burn = {path = "../../burn"} | ||
|
||
# File download | ||
reqwest = {workspace = true, features = ["blocking"]} | ||
tempfile = {workspace = true} | ||
|
||
# CSV parsing | ||
csv = {workspace = true} | ||
serde = {workspace = true, features = ["std", "derive"]} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Custom CSV Dataset | ||
|
||
The [custom-csv-dataset](src/dataset.rs) example implements the `Dataset` trait to retrieve dataset elements from a `.csv` file on disk. For this example, we use the [diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) (original [source](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html)). | ||
|
||
The dataset only contains 442 records, so we use [`InMemDataset::from_csv(path)`](src/dataset.rs#L80) method to read the csv dataset file into a vector (in-memory) of [`DiabetesPatient`](src/dataset.rs#L13) records (struct) with the help of `serde`. | ||
|
||
## Example Usage | ||
|
||
```sh | ||
cargo run --example custom-csv-dataset | ||
``` |
15 changes: 15 additions & 0 deletions
15
examples/custom-csv-dataset/examples/custom-csv-dataset.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
use burn::data::dataset::Dataset; | ||
use custom_csv_dataset::dataset::DiabetesDataset; | ||
|
||
fn main() { | ||
let dataset = DiabetesDataset::new().expect("Could not load diabetes dataset"); | ||
|
||
println!("Dataset loaded with {} rows", dataset.len()); | ||
|
||
// Display first and last elements | ||
let item = dataset.get(0).unwrap(); | ||
println!("First item:\n{:?}", item); | ||
|
||
let item = dataset.get(441).unwrap(); | ||
println!("Last item:\n{:?}", item); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
use burn::data::dataset::{Dataset, InMemDataset}; | ||
use serde::{Deserialize, Serialize}; | ||
use std::{ | ||
fs::File, | ||
io::copy, | ||
path::{Path, PathBuf}, | ||
}; | ||
|
||
/// Diabetes patient record. | ||
/// For each field, we manually specify the expected header name for serde as all names | ||
/// are capitalized and some field names are not very informative. | ||
#[derive(Deserialize, Serialize, Debug, Clone)] | ||
pub struct DiabetesPatient { | ||
/// Age in years | ||
#[serde(rename = "AGE")] | ||
pub age: u8, | ||
|
||
/// Sex categorical label | ||
#[serde(rename = "SEX")] | ||
pub sex: u8, | ||
|
||
/// Body mass index | ||
#[serde(rename = "BMI")] | ||
pub bmi: f32, | ||
|
||
/// Average blood pressure | ||
#[serde(rename = "BP")] | ||
pub bp: f32, | ||
|
||
/// S1: total serum cholesterol | ||
#[serde(rename = "S1")] | ||
pub tc: u16, | ||
|
||
/// S2: low-density lipoproteins | ||
#[serde(rename = "S2")] | ||
pub ldl: f32, | ||
|
||
/// S3: high-density lipoproteins | ||
#[serde(rename = "S3")] | ||
pub hdl: f32, | ||
|
||
/// S4: total cholesterol | ||
#[serde(rename = "S4")] | ||
pub tch: f32, | ||
|
||
/// S5: possibly log of serum triglycerides level | ||
#[serde(rename = "S5")] | ||
pub ltg: f32, | ||
|
||
/// S6: blood sugar level | ||
#[serde(rename = "S6")] | ||
pub glu: u8, | ||
|
||
/// Y: quantitative measure of disease progression one year after baseline | ||
#[serde(rename = "Y")] | ||
pub response: u16, | ||
} | ||
|
||
/// Diabetes patients dataset, also used in [scikit-learn](https://scikit-learn.org/stable/). | ||
/// See [Diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset). | ||
/// | ||
/// The data is parsed from a single csv file (tab as the delimiter). | ||
/// The dataset contains 10 baseline variables (age, sex, body mass index, average blood pressure and | ||
/// 6 blood serum measurements for a total of 442 diabetes patients. | ||
/// For each patient, the response of interest, a quantitative measure of disease progression one year | ||
/// after baseline, was collected. This represents the target variable. | ||
pub struct DiabetesDataset { | ||
dataset: InMemDataset<DiabetesPatient>, | ||
} | ||
|
||
impl DiabetesDataset { | ||
pub fn new() -> Result<Self, std::io::Error> { | ||
// Download dataset csv file | ||
let path = DiabetesDataset::download(); | ||
|
||
// Build dataset from csv with tab ('\t') delimiter | ||
let mut rdr = csv::ReaderBuilder::new(); | ||
let rdr = rdr.delimiter(b'\t'); | ||
|
||
let dataset = InMemDataset::from_csv(path, rdr).unwrap(); | ||
|
||
let dataset = Self { dataset }; | ||
|
||
Ok(dataset) | ||
} | ||
/// Download the CSV file from its original source on the web. | ||
/// Panics if the download cannot be completed or the content of the file cannot be written to disk. | ||
fn download() -> PathBuf { | ||
// Point file to current example directory | ||
let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap(); | ||
let file_name = example_dir.join("diabetes.csv"); | ||
|
||
if file_name.exists() { | ||
println!("File already downloaded at {:?}", file_name); | ||
} else { | ||
// Get file from web | ||
println!("Downloading file to {:?}", file_name); | ||
let url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt"; | ||
let mut response = reqwest::blocking::get(url).unwrap(); | ||
|
||
// Create file to write the downloaded content to | ||
let mut file = File::create(&file_name).unwrap(); | ||
|
||
// Copy the downloaded contents | ||
copy(&mut response, &mut file).unwrap(); | ||
}; | ||
|
||
file_name | ||
} | ||
} | ||
|
||
// Implement the `Dataset` trait which requires `get` and `len` | ||
impl Dataset<DiabetesPatient> for DiabetesDataset { | ||
fn get(&self, index: usize) -> Option<DiabetesPatient> { | ||
self.dataset.get(index) | ||
} | ||
|
||
fn len(&self) -> usize { | ||
self.dataset.len() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub mod dataset; |