-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added datasets module, test cases and updated Read the docs examples (#…
…117)
- Loading branch information
1 parent
c436275
commit 4595264
Showing
6 changed files
with
286 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,5 @@ | |
"Explainer", | ||
"MultiLabelLocalClassifierPerNode", | ||
"MultiLabelLocalClassifierPerParentNode", | ||
"datasets", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""Datasets util for downloading and maintaining sample datasets.""" | ||
|
||
import requests | ||
import pandas as pd | ||
import os | ||
import tempfile | ||
import logging | ||
from sklearn.model_selection import train_test_split | ||
|
||
# Configure logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
# Use temp directory to store cached datasets | ||
CACHE_DIR = tempfile.gettempdir() | ||
|
||
# Ensure cache directory exists | ||
os.makedirs(CACHE_DIR, exist_ok=True) | ||
|
||
# Dataset urls | ||
PLATYPUS_URL = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv" | ||
HIERARCHICAL_TEXT_CLASSIFICATION_URL = ( | ||
"https://zenodo.org/record/6657410/files/train_40k.csv?download=1" | ||
) | ||
|
||
|
||
def _download_file(url, destination): | ||
"""Download file from given URL to specified destination.""" | ||
try: | ||
response = requests.get(url) | ||
# Raise HTTPError if response code is not OK | ||
response.raise_for_status() | ||
with open(destination, "wb") as f: | ||
f.write(response.content) | ||
except requests.RequestException as e: | ||
raise RuntimeError(f"Failed to download file from {url}: {str(e)}") | ||
|
||
|
||
def load_platypus(test_size=0.3, random_state=42): | ||
""" | ||
Load platypus diseases dataset. | ||
Parameters | ||
---------- | ||
test_size : float, default=0.3 | ||
The proportion of the dataset to include in the test split. | ||
random_state : int or None, default=42 | ||
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. | ||
Returns | ||
------- | ||
list | ||
List containing train-test split of inputs. | ||
Raises | ||
------ | ||
RuntimeError | ||
If failed to access or process the dataset. | ||
Examples | ||
-------- | ||
>>> from hiclass.datasets import load_platypus | ||
>>> X_train, X_test, Y_train, Y_test = load_platypus() | ||
>>> X_train[:3] | ||
fever diarrhea stomach pain skin rash cough sniffles short breath headache size | ||
220 37.8 0 3 5 1 1 0 2 27.6 | ||
539 37.2 0 6 1 1 1 0 3 28.4 | ||
326 39.9 0 2 5 1 1 1 2 30.7 | ||
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape | ||
(572, 9) (246, 9) (572,) (246,) | ||
""" | ||
dataset_name = "platypus_diseases.csv" | ||
cached_file_path = os.path.join(CACHE_DIR, dataset_name) | ||
|
||
# Check if the file exists in the cache | ||
if not os.path.exists(cached_file_path): | ||
try: | ||
logger.info("Downloading platypus diseases dataset..") | ||
_download_file(PLATYPUS_URL, cached_file_path) | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to access or download dataset: {str(e)}") | ||
|
||
data = pd.read_csv(cached_file_path).fillna(" ") | ||
X = data.drop(["label"], axis=1) | ||
y = pd.Series([eval(val) for val in data["label"]]) | ||
|
||
# Return tuple (X_train, X_test, y_train, y_test) | ||
return train_test_split(X, y, test_size=test_size, random_state=random_state) | ||
|
||
|
||
def load_hierarchical_text_classification(test_size=0.3, random_state=42): | ||
""" | ||
Load hierarchical text classification dataset. | ||
Parameters | ||
---------- | ||
test_size : float, default=0.3 | ||
The proportion of the dataset to include in the test split. | ||
random_state : int or None, default=42 | ||
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. | ||
Returns | ||
------- | ||
list | ||
List containing train-test split of inputs. | ||
Raises | ||
------ | ||
RuntimeError | ||
If failed to access or process the dataset. | ||
Examples | ||
-------- | ||
>>> from hiclass.datasets import load_hierarchical_text_classification | ||
>>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification() | ||
>>> X_train[:3] | ||
38015 Nature's Way Selenium | ||
2281 Music In Motion Developmental Mobile W Remote | ||
36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20... | ||
Name: Title, dtype: object | ||
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape | ||
(28000,) (12000,) (28000, 3) (12000, 3) | ||
""" | ||
dataset_name = "hierarchical_text_classification.csv" | ||
cached_file_path = os.path.join(CACHE_DIR, dataset_name) | ||
|
||
# Check if the file exists in the cache | ||
if not os.path.exists(cached_file_path): | ||
try: | ||
logger.info("Downloading hierarchical text classification dataset..") | ||
_download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path) | ||
except Exception as e: | ||
raise RuntimeError(f"Failed to access or download dataset: {str(e)}") | ||
|
||
data = pd.read_csv(cached_file_path).fillna(" ") | ||
X = data["Title"] | ||
y = data[["Cat1", "Cat2", "Cat3"]] | ||
|
||
# Return tuple (X_train, X_test, y_train, y_test) | ||
return train_test_split(X, y, test_size=test_size, random_state=random_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import hiclass.datasets | ||
from hiclass.datasets import load_platypus, load_hierarchical_text_classification | ||
import os | ||
import tempfile | ||
|
||
|
||
def test_load_platypus_output_shape(): | ||
X_train, X_test, y_train, y_test = load_platypus(test_size=0.2, random_state=42) | ||
assert X_train.shape[0] == y_train.shape[0] | ||
assert X_test.shape[0] == y_test.shape[0] | ||
|
||
|
||
def test_load_platypus_random_state(): | ||
X_train_1, X_test_1, y_train_1, y_test_1 = load_platypus( | ||
test_size=0.2, random_state=42 | ||
) | ||
X_train_2, X_test_2, y_train_2, y_test_2 = load_platypus( | ||
test_size=0.2, random_state=42 | ||
) | ||
assert (X_train_1.values == X_train_2.values).all() | ||
assert (X_test_1.values == X_test_2.values).all() | ||
assert (y_train_1.index == y_train_2.index).all() | ||
assert (y_test_1.index == y_test_2.index).all() | ||
|
||
|
||
def test_load_hierarchical_text_classification_shape(): | ||
X_train, X_test, y_train, y_test = load_hierarchical_text_classification( | ||
test_size=0.2, random_state=42 | ||
) | ||
assert X_train.shape[0] == y_train.shape[0] | ||
assert X_test.shape[0] == y_test.shape[0] | ||
|
||
|
||
def test_load_hierarchical_text_classification_random_state(): | ||
X_train_1, X_test_1, y_train_1, y_test_1 = load_hierarchical_text_classification( | ||
test_size=0.2, random_state=42 | ||
) | ||
X_train_2, X_test_2, y_train_2, y_test_2 = load_hierarchical_text_classification( | ||
test_size=0.2, random_state=42 | ||
) | ||
assert (X_train_1 == X_train_2).all() | ||
assert (X_test_1 == X_test_2).all() | ||
assert (y_train_1.index == y_train_2.index).all() | ||
assert (y_test_1.index == y_test_2.index).all() | ||
|
||
|
||
def test_load_hierarchical_text_classification_file_exists(): | ||
dataset_name = "hierarchical_text_classification.csv" | ||
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) | ||
|
||
if os.path.exists(cached_file_path): | ||
os.remove(cached_file_path) | ||
|
||
if not os.path.exists(cached_file_path): | ||
load_hierarchical_text_classification() | ||
assert os.path.exists(cached_file_path) | ||
|
||
|
||
def test_load_platypus_file_exists(): | ||
dataset_name = "platypus_diseases.csv" | ||
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) | ||
|
||
if os.path.exists(cached_file_path): | ||
os.remove(cached_file_path) | ||
|
||
if not os.path.exists(cached_file_path): | ||
load_platypus() | ||
assert os.path.exists(cached_file_path) | ||
|
||
|
||
def test_download_dataset(): | ||
dataset_name = "platypus_diseases_test.csv" | ||
url = hiclass.datasets.PLATYPUS_URL | ||
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) | ||
|
||
if os.path.exists(cached_file_path): | ||
os.remove(cached_file_path) | ||
|
||
if not os.path.exists(cached_file_path): | ||
hiclass.datasets._download_file(url, cached_file_path) | ||
assert os.path.exists(cached_file_path) | ||
|
||
|
||
def test_download_error_load_platypus(): | ||
dataset_name = "platypus_diseases.csv" | ||
backup_url = hiclass.datasets.PLATYPUS_URL | ||
hiclass.datasets.PLATYPUS_URL = "" | ||
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) | ||
|
||
if os.path.exists(cached_file_path): | ||
os.remove(cached_file_path) | ||
|
||
if not os.path.exists(cached_file_path): | ||
with pytest.raises(RuntimeError): | ||
load_platypus() | ||
|
||
hiclass.datasets.PLATYPUS_URL = backup_url | ||
|
||
|
||
def test_download_error_load_hierarchical_text(): | ||
dataset_name = "hierarchical_text_classification.csv" | ||
backup_url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL | ||
hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = "" | ||
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) | ||
|
||
if os.path.exists(cached_file_path): | ||
os.remove(cached_file_path) | ||
|
||
if not os.path.exists(cached_file_path): | ||
with pytest.raises(RuntimeError): | ||
load_hierarchical_text_classification() | ||
|
||
hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = backup_url | ||
|
||
|
||
def test_url_links(): | ||
assert hiclass.datasets.PLATYPUS_URL != "" | ||
assert hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL != "" |