diff --git a/src/paste3/io.py b/src/paste3/io.py new file mode 100644 index 0000000..f4f1032 --- /dev/null +++ b/src/paste3/io.py @@ -0,0 +1,77 @@ +import scanpy as sc +import numpy as np + +from pathlib import Path +from collections import defaultdict +import logging + + +logger = logging.getLogger(__name__) + + +def process_files(g_fpath, s_fpath, w_fpath=None): + """Returns a list of AnnData objects.""" + + ext = Path(g_fpath[0]).suffix + + if ext == ".csv": + if not (len(s_fpath) == len(g_fpath)): + ValueError("Length of spatial files doesn't equal number of gene files") + _slices = defaultdict() + for file in g_fpath: + # The header of this file is alphanumeric, so this file has to be imported as a string + _slices[get_shape(file)[0]] = sc.read_csv(file) + + for file in s_fpath: + try: + _slice = _slices[get_shape(file)[0]] + except KeyError: + raise ValueError("Incomplete information for a slice") + else: + _slice.obsm["spatial"] = np.genfromtxt( + file, delimiter=",", dtype="float64" + ) + + if w_fpath: + if not (len(w_fpath) == len(g_fpath)): + ValueError("Length of weight files doesn't equal number of gene files") + for file in w_fpath: + _slice = _slices[get_shape(file)[0]] + _slice.obsm["weights"] = np.genfromtxt( + file, delimiter=",", dtype="float64" + ) + else: + for k, v in _slices.items(): + v.obsm["weights"] = np.ones((v.shape[0],)) / v.shape[0] + + slices = list(_slices.values()) + elif ext == ".h5ad": + slices = [sc.read_h5ad(file) for file in g_fpath] + + else: + raise ValueError("Incorrect file type provided ") + + return slices + + +def get_shape(file_path): + """Determines the shapes of the csv without opening the files""" + + def is_numeric(value): + try: + float(value) + return True + except ValueError: + return False + + with open(file_path, "r") as file: + first_line = file.readline().strip() + num_columns = len(first_line.split(",")) + + num_rows = sum(1 for _ in file) + + # Determine if the first row is a header + if all(is_numeric(val) for val in first_line.split(",")): + num_rows += 1 + + return num_rows, num_columns diff --git a/tests/data/input/slice1.h5ad b/tests/data/input/slice1.h5ad new file mode 100644 index 0000000..e2dd51b Binary files /dev/null and b/tests/data/input/slice1.h5ad differ diff --git a/tests/data/input/slice2.h5ad b/tests/data/input/slice2.h5ad new file mode 100644 index 0000000..2260734 Binary files /dev/null and b/tests/data/input/slice2.h5ad differ diff --git a/tests/data/input/slice3.h5ad b/tests/data/input/slice3.h5ad new file mode 100644 index 0000000..10fc00c Binary files /dev/null and b/tests/data/input/slice3.h5ad differ diff --git a/tests/data/input/slice4.h5ad b/tests/data/input/slice4.h5ad new file mode 100644 index 0000000..782ae31 Binary files /dev/null and b/tests/data/input/slice4.h5ad differ diff --git a/tests/test_paste_cmd_line.py b/tests/test_paste_cmd_line.py index ab05af2..823771e 100644 --- a/tests/test_paste_cmd_line.py +++ b/tests/test_paste_cmd_line.py @@ -1,8 +1,10 @@ import pandas as pd +import anndata as ad from pandas.testing import assert_frame_equal from pathlib import Path from collections import namedtuple from paste3.paste_cmd_line import main as paste_cmd_line +from paste3.io import get_shape, process_files test_dir = Path(__file__).parent input_dir = test_dir / "data/input" @@ -111,3 +113,38 @@ def test_cmd_line_pairwise(tmp_path): pd.read_csv(tmp_path / "paste_output/slice1_slice2_pairwise.csv"), pd.read_csv(output_dir / "slices_1_2_pairwise.csv"), ) + + +def test_process_files_csv(): + """Ensure process files works with csv inputs.""" + gene_fpath = [] + spatial_fpath = [] + for i in range(1, 5): + gene_fpath.append(Path(f"{input_dir}/slice{i}.csv")) + spatial_fpath.append(Path(f"{input_dir}/slice{i}_coor.csv")) + + ad_objs = process_files( + gene_fpath, + spatial_fpath, + ) + for obj in ad_objs: + assert isinstance(obj, ad.AnnData) + + +def test_process_files_ann_data(): + """Ensure process files works with Ann Data inputs.""" + gene_fpath = [] + for i in range(1, 5): + gene_fpath.append(Path(f"{input_dir}/slice{i}.h5ad")) + + ad_objs = process_files(gene_fpath, s_fpath=None) + for obj in ad_objs: + assert isinstance(obj, ad.AnnData) + + +def test_get_shape(): + s_fpath = Path(f"{input_dir}/slice1.csv") + c_fpath = Path(f"{input_dir}/slice1_coor.csv") + + assert get_shape(s_fpath) == (254, 7999) + assert get_shape(c_fpath) == (254, 2)