diff --git a/pyproject.toml b/pyproject.toml index af2fe83..d7f466c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ dynamic = ["version"] [project.scripts] -paste = "paste.__main__:main" +paste = "paste3.__main__:main" [tool.setuptools] package-dir = {"" = "src"} diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/paste3/__init__.py b/src/paste3/__init__.py index e69de29..dfee482 100644 --- a/src/paste3/__init__.py +++ b/src/paste3/__init__.py @@ -0,0 +1,32 @@ +import logging.config + + +# The _version.py file is managed by setuptools-scm +# and is not in version control. +try: + from paste3._version import version as __version__ # type: ignore +except ModuleNotFoundError: + # We're likely running as a source package without installation + __version__ = "src" + + +logging.config.dictConfig( + { + "version": 1, + "formatters": { + "standard": { + "format": "(%(levelname)s) (%(filename)s) (%(asctime)s) %(message)s", + "datefmt": "%d-%b-%y %H:%M:%S", + } + }, + "handlers": { + "default": { + "level": "NOTSET", + "formatter": "standard", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + } + }, + "loggers": {"": {"handlers": ["default"], "level": "INFO"}}, + } +) diff --git a/src/paste3/__main__.py b/src/paste3/__main__.py new file mode 100644 index 0000000..d21c3e9 --- /dev/null +++ b/src/paste3/__main__.py @@ -0,0 +1,41 @@ +import logging +import argparse +import os +from paste3 import align +import paste3 + +logger = logging.getLogger("paste3") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--version", action="version", version=paste3.__version__) + + modules = [align] + + subparsers = parser.add_subparsers(title="Choose a command") + subparsers.required = True + + def get_str_name(module): + return os.path.splitext(os.path.basename(module.__file__))[0] + + for module in modules: + this_parser = subparsers.add_parser( + get_str_name(module), description=module.__doc__ + ) + this_parser.add_argument( + "-v", "--verbose", action="store_true", help="Increase verbosity" + ) + + module.add_args(this_parser) + this_parser.set_defaults(func=module.main) + + args = parser.parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/src/paste3/align.py b/src/paste3/align.py new file mode 100644 index 0000000..d163005 --- /dev/null +++ b/src/paste3/align.py @@ -0,0 +1,258 @@ +import ot.backend +import numpy as np +from pathlib import Path + +import pandas as pd + +from paste3.io import process_files +import logging +from paste3.paste import pairwise_align, center_align +from paste3.visualization import stack_slices_pairwise, stack_slices_center + +logger = logging.getLogger(__name__) + + +def align( + mode, + gene_fpath, + spatial_fpath=None, + output_directory="", + alpha=0.1, + cost="kl", + n_components=15, + lmbda=None, + initial_slice=1, + threshold=0.001, + coordinates=False, + weight_fpath=None, + overlap_fraction=None, + start=None, + seed=None, + cost_matrix=None, + max_iter=10, + norm=False, + numItermax=200, + use_gpu=False, + return_obj=False, + optimizeTheta=True, + eps=1e-4, + is_histology=False, + armijo=False, +): + slices = process_files(gene_fpath, spatial_fpath, weight_fpath) + n_slices = len(slices) + + if not (mode == "pairwise" or mode == "center"): + raise (ValueError("Please select either pairwise or center alignment mode.")) + + if alpha < 0 or alpha > 1: + raise (ValueError("Alpha specified outside of 0-1 range.")) + + if initial_slice < 1 or initial_slice > n_slices: + raise (ValueError("Initial specified outside of 0 - n range")) + + if overlap_fraction: + if overlap_fraction < 0 or overlap_fraction > 1: + raise (ValueError("Overlap fraction specified outside of 0-1 range.")) + + if lmbda is None: + lmbda = n_slices * [1 / n_slices] + elif len(lmbda) != n_slices: + raise (ValueError("Length of lambda doesn't equal number of files")) + else: + if not all(i >= 0 for i in lmbda): + raise (ValueError("lambda includes negative weights")) + else: + print("Normalizing lambda weights into probability vector.") + lmbda = [float(i) / sum(lmbda) for i in lmbda] + + if cost_matrix: + cost_matrix = np.genfromtxt(cost_matrix, delimiter=",", dtype="float64") + + if start is None: + pis_init = [None] * (n_slices - 1) if mode == "pairwise" else None + elif mode == "pairwise" and not (len(start) == n_slices - 1): + raise ValueError( + f"Number of slices {n_slices} is not equal to number of start pi files {len(start)}" + ) + else: + pis_init = [np.genfromtxt(pi, delimiter=",") for pi in start] + + # make output directory if it doesn't exist + output_directory = Path(output_directory) + Path.mkdir(output_directory, exist_ok=True) + + if mode == "pairwise": + logger.info("Computing Pairwise Alignment ") + pis = [] + for i in range(n_slices - 1): + pi = pairwise_align( + sliceA=slices[i], + sliceB=slices[i + 1], + s=overlap_fraction, + M=cost_matrix, + alpha=alpha, + dissimilarity=cost, + use_rep=None, + G_init=pis_init[i], + a_distribution=slices[i].obsm["weights"], + b_distribution=slices[i + 1].obsm["weights"], + norm=norm, + numItermax=numItermax, + backend=ot.backend.NumpyBackend(), + use_gpu=use_gpu, + return_obj=return_obj, + maxIter=max_iter, + optimizeTheta=optimizeTheta, + eps=eps, + is_histology=is_histology, + armijo=armijo, + ) + pis.append(pi) + pd.DataFrame( + pi, index=slices[i].obs.index, columns=slices[i + 1].obs.index + ).to_csv(output_directory / f"slice_{i+1}_{i+2}_pairwise.csv") + + if coordinates: + new_slices = stack_slices_pairwise( + slices, pis, is_partial=overlap_fraction is not None + ) + + elif mode == "center": + logger.info("Computing Center Alignment") + initial_slice = slices[initial_slice - 1].copy() + + center_slice, pis = center_align( + A=initial_slice, + slices=slices, + lmbda=lmbda, + alpha=alpha, + n_components=n_components, + threshold=threshold, + max_iter=max_iter, + dissimilarity=cost, + norm=norm, + random_seed=seed, + pis_init=pis_init, + distributions=[slice.obsm["weights"] for slice in slices], + backend=ot.backend.NumpyBackend(), + use_gpu=use_gpu, + ) + + center_slice.write(output_directory / "center_slice.h5ad") + for i in range(len(pis) - 1): + pd.DataFrame( + pis[i], index=center_slice.obs.index, columns=slices[i].obs.index + ).to_csv(output_directory / f"slice_{i}_{i+1}_pairwise.csv") + + if coordinates: + new_slices = stack_slices_center(center_slice, slices, pis) + + if coordinates: + if mode == "center": + center, new_slices = new_slices + center.write(output_directory / "new_center.h5ad") + + for i, slice in enumerate(new_slices, start=1): + slice.write(output_directory / f"new_slices_{i}.h5ad") + + +def add_args(parser): + parser.add_argument( + "mode", type=str, help="Alignment type: 'pairwise' or 'center'." + ) + parser.add_argument( + "--g_fpath", type=str, nargs="+", help="Paths to gene exp files (.csv/ .h5ad)." + ) + parser.add_argument( + "--s_fpath", type=str, nargs="*", help="Paths to spatial data files (.csv)." + ) + parser.add_argument( + "--w_fpath", type=str, nargs="*", help="Paths to spot weight files (.csv)." + ) + parser.add_argument( + "--output_dir", default="./output", help="Directory to save output files." + ) + parser.add_argument( + "--alpha", type=float, default=0.1, help="Alpha param for alignment (0 to 1)." + ) + parser.add_argument( + "--cost", + choices=["kl", "euc", "gkl", "selection_kl", "pca", "glmpca"], + default="kl", + help="Expression dissimilarity cost", + ) + + parser.add_argument( + "--cost_mat", type=str, help="Paths to exp dissimilarity cost matrix." + ) + parser.add_argument( + "--n_comp", type=int, default=15, help="Components for NMF in center alignment." + ) + parser.add_argument( + "--lmbda", type=float, nargs="+", help="Weight vector for each slice." + ) + parser.add_argument( + "--init_slice", type=int, default=1, help="First slice for alignment (1 to n)." + ) + parser.add_argument( + "--thresh", + type=float, + default=1e-3, + help="Convergence threshold for alignment.", + ) + + parser.add_argument( + "--coor", action="store_true", help="Compute and save new coordinates." + ) + parser.add_argument( + "--ovlp_frac", type=float, default=None, help="Overlap fraction (0-1)." + ) + parser.add_argument( + "--start", type=str, nargs="+", help="Paths to initial alignment files." + ) + parser.add_argument( + "--norm", action="store_true", help="Normalize expression data if True." + ) + parser.add_argument("--max_iter", type=int, help="Maximum number of iterations.") + parser.add_argument( + "--gpu", action="store_true", help="Use GPU for processing if True." + ) + parser.add_argument("--r_info", action="store_true", help="Returns log if True.") + parser.add_argument( + "--hist", action="store_true", help="Use histological images if True." + ) + parser.add_argument( + "--armijo", action="store_true", help="Run Armijo line search if True." + ) + parser.add_argument( + "--seed", type=int, default=0, help="Random seed for reproducibility." + ) + return parser + + +def main(args): + align( + mode=args.mode, + gene_fpath=args.g_fpath, + spatial_fpath=args.s_fpath, + output_directory=args.output_dir, + alpha=args.alpha, + cost=args.cost, + n_components=args.n_comp, + lmbda=args.lmbda, + initial_slice=args.init_slice, + threshold=args.thresh, + coordinates=args.coor, + weight_fpath=args.w_fpath, + overlap_fraction=args.ovlp_frac, + start=args.start, + seed=args.seed, + cost_matrix=args.cost_mat, + norm=args.norm, + numItermax=args.max_iter, + use_gpu=args.gpu, + return_obj=args.r_info, + is_histology=args.hist, + armijo=args.armijo, + ) diff --git a/src/paste3/io.py b/src/paste3/io.py new file mode 100644 index 0000000..f4f1032 --- /dev/null +++ b/src/paste3/io.py @@ -0,0 +1,77 @@ +import scanpy as sc +import numpy as np + +from pathlib import Path +from collections import defaultdict +import logging + + +logger = logging.getLogger(__name__) + + +def process_files(g_fpath, s_fpath, w_fpath=None): + """Returns a list of AnnData objects.""" + + ext = Path(g_fpath[0]).suffix + + if ext == ".csv": + if not (len(s_fpath) == len(g_fpath)): + ValueError("Length of spatial files doesn't equal number of gene files") + _slices = defaultdict() + for file in g_fpath: + # The header of this file is alphanumeric, so this file has to be imported as a string + _slices[get_shape(file)[0]] = sc.read_csv(file) + + for file in s_fpath: + try: + _slice = _slices[get_shape(file)[0]] + except KeyError: + raise ValueError("Incomplete information for a slice") + else: + _slice.obsm["spatial"] = np.genfromtxt( + file, delimiter=",", dtype="float64" + ) + + if w_fpath: + if not (len(w_fpath) == len(g_fpath)): + ValueError("Length of weight files doesn't equal number of gene files") + for file in w_fpath: + _slice = _slices[get_shape(file)[0]] + _slice.obsm["weights"] = np.genfromtxt( + file, delimiter=",", dtype="float64" + ) + else: + for k, v in _slices.items(): + v.obsm["weights"] = np.ones((v.shape[0],)) / v.shape[0] + + slices = list(_slices.values()) + elif ext == ".h5ad": + slices = [sc.read_h5ad(file) for file in g_fpath] + + else: + raise ValueError("Incorrect file type provided ") + + return slices + + +def get_shape(file_path): + """Determines the shapes of the csv without opening the files""" + + def is_numeric(value): + try: + float(value) + return True + except ValueError: + return False + + with open(file_path, "r") as file: + first_line = file.readline().strip() + num_columns = len(first_line.split(",")) + + num_rows = sum(1 for _ in file) + + # Determine if the first row is a header + if all(is_numeric(val) for val in first_line.split(",")): + num_rows += 1 + + return num_rows, num_columns diff --git a/src/paste3/paste_cmd_line.py b/src/paste3/paste_cmd_line.py deleted file mode 100644 index c162571..0000000 --- a/src/paste3/paste_cmd_line.py +++ /dev/null @@ -1,257 +0,0 @@ -import scanpy as sc -import numpy as np -import pandas as pd -import argparse -import os -from paste3.paste import pairwise_align, center_align -from paste3.visualization import stack_slices_pairwise, stack_slices_center - - -def main(args): - # print(args) - n_slices = int(len(args.filename) / 2) - # Error check arguments - if args.mode != "pairwise" and args.mode != "center": - raise (ValueError("Please select either 'pairwise' or 'center' mode.")) - - if args.alpha < 0 or args.alpha > 1: - raise (ValueError("alpha specified outside [0, 1]")) - - if args.initial_slice < 1 or args.initial_slice > n_slices: - raise (ValueError("Initial slice specified outside [1, n]")) - - if len(args.lmbda) == 0: - lmbda = n_slices * [1.0 / n_slices] - elif len(args.lmbda) != n_slices: - raise (ValueError("Length of lambda does not equal number of files")) - else: - if not all(i >= 0 for i in args.lmbda): - raise (ValueError("lambda includes negative weights")) - else: - print("Normalizing lambda weights into probability vector.") - lmbda = args.lmbda - lmbda = [float(i) / sum(lmbda) for i in lmbda] - - # create slices - slices = [] - for i in range(n_slices): - s = sc.read_csv(args.filename[2 * i]) - s.obsm["spatial"] = np.genfromtxt(args.filename[2 * i + 1], delimiter=",") - slices.append(s) - - if len(args.weights) == 0: - for i in range(n_slices): - slices[i].obsm["weights"] = ( - np.ones((slices[i].shape[0],)) / slices[i].shape[0] - ) - elif len(args.weights) != n_slices: - raise ( - ValueError( - "Number of slices {0} != number of weight files {1}".format( - n_slices, len(args.weights) - ) - ) - ) - else: - for i in range(n_slices): - slices[i].obsm["weights"] = np.genfromtxt(args.weights[i], delimiter=",") - slices[i].obsm["weights"] = slices[i].obsm["weights"] / np.sum( - slices[i].obsm["weights"] - ) - - if len(args.start) == 0: - pis_init = (n_slices - 1) * [None] if args.mode == "pairwise" else None - elif (args.mode == "pairwise" and len(args.start) != n_slices - 1) or ( - args.mode == "center" and len(args.start) != n_slices - ): - raise ( - ValueError( - "Number of slices {0} != number of start pi files {1}".format( - n_slices, len(args.start) - ) - ) - ) - else: - pis_init = [ - pd.read_csv(args.start[i], index_col=0).to_numpy() - for i in range(len(args.start)) - ] - - # create output folder - output_path = os.path.join(args.direc, "paste_output") - if not os.path.exists(output_path): - os.mkdir(output_path) - - if args.mode == "pairwise": - print("Computing pairwise alignment.") - # compute pairwise align - pis = [] - for i in range(n_slices - 1): - pi = pairwise_align( - slices[i], - slices[i + 1], - alpha=args.alpha, - dissimilarity=args.cost, - a_distribution=slices[i].obsm["weights"], - b_distribution=slices[i + 1].obsm["weights"], - G_init=pis_init[i], - ) - pis.append(pi) - pi = pd.DataFrame( - pi, index=slices[i].obs.index, columns=slices[i + 1].obs.index - ) - output_filename = ( - "paste_output/slice" - + str(i + 1) - + "_slice" - + str(i + 2) - + "_pairwise.csv" - ) - pi.to_csv(os.path.join(args.direc, output_filename)) - if args.coordinates: - new_slices = stack_slices_pairwise(slices, pis) - for i in range(n_slices): - output_filename = ( - "paste_output/slice" + str(i + 1) + "_new_coordinates.csv" - ) - np.savetxt( - os.path.join(args.direc, output_filename), - new_slices[i].obsm["spatial"], - delimiter=",", - ) - elif args.mode == "center": - print("Computing center alignment.") - initial_slice = slices[args.initial_slice - 1].copy() - # compute center align - center_slice, pis = center_align( - initial_slice, - slices, - lmbda, - args.alpha, - args.n_components, - args.threshold, - random_seed=args.seed, - dissimilarity=args.cost, - distributions=[slices[i].obsm["weights"] for i in range(n_slices)], - pis_init=pis_init, - ) - W = pd.DataFrame(center_slice.uns["paste_W"], index=center_slice.obs.index) - H = pd.DataFrame(center_slice.uns["paste_H"], columns=center_slice.var.index) - W.to_csv(os.path.join(args.direc, "paste_output/W_center")) - H.to_csv(os.path.join(args.direc, "paste_output/H_center")) - for i in range(len(pis)): - output_filename = ( - "paste_output/slice_center_slice" + str(i + 1) + "_pairwise.csv" - ) - pi = pd.DataFrame( - pis[i], index=center_slice.obs.index, columns=slices[i].obs.index - ) - pi.to_csv(os.path.join(args.direc, output_filename)) - if args.coordinates: - center, new_slices = stack_slices_center(center_slice, slices, pis) - for i in range(n_slices): - output_filename = ( - "paste_output/slice" + str(i + 1) + "_new_coordinates.csv" - ) - np.savetxt( - os.path.join(args.direc, output_filename), - new_slices[i].obsm["spatial"], - delimiter=",", - ) - np.savetxt( - os.path.join(args.direc, "paste_output/center_new_coordinates.csv"), - center.obsm["spatial"], - delimiter=",", - ) - return - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-f", - "--filename", - help="path to data files (.csv). Alternate between gene expression and spatial data. Ex: slice1_gene.csv, slice1_coor.csv, slice2_gene.csv, slice2_coor.csv", - type=str, - default=[], - nargs="+", - ) - parser.add_argument( - "-m", - "--mode", - help="either 'pairwise' or 'center' ", - type=str, - default="pairwise", - ) - parser.add_argument("-d", "--direc", help="directory to save files", default="") - parser.add_argument( - "-a", - "--alpha", - help="alpha param for PASTE (float from [0,1])", - type=float, - default=0.1, - ) - parser.add_argument( - "-c", - "--cost", - help="expression dissimilarity cost, either 'kl' or 'euclidean' ", - type=str, - default="kl", - ) - parser.add_argument( - "-p", - "--n_components", - help="n_components for NMF step in center_align", - type=int, - default=15, - ) - parser.add_argument( - "-l", - "--lmbda", - help="lambda param in center_align (weight vector of length n) ", - type=float, - default=[], - nargs="+", - ) - parser.add_argument( - "-i", - "--initial_slice", - help="specify which slice is the intial slice for center_align (int from 1-n)", - type=int, - default=1, - ) - parser.add_argument( - "-t", - "--threshold", - help="convergence threshold for center_align", - type=float, - default=0.001, - ) - parser.add_argument( - "-x", - "--coordinates", - help="output new coordinates", - action="store_true", - default=False, - ) - parser.add_argument( - "-w", - "--weights", - help="path to files containing weights of spots in each slice. The format of the files is the same as the coordinate files used as input", - type=str, - default=[], - nargs="+", - ) - parser.add_argument( - "-s", - "--start", - help="path to files containing initial starting alignmnets. If not given the OT starts the search with uniform alignments. The format of the files is the same as the alignments files output by PASTE", - type=str, - default=[], - nargs="+", - ) - parser.add_argument( - "--seed", help="random seed for reproducibility", type=int, default=None - ) - args = parser.parse_args() - main(args) diff --git a/tests/test_align.py b/tests/test_align.py new file mode 100644 index 0000000..58b7dbf --- /dev/null +++ b/tests/test_align.py @@ -0,0 +1,149 @@ +import pandas as pd +import anndata as ad +from pandas.testing import assert_frame_equal +import scanpy as sc +from pathlib import Path +from paste3.io import get_shape, process_files +from paste3.align import align +import sys +import subprocess as sp +import paste3 + + +test_dir = Path(__file__).parent +input_dir = test_dir / "data/input" +output_dir = test_dir / "data/output" + + +def test_cmd_line_center(tmp_path): + print(f"Running command in {tmp_path}") + result = align( + "center", + [f"{input_dir}/slice{i}.csv" for i in range(1, 4)], + [f"{input_dir}/slice{i}_coor.csv" for i in range(1, 4)], + f"{tmp_path}", + 0.1, + "kl", + 15, + None, + 1, + 0.001, + False, + None, + None, + None, + 0, + None, + ) + + assert result is None + result = sc.read(tmp_path / "center_slice.h5ad") + + assert_frame_equal( + pd.DataFrame( + result.uns["paste_W"], + index=result.obs.index, + columns=[str(i) for i in range(15)], + ), + pd.read_csv(output_dir / "W_center", index_col=0), + check_names=False, + check_index_type=False, + rtol=1e-05, + atol=1e-08, + ) + assert_frame_equal( + pd.DataFrame(result.uns["paste_H"], columns=result.var.index), + pd.read_csv(output_dir / "H_center", index_col=0), + rtol=1e-05, + atol=1e-08, + ) + + for i, pi in enumerate(range(2)): + assert_frame_equal( + pd.read_csv(tmp_path / f"slice_{i}_{i+1}_pairwise.csv"), + pd.read_csv( + output_dir / f"slice_center_slice{i + 1}_pairwise.csv", + ), + ) + + +def test_cmd_line_pairwise_csv(tmp_path): + print(f"Running command in {tmp_path}") + result = align( + "pairwise", + [ + f"{input_dir}/slice1.csv", + f"{input_dir}/slice2.csv", + f"{input_dir}/slice3.csv", + ], + [ + f"{input_dir}/slice1_coor.csv", + f"{input_dir}/slice2_coor.csv", + f"{input_dir}/slice3_coor.csv", + ], + f"{tmp_path}", + 0.1, + "kl", + 15, + None, + 1, + 0.001, + False, + None, + None, + None, + 0, + None, + max_iter=1000, + ) + + assert result is None + assert_frame_equal( + pd.read_csv(tmp_path / "slice_1_2_pairwise.csv"), + pd.read_csv( + output_dir / "slices_1_2_pairwise.csv", + ), + ) + + +def test_process_files_csv(): + """Ensure process files works with csv inputs.""" + gene_fpath = [] + spatial_fpath = [] + for i in range(1, 5): + gene_fpath.append(Path(f"{input_dir}/slice{i}.csv")) + spatial_fpath.append(Path(f"{input_dir}/slice{i}_coor.csv")) + + ad_objs = process_files( + gene_fpath, + spatial_fpath, + ) + for obj in ad_objs: + assert isinstance(obj, ad.AnnData) + + +def test_process_files_ann_data(): + """Ensure process files works with Ann Data inputs.""" + gene_fpath = [] + for i in range(3, 7): + gene_fpath.append(Path(f"{input_dir}/15167{i}.h5ad")) + + ad_objs = process_files(gene_fpath, s_fpath=None) + for obj in ad_objs: + assert isinstance(obj, ad.AnnData) + + +def test_get_shape(): + s_fpath = Path(f"{input_dir}/slice1.csv") + c_fpath = Path(f"{input_dir}/slice1_coor.csv") + + assert get_shape(s_fpath) == (254, 7999) + assert get_shape(c_fpath) == (254, 2) + + +def test_version(): + result = sp.run( + [sys.executable, "-m", "paste3", "--version"], capture_output=True, text=True + ) + assert result.returncode == 0 + assert result.stdout.strip() == paste3.__version__ diff --git a/tests/test_paste_cmd_line.py b/tests/test_paste_cmd_line.py deleted file mode 100644 index ab05af2..0000000 --- a/tests/test_paste_cmd_line.py +++ /dev/null @@ -1,113 +0,0 @@ -import pandas as pd -from pandas.testing import assert_frame_equal -from pathlib import Path -from collections import namedtuple -from paste3.paste_cmd_line import main as paste_cmd_line - -test_dir = Path(__file__).parent -input_dir = test_dir / "data/input" -output_dir = test_dir / "data/output" - -args = namedtuple( - "args", - [ - "filename", - "mode", - "direc", - "alpha", - "cost", - "n_components", - "lmbda", - "initial_slice", - "threshold", - "coordinates", - "weights", - "start", - "seed", - ], -) - - -def test_cmd_line_center(tmp_path): - print(f"Running command in {tmp_path}") - result = paste_cmd_line( - args( - [ - f"{input_dir}/slice1.csv", - f"{input_dir}/slice1_coor.csv", - f"{input_dir}/slice2.csv", - f"{input_dir}/slice2_coor.csv", - f"{input_dir}/slice3.csv", - f"{input_dir}/slice3_coor.csv", - ], - "center", - f"{tmp_path}", - 0.1, - "kl", - 15, - [], - 1, - 0.001, - False, - [], - [], - 0, - ) - ) - - assert result is None - assert_frame_equal( - pd.read_csv(tmp_path / "paste_output/W_center"), - pd.read_csv(output_dir / "W_center"), - check_names=False, - rtol=1e-05, - atol=1e-08, - ) - assert_frame_equal( - pd.read_csv(tmp_path / "paste_output/H_center"), - pd.read_csv(output_dir / "H_center"), - rtol=1e-05, - atol=1e-08, - ) - - for i, pi in enumerate(range(3)): - assert_frame_equal( - pd.read_csv( - tmp_path / f"paste_output/slice_center_slice{i + 1}_pairwise.csv" - ), - pd.read_csv(output_dir / f"slice_center_slice{i + 1}_pairwise.csv"), - ) - - -def test_cmd_line_pairwise(tmp_path): - print(f"Running command in {tmp_path}") - result = paste_cmd_line( - args( - [ - f"{input_dir}/slice1.csv", - f"{input_dir}/slice1_coor.csv", - f"{input_dir}/slice2.csv", - f"{input_dir}/slice2_coor.csv", - f"{input_dir}/slice3.csv", - f"{input_dir}/slice3_coor.csv", - ], - "pairwise", - f"{tmp_path}", - 0.1, - "kl", - 15, - [], - 1, - 0.001, - False, - [], - [], - 0, - ) - ) - - assert result is None - assert_frame_equal( - pd.read_csv(tmp_path / "paste_output/slice1_slice2_pairwise.csv"), - pd.read_csv(output_dir / "slices_1_2_pairwise.csv"), - )