Skip to content

Commit

Permalink
Update clustering script. Add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
SGenheden committed Apr 13, 2021
1 parent d4cd1e7 commit 6775d39
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 32 deletions.
65 changes: 33 additions & 32 deletions route_distances/tools/cluster_aizynth_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,27 @@
import warnings
import functools
import time
from typing import List, Optional
from typing import List

import pandas as pd
from tqdm import tqdm

from route_distances.route_distances import route_distances_calculator
from route_distances.clustering import ClusteringHelper
from route_distances.utils.type_utils import StrDict, RouteDistancesCalculator
from route_distances.utils.type_utils import RouteDistancesCalculator


def _get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
"Tool to calculate pairwise distances for AiZynthFinder output"
)
parser.add_argument("--files", nargs="+", required=True)
parser.add_argument("--only_clustering", action="store_true", default=False)
parser.add_argument("--nclusters", type=int, default=None)
parser.add_argument("--output", default="finder_output_dist.hdf5")
return parser.parse_args()


def _make_empty_dict(nclusters: Optional[int]) -> StrDict:
dict_ = {"distance_matrix": [[0.0]], "distances_time": 0}
if nclusters is not None:
dict_["cluster_labels"] = []
dict_["cluster_time"] = 0.0
return dict_


def _merge_inputs(filenames: List[str]) -> pd.DataFrame:
data = None
for filename in filenames:
Expand All @@ -44,49 +37,57 @@ def _merge_inputs(filenames: List[str]) -> pd.DataFrame:
return data


def _calc_distances(
row: pd.Series, nclusters: Optional[int], calculator: RouteDistancesCalculator
) -> pd.Series:
def _calc_distances(row: pd.Series, calculator: RouteDistancesCalculator) -> pd.Series:
if len(row.trees) == 1:
return pd.Series(_make_empty_dict(nclusters))
return pd.Series({"distance_matrix": [[0.0]], "distances_time": 0})

time0 = time.perf_counter_ns()
distances = calculator(row.trees)
dict_ = {
"distance_matrix": distances.tolist(),
"distances_time": (time.perf_counter_ns() - time0) * 1e-9,
}
return pd.Series(dict_)

if nclusters is not None:
time0 = time.perf_counter_ns()
dict_["cluster_labels"] = ClusteringHelper.cluster(
distances, nclusters
).tolist()
dict_["cluster_time"] = (time.perf_counter_ns() - time0) * 1e-9

return pd.Series(dict_)
def _do_clustering(row: pd.Series, nclusters: int) -> pd.Series:
if row.distance_matrix == [[0.0]] or len(row.trees) < 3:
return pd.Series({"cluster_labels": [], "cluster_time": 0})

time0 = time.perf_counter_ns()
labels = ClusteringHelper.cluster(row.distance_matrix, nclusters).tolist()
cluster_time = (time.perf_counter_ns() - time0) * 1e-9
return pd.Series({"cluster_labels": labels, "cluster_time": cluster_time})


def main() -> None:
""" Entry-point for CLI tool """
args = _get_args()
tqdm.pandas()

calculator = route_distances_calculator("ted", content="both")
data = _merge_inputs(args.files)

func = functools.partial(
_calc_distances, nclusters=args.nclusters, calculator=calculator
)
dist_data = data.progress_apply(func, axis=1)
data = data.assign(
distance_matrix=dist_data.distance_matrix,
distances_time=dist_data.distances_time,
)
if args.only_clustering:
calculator = None
else:
calculator = route_distances_calculator("ted", content="both")

if not args.only_clustering:
func = functools.partial(
_calc_distances, calculator=calculator
)
dist_data = data.progress_apply(func, axis=1)
data = data.assign(
distance_matrix=dist_data.distance_matrix,
distances_time=dist_data.distances_time,
)

if args.nclusters is not None:
func = functools.partial(_do_clustering, nclusters=args.nclusters)
cluster_data = data.progress_apply(func, axis=1)
data = data.assign(
cluster_labels=dist_data.cluster_labels,
cluster_time=dist_data.cluster_time,
cluster_labels=cluster_data.cluster_labels,
cluster_time=cluster_data.cluster_time,
)

with warnings.catch_warnings(): # This wil suppress a PerformanceWarning
Expand Down
Binary file added tests/data/finder_output_example.hdf5
Binary file not shown.
102 changes: 102 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import sys

import pytest
import pandas as pd

from route_distances.tools.cluster_aizynth_output import (
main as calc_route_dist_main,
)


@pytest.fixture
def add_cli_arguments():
saved_argv = list(sys.argv)

def wrapper(args):
sys.argv = [sys.argv[0]] + args.split(" ")

yield wrapper
sys.argv = saved_argv


def test_calc_route_distances(shared_datadir, add_cli_arguments):
arguments = [
f"--files {shared_datadir / 'finder_output_example.hdf5'}",
f"--output {shared_datadir/ 'temp_out.hdf5'}",
]
add_cli_arguments(" ".join(arguments))

calc_route_dist_main()

data = pd.read_hdf(str(shared_datadir / "temp_out.hdf5"), "table")

assert "distances_time" in data.columns
assert "cluster_time" not in data.columns
assert "cluster_labels" not in data.columns

dist_mat = data.iloc[0].distance_matrix
assert len(dist_mat) == 3
assert pytest.approx(dist_mat[0][1], abs=1e-2) == 4.0596
assert pytest.approx(dist_mat[0][2], abs=1e-2) == 4.7446
assert pytest.approx(dist_mat[2][1], abs=1e-2) == 1.3149

dist_mat = data.iloc[1].distance_matrix
assert len(dist_mat) == 2
assert pytest.approx(dist_mat[0][1], abs=1e-2) == 4.0596

assert data.iloc[2].distance_matrix == [[0.0]]


def test_calc_route_clustering(shared_datadir, add_cli_arguments):
arguments = [
f"--files {shared_datadir / 'finder_output_example.hdf5'}",
f"--output {shared_datadir/ 'temp_out.hdf5'}",
"--nclusters 0",
]
add_cli_arguments(" ".join(arguments))

calc_route_dist_main()

data = pd.read_hdf(str(shared_datadir / "temp_out.hdf5"), "table")

assert "distances_time" in data.columns
assert "cluster_time" in data.columns
assert "cluster_labels" in data.columns

assert data.iloc[0].cluster_labels == [1, 0, 0]
assert data.iloc[1].cluster_labels == []
assert data.iloc[2].cluster_labels == []


def test_calc_route_only_clustering(shared_datadir, add_cli_arguments):
temp_file = str(shared_datadir / "temp_out.hdf5")
arguments = [
f"--files {shared_datadir / 'finder_output_example.hdf5'}",
f"--output {temp_file}",
]
add_cli_arguments(" ".join(arguments))
calc_route_dist_main()
# Read in the created file and remove distances_time column
data = pd.read_hdf(temp_file, "table")
data = data[["trees", "distance_matrix"]]
data.to_hdf(temp_file, "table")

arguments = [
f"--files {temp_file}",
f"--output {shared_datadir / 'temp_out2.hdf5'}",
"--nclusters 0",
"--only_clustering",
]
add_cli_arguments(" ".join(arguments))

calc_route_dist_main()

data = pd.read_hdf(str(shared_datadir / "temp_out2.hdf5"), "table")

assert "distances_time" not in data.columns
assert "cluster_time" in data.columns
assert "cluster_labels" in data.columns

assert data.iloc[0].cluster_labels == [1, 0, 0]
assert data.iloc[1].cluster_labels == []
assert data.iloc[2].cluster_labels == []

0 comments on commit 6775d39

Please sign in to comment.