Skip to content

Commit

Permalink
added datasets module, test cases and updated Read the docs examples (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ashishpatel16 authored Mar 28, 2024
1 parent c436275 commit 4595264
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 34 deletions.
20 changes: 3 additions & 17 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,11 @@
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerParentNode, Explainer
import requests
import pandas as pd
import shap
from hiclass.datasets import load_platypus

# Download training data
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
path = "platypus_diseases.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")

# Define data
X_train = training_data.drop(["label"], axis=1)
X_test = X_train[:100] # Use first 100 samples as test set
Y_train = training_data["label"]
Y_train = [eval(my) for my in Y_train]
# Load train and test splits
X_train, X_test, Y_train, Y_test = load_platypus()

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
Expand Down
20 changes: 3 additions & 17 deletions docs/examples/plot_parallel_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,15 @@
"""
import sys
from os import cpu_count

import pandas as pd
import requests
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

from hiclass import LocalClassifierPerParentNode
from hiclass.datasets import load_hierarchical_text_classification


# Download training data
url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
path = "train_40k.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")
# Load train and test splits
X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification()

# We will use logistic regression classifiers for every parent node
lr = LogisticRegression(max_iter=1000)
Expand All @@ -51,10 +41,6 @@
]
)

# Select training data
X_train = training_data["Title"]
Y_train = training_data[["Cat1", "Cat2", "Cat3"]]

# Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno'
# This only happens when building the documentation
# Hence, you don't actually need it for your code to work
Expand Down
20 changes: 20 additions & 0 deletions docs/source/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ F-score
^^^^^^^

.. autofunction:: metrics.f1

..................................


Datasets
----------

Platypus diseases dataset
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: datasets.load_platypus

..................................

Hierarchical text classification dataset
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: datasets.load_hierarchical_text_classification

..................................
1 change: 1 addition & 0 deletions hiclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
"Explainer",
"MultiLabelLocalClassifierPerNode",
"MultiLabelLocalClassifierPerParentNode",
"datasets",
]
138 changes: 138 additions & 0 deletions hiclass/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Datasets util for downloading and maintaining sample datasets."""

import requests
import pandas as pd
import os
import tempfile
import logging
from sklearn.model_selection import train_test_split

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Use temp directory to store cached datasets
CACHE_DIR = tempfile.gettempdir()

# Ensure cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)

# Dataset urls
PLATYPUS_URL = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
HIERARCHICAL_TEXT_CLASSIFICATION_URL = (
"https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
)


def _download_file(url, destination):
"""Download file from given URL to specified destination."""
try:
response = requests.get(url)
# Raise HTTPError if response code is not OK
response.raise_for_status()
with open(destination, "wb") as f:
f.write(response.content)
except requests.RequestException as e:
raise RuntimeError(f"Failed to download file from {url}: {str(e)}")


def load_platypus(test_size=0.3, random_state=42):
"""
Load platypus diseases dataset.
Parameters
----------
test_size : float, default=0.3
The proportion of the dataset to include in the test split.
random_state : int or None, default=42
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls.
Returns
-------
list
List containing train-test split of inputs.
Raises
------
RuntimeError
If failed to access or process the dataset.
Examples
--------
>>> from hiclass.datasets import load_platypus
>>> X_train, X_test, Y_train, Y_test = load_platypus()
>>> X_train[:3]
fever diarrhea stomach pain skin rash cough sniffles short breath headache size
220 37.8 0 3 5 1 1 0 2 27.6
539 37.2 0 6 1 1 1 0 3 28.4
326 39.9 0 2 5 1 1 1 2 30.7
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
(572, 9) (246, 9) (572,) (246,)
"""
dataset_name = "platypus_diseases.csv"
cached_file_path = os.path.join(CACHE_DIR, dataset_name)

# Check if the file exists in the cache
if not os.path.exists(cached_file_path):
try:
logger.info("Downloading platypus diseases dataset..")
_download_file(PLATYPUS_URL, cached_file_path)
except Exception as e:
raise RuntimeError(f"Failed to access or download dataset: {str(e)}")

data = pd.read_csv(cached_file_path).fillna(" ")
X = data.drop(["label"], axis=1)
y = pd.Series([eval(val) for val in data["label"]])

# Return tuple (X_train, X_test, y_train, y_test)
return train_test_split(X, y, test_size=test_size, random_state=random_state)


def load_hierarchical_text_classification(test_size=0.3, random_state=42):
"""
Load hierarchical text classification dataset.
Parameters
----------
test_size : float, default=0.3
The proportion of the dataset to include in the test split.
random_state : int or None, default=42
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls.
Returns
-------
list
List containing train-test split of inputs.
Raises
------
RuntimeError
If failed to access or process the dataset.
Examples
--------
>>> from hiclass.datasets import load_hierarchical_text_classification
>>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification()
>>> X_train[:3]
38015 Nature's Way Selenium
2281 Music In Motion Developmental Mobile W Remote
36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20...
Name: Title, dtype: object
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
(28000,) (12000,) (28000, 3) (12000, 3)
"""
dataset_name = "hierarchical_text_classification.csv"
cached_file_path = os.path.join(CACHE_DIR, dataset_name)

# Check if the file exists in the cache
if not os.path.exists(cached_file_path):
try:
logger.info("Downloading hierarchical text classification dataset..")
_download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path)
except Exception as e:
raise RuntimeError(f"Failed to access or download dataset: {str(e)}")

data = pd.read_csv(cached_file_path).fillna(" ")
X = data["Title"]
y = data[["Cat1", "Cat2", "Cat3"]]

# Return tuple (X_train, X_test, y_train, y_test)
return train_test_split(X, y, test_size=test_size, random_state=random_state)
121 changes: 121 additions & 0 deletions tests/test_Datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
import pytest

import hiclass.datasets
from hiclass.datasets import load_platypus, load_hierarchical_text_classification
import os
import tempfile


def test_load_platypus_output_shape():
X_train, X_test, y_train, y_test = load_platypus(test_size=0.2, random_state=42)
assert X_train.shape[0] == y_train.shape[0]
assert X_test.shape[0] == y_test.shape[0]


def test_load_platypus_random_state():
X_train_1, X_test_1, y_train_1, y_test_1 = load_platypus(
test_size=0.2, random_state=42
)
X_train_2, X_test_2, y_train_2, y_test_2 = load_platypus(
test_size=0.2, random_state=42
)
assert (X_train_1.values == X_train_2.values).all()
assert (X_test_1.values == X_test_2.values).all()
assert (y_train_1.index == y_train_2.index).all()
assert (y_test_1.index == y_test_2.index).all()


def test_load_hierarchical_text_classification_shape():
X_train, X_test, y_train, y_test = load_hierarchical_text_classification(
test_size=0.2, random_state=42
)
assert X_train.shape[0] == y_train.shape[0]
assert X_test.shape[0] == y_test.shape[0]


def test_load_hierarchical_text_classification_random_state():
X_train_1, X_test_1, y_train_1, y_test_1 = load_hierarchical_text_classification(
test_size=0.2, random_state=42
)
X_train_2, X_test_2, y_train_2, y_test_2 = load_hierarchical_text_classification(
test_size=0.2, random_state=42
)
assert (X_train_1 == X_train_2).all()
assert (X_test_1 == X_test_2).all()
assert (y_train_1.index == y_train_2.index).all()
assert (y_test_1.index == y_test_2.index).all()


def test_load_hierarchical_text_classification_file_exists():
dataset_name = "hierarchical_text_classification.csv"
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
load_hierarchical_text_classification()
assert os.path.exists(cached_file_path)


def test_load_platypus_file_exists():
dataset_name = "platypus_diseases.csv"
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
load_platypus()
assert os.path.exists(cached_file_path)


def test_download_dataset():
dataset_name = "platypus_diseases_test.csv"
url = hiclass.datasets.PLATYPUS_URL
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
hiclass.datasets._download_file(url, cached_file_path)
assert os.path.exists(cached_file_path)


def test_download_error_load_platypus():
dataset_name = "platypus_diseases.csv"
backup_url = hiclass.datasets.PLATYPUS_URL
hiclass.datasets.PLATYPUS_URL = ""
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
with pytest.raises(RuntimeError):
load_platypus()

hiclass.datasets.PLATYPUS_URL = backup_url


def test_download_error_load_hierarchical_text():
dataset_name = "hierarchical_text_classification.csv"
backup_url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL
hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = ""
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)

if os.path.exists(cached_file_path):
os.remove(cached_file_path)

if not os.path.exists(cached_file_path):
with pytest.raises(RuntimeError):
load_hierarchical_text_classification()

hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = backup_url


def test_url_links():
assert hiclass.datasets.PLATYPUS_URL != ""
assert hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL != ""

0 comments on commit 4595264

Please sign in to comment.