Skip to content

Commit

Permalink
Merge pull request #59 from akshayparopkari/master
Browse files Browse the repository at this point in the history
Streamline LDA.py, move bubble plotting into LDA_bubble.py
  • Loading branch information
smdabdoub authored Jul 20, 2016
2 parents 2726561 + 2208e1f commit b590938
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 192 deletions.
283 changes: 113 additions & 170 deletions bin/LDA.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
#!/usr/bin/env python
import os
"""
Abstract: This script calculates and returns LDA plots based on normalized relative
abundances or distance matrices (for e.g. unifrac distance matrix).
"""

import sys
import argparse
from phylotoast import util, biom_calc as bc, otu_calc as oc, graph_util as gu
from phylotoast import util, biom_calc as bc, graph_util as gu
errors = []
try:
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
mpl.rc("font", family="Arial") # define font for figure text
except ImportError as ie:
errors.append(ie)
try:
Expand All @@ -27,57 +33,59 @@
errors.append(ie)
if len(errors) != 0:
for item in errors:
print "Import Error:", item
print("Import Error:", item)
sys.exit()


def get_relative_abundance(biomfile):
"""
Return relative abundance from a OTU table. OTUIDs are converted to their
genus-species identifier.
"""
biomf = biom.load_table(biomfile)
norm_biomf = biomf.norm(inplace=False)
rel_abd = {}
for sid in norm_biomf.ids():
rel_abd[sid] = {}
for otuid in norm_biomf.ids("observation"):
otuname = oc.otu_name(norm_biomf.metadata(otuid, axis="observation")["taxonomy"])
abd = norm_biomf.get_value_by_ids(otuid, sid)
rel_abd[sid][otuname] = abd
ast_rel_abd = bc.arcsine_sqrt_transform(rel_abd)
return ast_rel_abd


def plot_LDA(X_lda, y_lda, class_colors, exp_var, style, out_fp=""):
def plot_LDA(X_lda, y_lda, class_colors, exp_var, style, fig_size, label_pad, font_size,
dim=2, zangles=None, out_fp=""):
"""
Plot transformed LDA data.
"""
cats = class_colors.keys()
group_lda = {c: [] for c in cats}
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111)
for i, target_name in zip(range(len(cats)), cats):
cat_x = X_lda[:, 0][y_lda == target_name]
if X_lda.shape[1] == 1:
cat_y = np.ones((cat_x.shape[0], 1)) + i
else:
fig = plt.figure(figsize=fig_size)
if dim == 3:
ax = fig.add_subplot(111, projection="3d")
ax.view_init(elev=zangles[1], azim=zangles[0])
try:
ax.set_zlabel("LD3 (Percent Explained Variance: {:.3f}%)".
format(exp_var[2]*100), fontsize=font_size, labelpad=label_pad)
except:
ax.set_zlabel("LD3", fontsize=font_size, labelpad=label_pad)
for i, target_name in zip(range(len(cats)), cats):
cat_x = X_lda[:, 0][y_lda == target_name]
cat_y = X_lda[:, 1][y_lda == target_name]
group_lda[target_name].append(cat_x)
group_lda[target_name].append(cat_y)
plt.scatter(x=cat_x, y=cat_y, label=target_name,
color=class_colors[target_name],
alpha=0.85, s=250, edgecolors="k")
mpl.rc("font", family="Arial") # define font for figure text
mpl.rc('xtick', labelsize=12) # increase X axis ticksize
mpl.rc('ytick', labelsize=12) # increase Y axis ticksize
cat_z = X_lda[:, 2][y_lda == target_name]
ax.scatter(xs=cat_x, ys=cat_y, zs=cat_z, label=target_name,
c=class_colors[target_name], alpha=0.85, s=250,
edgecolors="k", zdir="z")
else:
ax = fig.add_subplot(111)
for i, target_name in zip(range(len(cats)), cats):
cat_x = X_lda[:, 0][y_lda == target_name]
if X_lda.shape[1] == 1:
cat_y = np.ones((cat_x.shape[0], 1)) + i
else:
cat_y = X_lda[:, 1][y_lda == target_name]
ax.scatter(x=cat_x, y=cat_y, label=target_name,
color=class_colors[target_name],
alpha=0.85, s=250, edgecolors="k")
if X_lda.shape[1] == 1:
plt.ylim((0.5, 2.5))
plt.xlabel("LD1 (Percent Explained Variance: {:.3f}%)".format(exp_var[0]*100), fontsize=16)
plt.ylabel("LD2 (Percent Explained Variance: {:.3f}%)".format(exp_var[1]*100), fontsize=16)
leg = plt.legend(loc="best", frameon=True, framealpha=1, fontsize=16)
try:
ax.set_xlabel("LD1 (Percent Explained Variance: {:.3f}%)".
format(exp_var[0]*100), fontsize=font_size, labelpad=label_pad)
except:
ax.set_xlabel("LD1", fontsize=font_size, labelpad=label_pad)
try:
ax.set_ylabel("LD2 (Percent Explained Variance: {:.3f}%)".
format(exp_var[1]*100), fontsize=font_size, labelpad=label_pad)
except:
ax.set_ylabel("LD2", fontsize=font_size, labelpad=label_pad)

leg = plt.legend(loc="best", scatterpoints=3, frameon=True, framealpha=1, fontsize=15)
leg.get_frame().set_edgecolor('k')
if style:
if dim == 2 and style:
gu.ggplot2_style(ax)
fc = "0.8"
else:
Expand All @@ -94,7 +102,7 @@ def plot_LDA(X_lda, y_lda, class_colors, exp_var, style, out_fp=""):
def run_LDA(df):
"""
Run LinearDiscriminantAnalysis on input dataframe (df) and return
transformed data, scalings and
transformed data, scalings and explained variance by discriminants.
"""
# Prep variables for sklearn LDA
X = df[range(1, df.shape[1])].values # input data matrix
Expand All @@ -103,48 +111,37 @@ def run_LDA(df):
# Calculate LDA
sklearn_lda = LDA()
X_lda_sklearn = sklearn_lda.fit_transform(X, y)
exp_var = sklearn_lda.explained_variance_ratio_

try:
exp_var = sklearn_lda.explained_variance_ratio_
except AttributeError as ae:
print("\n{}: explained variance cannot be computed.\nPlease check this GitHub PR:"
" https://github.com/scikit-learn/scikit-learn/pull/6027".format(ae))
return X_lda_sklearn, y, "NA"
return X_lda_sklearn, y, exp_var


def handle_program_options():
parser = argparse.ArgumentParser(description="Create an LDA plot from\
sample-grouped OTU data. It is necessary\
to remove the header cell '#OTU ID'\
before running this program.")
parser.add_argument("-i", "--input_data_type", required=True,
choices=["biom", "unifrac_dm"],
default="unifrac_dm",
help="Specify if the input file is biom file format OTU \
table or unifrac distance matrix. If biom file is \
provided, the arc-sine transformed relative abundances \
eill be used as input whereas, if unifrac distance matrix.\
is given, unifrac distances will be used as input to LDA.\
[REQUIRED]")
parser.add_argument("-bf", "--biom_file", required=True,
help="Input biom file format. [REQUIRED]")
parser = argparse.ArgumentParser(description="This script calculates and returns LDA "
"plots based on normalized relative "
"abundances or distance matrices "
"(for e.g. unifrac distance matrix).")
parser.add_argument("-i", "--otu_table", required=True,
help="Input biom file format OTU table. [REQUIRED]")
parser.add_argument("-m", "--map_fp", required=True,
help="Metadata mapping file. [REQUIRED]")
parser.add_argument("-uf", "--unifrac_file",
help="Input unifrac datdistance matrix file. This is the \
output from ")
parser.add_argument("-g", "--group_by", required=True,
help="A column name in the mapping file containing\
categorical values that will be used to identify \
groups. Each sample ID must have a group entry. \
Default is no categories and all the data will be \
treated as a single group.")
treated as a single group. [REQUIRED]")
parser.add_argument("-c", "--color_by", required=True,
help="A column name in the mapping file containing\
hexadecimal (#FF0000) color values that will\
be used to color the groups. Each sample ID must\
have a color entry.")
parser.add_argument("--bubble",
help="If set, provide a file with 1 OTU name per line \
for bubble plotting. OTU name must be condensed to \
genus-species identifier. Default parameter value \
will not plot bubble plots.")
have a color entry. [REQUIRED]")
parser.add_argument("-dm", "--dist_matrix_file",
help="Input distance matrix file.")
parser.add_argument("--save_lda_input",
help="Save a CSV-format file of the transposed LDA-input\
table to the file specifed by this option.")
Expand All @@ -155,15 +152,18 @@ def handle_program_options():
If specified, the figure will be saved directly\
instead of opening a window in which the plot \
can be viewed before saving")
parser.add_argument("-od", "--output_dir", default=".",
help="The directory to save the LDA bubble plots to.")
parser.add_argument("--scale_by", default=1000, type=float,
help="Species relative abundance is multiplied by this \
factor in order to make appropriate visible \
bubbles in the output plots. Default is 1000.")
parser.add_argument("-s", "--save_as", default="svg",
help="The type of image file for LDA plots. By default,\
files will be saved in SVG format.")
parser.add_argument("-d", "--dimensions", default=2, type=int, choices=[2, 3],
help="Choose whether to plot 2D or 3D.")
parser.add_argument("--z_angles", type=float, nargs=2, default=[45., 30.],
help="Specify the azimuth and elevation angles for a 3D plot.")
parser.add_argument("--figsize", default=[14, 8], type=int, nargs=2,
help="Specify the 'width height' in inches for LDA plots."
"By default, figure size is 14x8 inches.")
parser.add_argument("--font_size", default=12, type=int,
help="Sets the font size for text elements in the plot.")
parser.add_argument("--label_padding", default=15, type=int,
help="Sets the spacing in points between the each axis and its \
label.")
parser.add_argument("--ggplot2_style", action="store_true",
help="Apply ggplot2 styling to the figure.")
return parser.parse_args()
Expand All @@ -172,123 +172,66 @@ def handle_program_options():
def main():
args = handle_program_options()

# Parse and read mapping file
try:
with open(args.map_fp):
pass
header, imap = util.parse_map_file(args.map_fp)
category_idx = header.index(args.group_by)
except IOError as ioe:
err_msg = "\nError in metadata mapping filepath (-m): {}\n"
sys.exit(err_msg.format(ioe))

# Parse and read mapping file and obtain group colors
header, imap = util.parse_map_file(args.map_fp)
# Obtain group colors
class_colors = util.color_mapping(imap, header, args.group_by, args.color_by)

if args.input_data_type == "unifrac_dm":
if args.dist_matrix_file:
try:
with open(args.unifrac_file):
with open(args.dist_matrix_file):
pass
except IOError as ioe:
err_msg = "\nError with unifrac distance matrix file (-d): {}\n"
sys.exit(err_msg.format(ioe))
uf_data = pd.read_csv(args.unifrac_file, sep="\t", index_col=0)
uf_data.insert(0, "Condition", [imap[sid][header.index(args.group_by)]
for sid in uf_data.index])
sampleids = uf_data.index
uf_data = pd.read_csv(args.dist_matrix_file, sep="\t", index_col=0)
uf_data.insert(0, "Condition", [imap[sid][category_idx] for sid in uf_data.index])
if args.save_lda_input:
uf_data.to_csv(args.save_lda_input, sep="\t")
# Run LDA
X_lda, y_lda, exp_var = run_LDA(uf_data)
# Plot LDA
plot_LDA(X_lda, y_lda, class_colors, exp_var, style=args.ggplot2_style,
out_fp=args.out_fp)
if args.dimensions == 3:
plot_LDA(X_lda, y_lda, class_colors, exp_var, style=args.ggplot2_style,
fig_size=args.figsize, label_pad=args.label_padding,
font_size=args.font_size, dim=3, zangles=args.z_angles,
out_fp=args.out_fp)
else:
plot_LDA(X_lda, y_lda, class_colors, exp_var, style=args.ggplot2_style,
fig_size=args.figsize, label_pad=args.label_padding,
font_size=args.font_size, out_fp=args.out_fp)
else:
# Load biom file and calculate relative abundance
try:
rel_abd = get_relative_abundance(args.biom_file)
except ValueError as ve:
biomf = biom.load_table(args.otu_table)
except IOError as ioe:
err_msg = "\nError with biom format file (-d): {}\n"
sys.exit(err_msg.format(ve))
sys.exit(err_msg.format(ioe))
# Get normalized relative abundances
rel_abd = bc.relative_abundance(biomf)
rel_abd = bc.arcsine_sqrt_transform(rel_abd)
df_rel_abd = pd.DataFrame(rel_abd).T
df_rel_abd.insert(0, "Condition", [imap[sid][header.index(args.group_by)]
for sid in df_rel_abd.index])
sampleids = df_rel_abd.index
df_rel_abd.insert(0, "Condition", [imap[sid][category_idx] for sid in df_rel_abd.index])
if args.save_lda_input:
df_rel_abd.to_csv(args.save_lda_input, sep="\t")
# Run LDA
X_lda, y_lda, exp_var = run_LDA(df_rel_abd)
# Plot LDA
plot_LDA(X_lda, y_lda, class_colors, exp_var, style=args.ggplot2_style,
out_fp=args.out_fp)

if args.bubble:
# Get otus for LDA bubble plots
try:
with open(args.bubble) as hojiehr:
for line in hojiehr.readlines():
bubble_otus = line.strip().split("\r")
except IOError as ioe:
err_msg = "\nError in OTU name list file (--bubble): {}\n"
sys.exit(err_msg.format(ioe))

# Load biom file and calculate relative abundance
try:
rel_abd = get_relative_abundance(args.biom_file)
except ValueError as ve:
err_msg = "\nError with biom format file (-d): {}\n"
sys.exit(err_msg.format(ve))
category_idx = header.index(args.group_by)

# Calculate position and size of SampleIDs to plot for each OTU
for otuname in bubble_otus:
plot_data = {cat: {"x": [], "y": [], "size": [], "label": []}
for cat in class_colors.keys()}
for sid, data in zip(sampleids, X_lda):
category = plot_data[imap[sid][category_idx]]
try:
size = rel_abd[sid][otuname] * args.scale_by
except KeyError as ke:
print "{} not found in {} sample.".format(ke, sid)
continue
category["x"].append(float(data[0]))
category["y"].append(float(data[1]))
category["size"].append(size)

# Plot LDA bubble for each OTU
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(111)
for i, cat in enumerate(plot_data):
plt.scatter(plot_data[cat]["x"], plot_data[cat]["y"],
plot_data[cat]["size"], label=cat,
color=class_colors[cat],
alpha=0.85, marker="o", edgecolor="k")
mpl.rc("font", family="Arial") # define font for figure text
mpl.rc("xtick", labelsize=12) # increase X axis ticksize
mpl.rc("ytick", labelsize=12) # increase Y axis ticksize
if X_lda.shape[1] == 1:
plt.ylim((0.5, 2.5))
plt.title(" ".join(otuname.split("_")), style="italic")
plt.xlabel("LD1 (Percent Explained Variance: {:.3f}%)".format(exp_var[0]*100),
fontsize=12)
plt.ylabel("LD2 (Percent Explained Variance: {:.3f}%)".format(exp_var[1]*100),
fontsize=12)
lgnd = plt.legend(loc="best", scatterpoints=3, fontsize=12)
# Change the legend marker size manually
for i in range(len(class_colors.keys())):
lgnd.legendHandles[i]._sizes = [75]

# Set style for LDA bubble plots
if args.ggplot2_style:
gu.ggplot2_style(ax)
fc = "0.8"
else:
fc = "none"
if args.dimensions == 3:
plot_LDA(X_lda, y_lda, class_colors, exp_var, style=args.ggplot2_style,
fig_size=args.figsize, label_pad=args.label_padding,
font_size=args.font_size, dim=3, zangles=args.z_angles,
out_fp=args.out_fp)
else:
plot_LDA(X_lda, y_lda, class_colors, exp_var, style=args.ggplot2_style,
fig_size=args.figsize, label_pad=args.label_padding,
font_size=args.font_size, out_fp=args.out_fp)

# Save LDA bubble plots to output directory
print "Saving chart for {}".format(" ".join(otuname.split("_")))
fig.savefig(os.path.join(args.output_dir, "_".join(otuname.split())) + "." + args.save_as,
facecolor=fc, edgecolor="none", dpi=300,
bbox_inches="tight", pad_inches=0.2)
plt.close(fig)

if __name__ == "__main__":
sys.exit(main())
Loading

0 comments on commit b590938

Please sign in to comment.