import glob
import gzip
import math
import re
from collections import defaultdict
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from allel import (
GenotypeArray,
read_vcf,
windowed_count,
windowed_diversity,
windowed_watterson_theta,
)
from numba import int64, njit
# from pybedtools import BedTool
from polars_bio import merge, nearest, overlap
from . import Parallel, delayed, np, pl
from .fv import get_cm
################## Plotting
[docs]
def plot_diversity(data_dir, figsize=None, title=None, out=None, nthreads=1):
vcf_files = glob.glob(f"{data_dir}/*vcf.gz")
vcf_files = [p for p in vcf_files if "masked" not in p]
with Parallel(n_jobs=nthreads, verbose=2) as parallel:
vcf_data = parallel(delayed(read_vcf)(i) for i in vcf_files)
_out = []
for v in vcf_data:
hap = GenotypeArray(v["calldata/GT"]).to_haplotypes()
ac = hap.count_alleles()
_pos = v["variants/POS"]
pi_w, _w_pi, _n_pi, _c_pi = windowed_diversity(
_pos, ac, size=int(1e5), step=int(5e4)
)
theta_w, _w_theta, _n_theta, _c_theta = windowed_watterson_theta(
_pos, ac, size=int(1e5), step=int(5e4)
)
_counts, _windows = windowed_count(_pos, size=int(1e5), step=int(5e4))
nchr = np.unique(v["variants/CHROM"])[0]
tmp_1 = pl.DataFrame(
{
"contig": nchr,
"start": _w_pi[:, 0],
"end": _w_pi[:, 1],
"pi": pi_w,
"theta_w": theta_w,
"s": _counts,
"window_size": int(1e5),
}
)
#######
pi_w, _w_pi, _n_pi, _c_pi = windowed_diversity(
_pos, ac, size=int(1.2e6), step=int(1e5)
)
theta_w, _w_theta, _n_theta, _c_theta = windowed_watterson_theta(
_pos, ac, size=int(1.2e6), step=int(1e5)
)
_counts, _windows = windowed_count(_pos, size=int(1.2e6), step=int(1e5))
tmp_2 = pl.DataFrame(
{
"contig": nchr,
"start": _w_pi[:, 0],
"end": _w_pi[:, 1],
"pi": pi_w,
"theta_w": theta_w,
"s": _counts,
"window_size": int(1.2e6),
}
)
_out.append(pl.concat([tmp_1, tmp_2]))
df = pl.concat(_out).sort(
pl.col("contig").str.replace("chr", "").cast(pl.Int64),
"window_size",
"start",
)
df_plot = df.with_columns(
((pl.col("start") + pl.col("end")) / 2).alias("mid")
).filter(pl.col("window_size") == int(1.2e6))
# Preserve original contig order
contig_order = df_plot.select("contig").unique(maintain_order=True)
# Compute contig lengths
contig_lengths = (
df_plot.group_by("contig")
.agg(pl.max("end").alias("length"))
.join(contig_order, on="contig")
)
# Cumulative offsets
contig_offsets = contig_lengths.with_columns(
pl.col("length").cum_sum().shift(1).fill_null(0).alias("offset")
)
# Genome-wide positions
df_plot = df_plot.join(
contig_offsets.select(["contig", "offset"]), on="contig"
).with_columns((pl.col("mid") + pl.col("offset")).alias("genome_pos"))
pdf = df_plot.sort("genome_pos").to_pandas()
boundaries = contig_offsets.to_pandas()
# Compute mean per contig
stats = ["pi", "theta_w", "s"]
means = df_plot.group_by("contig").agg([pl.mean(s).alias(s) for s in stats])
means_pdf = means.join(contig_offsets, on="contig").to_pandas()
means_pdf["start_pos"] = means_pdf["offset"]
means_pdf["end_pos"] = means_pdf["offset"] + means_pdf["length"]
# Set publication-ready colors
colors = {
"pi": "#1f77b4", # blue
"theta_w": "#ff7f0e", # orange
"s": "#2ca02c", # green
}
# Plot genome tracks
ylabels = ["π", "θ watterson", "S"]
fig, axes = plt.subplots(3, 1, figsize=(12, 6), sharex=True)
for ax, stat, ylabel in zip(axes, stats, ylabels):
ax.plot(pdf["genome_pos"], pdf[stat], color=colors[stat], linewidth=1)
# Vertical lines between contigs (skip first contig)
if len(boundaries) > 1:
for offset in boundaries["offset"].iloc[1:]:
ax.axvline(offset, color="black", linestyle="-.", alpha=0.5)
# Mean lines per contig
for _, row in means_pdf.iterrows():
ax.hlines(
y=row[stat],
xmin=row["start_pos"],
xmax=row["end_pos"],
colors="gray",
linestyles="dashed",
alpha=0.7,
)
# Clean axes
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_ylabel(ylabel, fontsize=12)
axes[-1].set_xlabel("genome position")
# Chromosome labels
if len(boundaries) > 1:
centers = boundaries.copy()
centers["center"] = centers["offset"] + centers["length"] / 2
axes[-1].set_xticks(centers["center"])
axes[-1].set_xticklabels(centers["contig"], rotation=90)
if out is not None:
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.tight_layout()
plt.show()
return df
def _lncomb(N, k):
"""Log of N choose k, vectorized over k. Returns -inf for out-of-range."""
from scipy.special import gammaln
with np.errstate(invalid="ignore"):
result = gammaln(N + 1) - gammaln(k + 1) - gammaln(N - k + 1)
return np.where(np.isfinite(result), result, -np.inf)
def _project_sfs(sfs, n_proj):
"""
Project a 1-D SFS (length proj_from+1) to n_proj chromosomes.
Uses hypergeometric weights following dadi/moments convention
(_cached_projection). Fixed sites (index 0 and proj_from) are not
projected. Returns array of length n_proj+1.
"""
proj_from = len(sfs) - 1
p_sfs = np.zeros(n_proj + 1)
proj_hits = np.arange(n_proj + 1)
lnc_to = _lncomb(n_proj, proj_hits)
for hits in range(1, proj_from):
if sfs[hits] == 0:
continue
lncontrib = (
lnc_to
+ _lncomb(proj_from - n_proj, hits - proj_hits)
- float(_lncomb(proj_from, hits))
)
contrib = np.exp(np.where(np.isfinite(lncontrib), lncontrib, -np.inf))
least = max(n_proj - (proj_from - hits), 0)
most = min(hits, n_proj)
p_sfs[least : most + 1] += sfs[hits] * contrib[least : most + 1]
return p_sfs
def plot_sfs(
vcf_path,
n_proj=None,
fold=False,
merge=False,
nthreads=1,
figsize=None,
title=None,
out=None,
):
"""
Plot the Site Frequency Spectrum (SFS) from one or more VCF files.
Each bar shows the percentage of segregating sites at that frequency class
(singletons, doubletons, …). For multiple files the bars are plotted
side-by-side within each frequency class, one bar per file.
Parameters
----------
vcf_path : str
Path to a single VCF/VCF.gz file *or* a directory of ``*.vcf.gz``
files (files containing "masked" in the name are skipped).
n_proj : int, optional
Diploid sample size to project down to (e.g. 50 → 100 haplotypes,
SFS runs 1–99). Must be smaller than the actual diploid sample size.
fold : bool, default False
If True, use minor-allele count (folded SFS). If False, treat the
ALT allele as derived (unfolded SFS).
merge : bool, default True
If True, pool all VCF files into a single SFS labelled with
``vcf_path``. If False, plot one bar group per file.
nthreads : int, default 1
Parallel workers for reading multiple VCF files.
figsize : tuple, optional
Matplotlib figure size. Defaults to ``(max(8, n_classes * 0.6 + 2), 4)``.
title : str, optional
Figure title.
out : str, optional
If given, save the figure to this path instead of displaying it.
Returns
-------
polars.DataFrame
Columns: ``dataset``, ``dac``, ``count``, ``pct``.
"""
import os
if os.path.isdir(vcf_path):
vcf_files = sorted(glob.glob(f"{vcf_path}/*.vcf.gz"))
vcf_files = [p for p in vcf_files if "masked" not in p]
else:
vcf_files = [vcf_path]
if not vcf_files:
raise FileNotFoundError(f"No VCF files found at {vcf_path}")
def _read_ac(vcf_file):
"""Return (dac array, n_max) for one VCF file, or None on failure."""
raw = read_vcf(vcf_file)
if raw is None:
return None
ac = GenotypeArray(raw["calldata/GT"]).count_alleles()
bial = ac.is_biallelic_01()
ac = ac[bial]
if len(ac) == 0:
return None
n_called = ac.sum(axis=1)
n_max = int(n_called.max())
dac = np.minimum(ac[:, 0], ac[:, 1]) if fold else ac[:, 1]
mask = n_called == n_max
return dac[mask].astype(int), n_max
def _sfs_frame(dac_full, n_max, label):
"""Build a pl.DataFrame from raw dac counts."""
sfs = np.bincount(dac_full, minlength=n_max + 1).astype(float)
if n_proj is not None:
n_hap = n_proj * 2
if n_hap >= n_max:
raise ValueError(
f"n_proj={n_proj} diploid ({n_hap} haplotypes) must be "
f"< actual diploid sample size {n_max // 2} ({n_max} haplotypes)."
)
sfs = _project_sfs(sfs, n_hap)
n_total = n_hap
else:
n_total = n_max
sfs_trim = sfs[1 : n_total // 2 + 1] if fold else sfs[1:-1]
total = sfs_trim.sum()
sfs_pct = sfs_trim / total * 100 if total > 0 else sfs_trim
dac_idx = np.arange(1, len(sfs_trim) + 1)
return pl.DataFrame(
{"dataset": label, "dac": dac_idx, "count": sfs_trim, "pct": sfs_pct}
)
with Parallel(n_jobs=nthreads, verbose=2) as parallel:
ac_results = parallel(delayed(_read_ac)(f) for f in vcf_files)
ac_results = [r for r in ac_results if r is not None]
if not ac_results:
raise ValueError("No valid SFS could be computed from the input files.")
if merge:
# Pool all dac arrays; require same n_max across files
n_values = [n for _, n in ac_results]
n_max = min(n_values)
dac_all = np.concatenate([d[d <= n_max] for d, _ in ac_results])
df = _sfs_frame(dac_all, n_max, os.path.basename(vcf_path.rstrip("/")))
else:
frames = []
for (dac_full, n_max), vcf_file in zip(ac_results, vcf_files):
label = (
os.path.basename(vcf_file).replace(".vcf.gz", "").replace(".vcf", "")
)
frames.append(_sfs_frame(dac_full, n_max, label))
df = pl.concat(frames)
# Align to minimum shared dac range across datasets
max_shared = df.group_by("dataset").agg(pl.max("dac")).select(pl.min("dac")).item()
df = df.filter(pl.col("dac") <= max_shared)
# Plot
labels = df.select("dataset").unique(maintain_order=True).to_series().to_list()
n_files = len(labels)
n_classes = max_shared
x = np.arange(1, n_classes + 1)
bar_width = 0.8 / n_files
offsets = (np.arange(n_files) - n_files / 2 + 0.5) * bar_width
colors = plt.cm.tab10(np.linspace(0, 0.9, n_files))
if figsize is None:
figsize = (max(8, n_classes * 0.6 + 2), 4)
fig, ax = plt.subplots(figsize=figsize)
for i, label in enumerate(labels):
pct = df.filter(pl.col("dataset") == label).sort("dac")["pct"].to_numpy()
ax.bar(
x + offsets[i],
pct,
width=bar_width,
label=label,
color=colors[i],
alpha=0.85,
edgecolor="none",
)
xlabel = "Minor allele count" if fold else "Derived allele count"
if n_proj is not None:
xlabel += f" (projected n={n_proj} diploid)"
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel("Segregating sites (%)", fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(x)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if n_files > 1:
ax.legend(fontsize=9, frameon=False, bbox_to_anchor=(0.75, 1), loc="upper left")
if title:
ax.set_title(title)
plt.tight_layout()
if out is not None:
plt.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return df
[docs]
def plot_manhattan(
prediction,
recombination_map: str | None = None,
eps: float = 1e-10,
chr_col: str | None = None,
pos_col: str | None = None,
p_col: str | None = None,
log_transform: bool = True,
threshold_lines: list | None = None,
figsize: tuple = (14, 5),
out: str | None = None,
title: str | None = None,
chrom: str | int | None = None,
center: int | None = None,
window_bp: int = 5_000_000,
chr_prefix_pattern: str = r"^chr",
):
# ----------------------------
# Stable chromosome encoding
# ----------------------------
def _chr_to_int(expr: pl.Expr) -> pl.Expr:
cleaned = (
expr.cast(pl.Utf8)
.str.to_uppercase()
.str.replace(chr_prefix_pattern, "")
.str.extract(r"(\d+|X|Y|MT|M)$", 1)
)
return cleaned.replace({"X": "23", "Y": "24", "MT": "25", "M": "25"}).cast(
pl.Int32
)
def _chr_str_to_int(val: str | int) -> int:
s = str(val).upper()
s = pl.Series([s]).str.replace(chr_prefix_pattern, "").to_list()[0]
if s == "X":
return 23
if s == "Y":
return 24
if s in {"MT", "M"}:
return 25
return int(s)
# ----------------------------
# Fixed chromosome color map
# ----------------------------
_CHR_COLORS = {i: ("black" if i % 2 == 1 else "0.8") for i in range(1, 26)}
# ----------------------------
# Load data
# ----------------------------
df = pl.read_csv(prediction) if isinstance(prediction, str) else prediction
_chr_col = chr_col or "chr"
_pos_col = pos_col or "start"
_end_col = (
"end" if "end" in df.columns else ("stop" if "stop" in df.columns else _pos_col)
)
is_regional = chrom is not None
# ----------------------------
# Normalize input
# ----------------------------
df = (
df.with_columns(
[
_chr_to_int(pl.col(_chr_col)).alias("CHR_INT"),
pl.col(_pos_col).cast(pl.Int64).alias("BP_START"),
pl.col(_end_col).cast(pl.Int64).alias("BP_END"),
]
)
.filter(pl.col("CHR_INT").is_not_null())
.with_columns(((pl.col("BP_START") + pl.col("BP_END")) / 2).alias("midpoint"))
.sort(["CHR_INT", "BP_START"])
)
p_target = p_col or "prob_sweep"
df = df.with_columns(
pl.col(p_target).cast(pl.Float64).clip(lower_bound=eps).alias("P")
)
# ----------------------------
# Recombination map
# ----------------------------
if recombination_map:
df_rec = (
pl.read_csv(
recombination_map,
separator="\t",
comment_prefix="#",
schema=pl.Schema(
[
("chr", pl.String),
("start", pl.Int64),
("end", pl.Int64),
("cm_mb", pl.Float64),
("cm", pl.Float64),
]
),
)
.with_columns(_chr_to_int(pl.col("chr")).alias("chr"))
.filter(pl.col("chr").is_not_null())
)
df_temp = df.select(
[
pl.col("CHR_INT").alias("chr"),
pl.col("BP_START").alias("start"),
pl.col("BP_END").alias("end"),
]
)
rate_frames = [
get_cm(
df_rec.filter(pl.col("chr") == c),
df_temp.filter(pl.col("chr") == c)
.select("start", "end")
.unique()
.to_numpy(),
cm_mb=True,
)
for c in df_temp["chr"].unique()
if c in df_rec["chr"].unique()
]
if rate_frames:
df_rate = pl.concat(rate_frames).select(["chr", "start", "end", "cm_mb"])
df = df.join(
df_rate,
left_on=["CHR_INT", "BP_START", "BP_END"],
right_on=["chr", "start", "end"],
how="left",
)
chrom_order = list(range(1, 26))
# ----------------------------
# Regional mode
# ----------------------------
if is_regional:
target_chrom = _chr_str_to_int(chrom)
df_plot = df.filter(pl.col("CHR_INT") == target_chrom)
if center is not None:
xlim_lo, xlim_hi = center - window_bp, center + window_bp
plot_center_line = float(center)
df_plot = df_plot.filter(
(pl.col("midpoint") >= xlim_lo - 1_000_000)
& (pl.col("midpoint") <= xlim_hi + 1_000_000)
)
else:
xlim_lo, xlim_hi, plot_center_line = None, None, None
df_plot = df_plot.sort("BP_START")
# ----------------------------
# Global mode
# ----------------------------
else:
df = df.filter(pl.col("CHR_INT").is_in(chrom_order))
chr_lens = (
df.group_by("CHR_INT")
.agg(pl.col("BP_START").max().alias("chr_len"))
.sort("CHR_INT")
.with_columns(
(pl.col("chr_len").cum_sum() - pl.col("chr_len")).alias("tot")
)
)
df_plot = (
df.join(chr_lens, on="CHR_INT", how="left")
.sort(["CHR_INT", "BP_START"])
.with_columns((pl.col("BP_START") + pl.col("tot")).alias("BPcum"))
)
axisdf = (
df_plot.group_by("CHR_INT")
.agg(((pl.col("BPcum").max() + pl.col("BPcum").min()) / 2).alias("center"))
.sort("CHR_INT")
)
# ----------------------------
# Plot
# ----------------------------
fig, ax = plt.subplots(figsize=figsize)
# ----------------------------
# Recombination (ON TOP, darker)
# ----------------------------
if "cm_mb" in df_plot.columns:
ax2 = ax.twinx()
ax2.set_zorder(ax.get_zorder() + 1)
ax2.patch.set_alpha(0)
x_rec = (
df_plot["midpoint"].to_numpy()
if is_regional
else df_plot["BPcum"].to_numpy()
)
y_rec = df_plot["cm_mb"].fill_null(0).to_numpy()
recomb_color = "#e60000"
if is_regional:
ax2.plot(x_rec, y_rec, color=recomb_color, lw=1.6, alpha=0.6, zorder=50)
else:
ax2.scatter(x_rec, y_rec, color=recomb_color, s=2, alpha=0.45, zorder=50)
ax2.set_ylabel("Recombination Rate (cM/Mb)", fontsize=9)
if is_regional and center is not None:
ax2.set_xlim(xlim_lo, xlim_hi)
# ----------------------------
# Manhattan scatter
# ----------------------------
if is_regional:
y = (
-np.log10((1 - df_plot["P"]).clip(lower_bound=eps).to_numpy())
if log_transform
else df_plot["P"].to_numpy()
)
x = df_plot["midpoint"].to_numpy()
colors = [_CHR_COLORS[c] for c in df_plot["CHR_INT"].to_list()]
ax.scatter(x, y, color=colors, s=8, lw=0, zorder=10)
else:
for c_val in chrom_order:
sub = df_plot.filter(pl.col("CHR_INT") == c_val)
if sub.is_empty():
continue
y = (
-np.log10((1 - sub["P"]).clip(lower_bound=eps).to_numpy())
if log_transform
else sub["P"].to_numpy()
)
x = sub["BPcum"].to_numpy()
ax.scatter(x, y, color=_CHR_COLORS[c_val], s=8, lw=0, zorder=10)
# ----------------------------
# Center line
# ----------------------------
if is_regional and center is not None:
ax.axvline(
plot_center_line, color="black", lw=1.2, ls="--", alpha=0.6, zorder=100
)
# ----------------------------
# Thresholds (RESTORED)
# ----------------------------
if threshold_lines:
for y_v, ls, lbl in threshold_lines:
ax.axhline(y_v, color="black", linestyle=ls, lw=1.0, label=lbl, zorder=20)
elif log_transform:
ax.axhline(
3, color="black", ls="-", lw=1.2, label=r"$p_{sweep} > 0.999$", zorder=20
)
ax.axhline(
2, color="black", ls="--", lw=1.2, label=r"$p_{sweep} > 0.99$", zorder=20
)
# ----------------------------
# Formatting
# ----------------------------
if log_transform:
ax.set_ylabel(r"$-\log_{10}(1 - P)$")
ax.set_ylim(0, 8)
else:
ax.set_ylabel("Probability of Sweep")
ax.set_ylim(0, 1.1)
if is_regional:
ax.xaxis.set_major_formatter(
mticker.FuncFormatter(lambda x, _: f"{x / 1e6:.2f} Mb")
)
ax.set_xlabel(f"Chromosome {chrom}")
else:
ax.set_xticks(axisdf["center"].to_list())
ax.set_xticklabels([str(c) for c in axisdf["CHR_INT"].to_list()], fontsize=7)
ax.spines[["top", "right"]].set_visible(False)
ax.grid(axis="y", color="lightgray", lw=0.5)
if log_transform or threshold_lines:
ax.legend(fontsize=9)
if title:
ax.set_title(title)
ax.set_xlim(center - 1_000_000, center + 1_000_000)
# If you want the Y-axis to auto-adjust to the data
# visible in this new X-range:
ax.relim()
ax.autoscale_view(scalex=False, scaley=True)
plt.tight_layout()
if out:
fig.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
# # 1. Access the main axis
# ax = a.axes[0]
# # 2. Set the zoom and the narrow figure size
# ax.set_xlim(80_900_000, 82_800_000)
# a.set_size_inches(6, 5)
# # 3. FIX OVERLAPPING TICKS
# # Option A: Tell Matplotlib to only show a few ticks (e.g., max 4)
# ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=4))
# # Option B: Rotate the labels slightly so they don't hit each other
# plt.setp(ax.get_xticklabels(), rotation=30, ha='right')
# # 4. Cleanup and Save
# a.tight_layout()
# a.savefig("plot_zoom_fixed_ticks.png", dpi=150, bbox_inches="tight")
return fig
[docs]
def plot_sweep_density(prediction, output_path=None):
"""
Histogram of per-window sweep probability split by chromosome.
Parameters
----------
prediction : pl.DataFrame | str
Output of CNN.predict() or path to parquet/CSV with columns
chr, start, end, prob_sweep.
output_path : str, optional
If given, saves the figure as SVG.
Returns
-------
fig : matplotlib.figure.Figure
"""
if isinstance(prediction, str):
df = (
pl.read_parquet(prediction)
if prediction.endswith(".parquet")
else pl.read_csv(prediction)
)
else:
df = prediction
def _chr_key(x):
s = str(x).replace("chr", "")
return int(s) if s.isdigit() else ord(s[0])
chroms = sorted(df["chr"].unique().to_list(), key=_chr_key)
ncols = 4
nrows = math.ceil(len(chroms) / ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharey=True)
axes = np.array(axes).flatten()
for ax, chrom in zip(axes, chroms):
vals = df.filter(pl.col("chr") == chrom)["prob_sweep"].to_numpy()
ax.hist(
vals, bins=30, range=(0, 1), color="steelblue", edgecolor="none", alpha=0.8
)
pct = (vals > 0.5).mean() * 100
ax.set_title(f"{chrom} ({pct:.1f}% > 0.5)", fontsize=9)
ax.set_xlabel("P(sweep)", fontsize=8)
ax.axvline(0.5, color="tomato", linewidth=0.9, linestyle="--")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
for ax in axes[len(chroms) :]:
ax.set_visible(False)
fig.tight_layout()
if output_path is not None:
fig.savefig(output_path, bbox_inches="tight")
return fig
[docs]
def plot_fv_pca(train_data, empirical_data, subsample=5000, output_path=None):
"""
PCA of the feature vector matrix colored by neutral/sweep label.
Parameters
----------
train_data : str | pl.DataFrame
Path to fvs*.parquet or already-loaded DataFrame. Must have a 'model' column.
subsample : int
Max rows to use (avoids slow PCA on very large datasets). Default 5000.
output_path : str, optional
If given, saves SVG.
Returns
-------
fig : matplotlib.figure.Figure
"""
from sklearn.decomposition import PCA
if isinstance(train_data, str):
df = pl.read_parquet(train_data)
else:
df = train_data
if df.shape[0] > subsample:
df = df.sample(n=subsample, seed=42)
df = df.with_columns(
pl.when(pl.col("model") != "neutral")
.then(pl.lit("sweep"))
.otherwise(pl.lit("neutral"))
.alias("model")
)
meta_cols = {"iter", "s", "t", "f_i", "f_t", "mu", "r", "model"}
feat_cols = [c for c in df.columns if c not in meta_cols]
X = df.select(feat_cols).fill_null(0).to_numpy()
labels = df["model"].to_numpy()
pca = PCA(n_components=2)
Z = pca.fit_transform(X)
fig, ax = plt.subplots(figsize=(7, 6))
for cls, color in [("neutral", "steelblue"), ("sweep", "tomato")]:
mask = labels == cls
ax.scatter(
Z[mask, 0],
Z[mask, 1],
c=color,
label=cls,
alpha=0.4,
s=8,
edgecolors="none",
)
var = pca.explained_variance_ratio_
ax.set_xlabel(f"PC1 ({var[0] * 100:.1f}%)")
ax.set_ylabel(f"PC2 ({var[1] * 100:.1f}%)")
ax.legend(markerscale=3, frameon=False)
ax.set_title("Feature vector PCA — neutral vs sweep")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.tight_layout()
if output_path is not None:
fig.savefig(output_path, bbox_inches="tight")
return fig
[docs]
def plot_stat_distributions(
train_data,
empirical_data=None,
stats=[
"dind",
"dist_kurtosis",
"dist_skew",
"dist_var",
"h1",
"h12",
"h2_h1",
"haf",
"hapdaf_o",
"hapdaf_s",
"high_freq",
"ihs",
"isafe",
"k_counts",
"low_freq",
"max_fda",
"nsl",
"omega_max",
"pi",
"s_ratio",
"tajima_d",
"theta_h",
"theta_w",
"zns",
],
output_path=None,
):
"""
Violin plots of feature stats split by neutral/sweep (and optionally empirical).
Parameters
----------
train_data : str | pl.DataFrame
Training fvs*.parquet — must have a 'model' column.
empirical_data : str | pl.DataFrame, optional
Empirical fvs*.parquet — no 'model' column; plotted as a third distribution.
stats : list[str], optional
Stat base names to plot (e.g. ['pi', 'h12', 'ihs']).
Default: one representative column per unique stat base name.
output_path : str, optional
Save path for SVG.
Returns
-------
fig : matplotlib.figure.Figure
"""
if isinstance(train_data, str):
df = pl.read_parquet(train_data)
else:
df = train_data
df = df.with_columns(
pl.when(pl.col("model") != "neutral")
.then(pl.lit("sweep"))
.otherwise(pl.lit("neutral"))
.alias("model")
)
if empirical_data is not None:
if isinstance(empirical_data, str):
df_emp = pl.read_parquet(empirical_data)
else:
df_emp = empirical_data
else:
df_emp = None
meta_cols = {"iter", "s", "t", "f_i", "f_t", "mu", "r", "model"}
feat_cols = [c for c in df.columns if c not in meta_cols]
# Pick one representative column per stat base name
if stats is None:
seen = {}
for c in feat_cols:
parts = c.rsplit("_", 2)
base = parts[0] if len(parts) == 3 else c
if base not in seen:
seen[base] = c
plot_cols = list(seen.values())
stat_labels = list(seen.keys())
else:
plot_cols, stat_labels = [], []
for s in stats:
col = next((c for c in feat_cols if c.startswith(s + "_")), None)
if col is not None:
plot_cols.append(col)
stat_labels.append(s)
ncols = 4
nrows = math.ceil(len(plot_cols) / ncols)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 3.5, nrows * 3.2))
axes = np.array(axes).flatten()
for ax, col, label in zip(axes, plot_cols, stat_labels):
data_neutral = (
df.filter(pl.col("model") == "neutral")[col].drop_nulls().to_numpy()
)
data_sweep = df.filter(pl.col("model") == "sweep")[col].drop_nulls().to_numpy()
datasets = [data_neutral, data_sweep]
positions = [0, 1]
face_colors = ["steelblue", "tomato"]
if df_emp is not None and col in df_emp.columns:
datasets.append(df_emp[col].drop_nulls().to_numpy())
positions.append(2)
face_colors.append("goldenrod")
parts = ax.violinplot(
datasets, positions=positions, showmedians=True, widths=0.6
)
for body, color in zip(parts["bodies"], face_colors):
body.set_facecolor(color)
body.set_alpha(0.7)
tick_labels = (
["neutral", "sweep"]
if df_emp is None
else ["neutral", "sweep", "empirical"]
)
ax.set_xticks(positions)
ax.set_xticklabels(tick_labels, fontsize=8)
ax.set_title(label, fontsize=9)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
for ax in axes[len(plot_cols) :]:
ax.set_visible(False)
fig.tight_layout()
if output_path is not None:
fig.savefig(output_path, bbox_inches="tight")
return fig
def _load_scan(stat_file: str, stat_col: str) -> pl.DataFrame:
df = pl.read_csv(stat_file, separator="\t")
if stat_col not in df.columns:
raise ValueError(
f"Column '{stat_col}' not found in {stat_file}. Available: {df.columns}"
)
return df
def _resolve_scan_inputs(stats, stat_cols):
"""Normalise flexible scan inputs to (list[pl.DataFrame], list[str]).
Accepts:
- ``dict`` from ``scan()`` — keys are stat names, values are DataFrames.
- ``str`` — path to a single TSV file; ``stat_cols`` must be provided.
- ``list[str]`` — paths to multiple TSV files; ``stat_cols`` must match length.
"""
if isinstance(stats, dict):
if stat_cols is None:
from .scan import STAT_REGISTRY
stat_cols = [
STAT_REGISTRY[s].rank_col if s in STAT_REGISTRY else s
for s in stats.keys()
]
elif isinstance(stat_cols, str):
stat_cols = [stat_cols]
dfs = [stats[s] for s in stats.keys()]
elif isinstance(stats, str):
if stat_cols is None:
raise ValueError("stat_cols required when stats is a file path")
if isinstance(stat_cols, str):
stat_cols = [stat_cols]
dfs = [_load_scan(stats, s) for s in stat_cols]
elif isinstance(stats, (list, tuple)):
if isinstance(stat_cols, str):
stat_cols = [stat_cols] * len(stats)
elif stat_cols is None:
raise ValueError("stat_cols required when stats is a list of file paths")
dfs = [_load_scan(f, s) for f, s in zip(stats, stat_cols)]
else:
raise TypeError(
"stats must be a dict (scan() output), a file path str, or a list of paths"
)
return dfs, stat_cols
def _chr_to_float(col_expr):
"""Polars expression: strip 'chr' prefix and cast to Float64 for sorting."""
return col_expr.cast(pl.Utf8).str.replace("^chr", "").cast(pl.Float64)
[docs]
def plot_scan(
stats,
stat_cols=None,
pvalue: bool = False,
top_pct: float = 0.01,
threshold_lines: list | None = None,
out: str | None = None,
figsize: tuple | None = None,
title: str | None = None,
sharey: bool = False,
chrom: str | None = None,
center: int | None = None,
window_bp: int = 500_000,
) -> plt.Figure:
"""Genome-wide or regional scan plot — single stat or stacked multi-stat.
Accepts output from ``scan()`` (dict) or paths to ``{prefix}.{stat}.txt`` files.
**Genome-wide mode** (default, ``chrom=None``):
Single stat → 1-panel Manhattan. Multiple stats → stacked panels, shared x-axis.
``pvalue=False`` plots raw values with bipolar colouring for signed stats.
``pvalue=True`` plots ``-log10(p_emp)`` with threshold lines at p = 0.01 / 0.001.
**Zoom mode** (``chrom`` + ``center`` provided):
n rows × 2 columns: raw stat (left) | ``-log10(p_emp)`` (right).
Filtered to ``[center ± window_bp]``. Orange dashed line marks the centre.
Parameters
----------
stats : dict | str | list[str]
- ``dict`` from ``scan()`` — keys are stat names, values are DataFrames.
- ``str`` — path to a single ``{stat}.txt`` scan output file.
- ``list[str]`` — paths to multiple scan output files.
stat_cols : str | list[str], optional
Stat column name(s). Required when ``stats`` is a file path / list of paths.
When ``stats`` is a dict, defaults to all keys.
pvalue : bool
Genome-wide mode only. If True, plot ``-log10(p_emp)`` with y=[0,10] and
threshold lines. If False, plot raw stat values with bipolar colouring.
top_pct : float
Fraction highlighted as outliers in genome-wide raw mode.
threshold_lines : list of (y_value, linestyle, label), optional
Horizontal lines. ``pvalue=True`` default: y=2 dashed (p=0.01), y=3 solid.
Pass ``[]`` to suppress.
out : str, optional
Save path. If None, shows interactively.
figsize : tuple, optional
Genome-wide default: (14, 4) / (14, 3×n). Zoom default: (10, 2.5×n).
title : str, optional
Title for single-stat genome-wide plots.
sharey : bool
Share y-axis across panels in genome-wide mode (default False).
chrom : str, optional
Chromosome for zoom mode (e.g. ``"22"`` or ``"chr22"``).
center : int, optional
Centre position (bp) for zoom mode.
window_bp : int
Half-window size in bp for zoom mode (default 500 000 = ±500 kb).
"""
dfs, stat_cols = _resolve_scan_inputs(stats, stat_cols)
n = len(dfs)
# ── Zoom mode ────────────────────────────────────────────────────────────
if chrom is not None and center is not None:
lo, hi = center - window_bp, center + window_bp
chrom_str = str(chrom).lstrip("chr")
if figsize is None:
figsize = (10, 4) if n == 1 else (10, 3.0 * n)
fig, axes = plt.subplots(n, 1, figsize=figsize, sharex=True, sharey=sharey)
axes = [axes] if n == 1 else list(axes)
for ax, df, stat_col in zip(axes, dfs, stat_cols):
pval_col = f"{stat_col}_pvalue"
df = df.filter(
pl.col("chrom").cast(pl.Utf8).str.replace("^chr", "") == chrom_str
).filter((pl.col("pos") >= lo) & (pl.col("pos") <= hi))
pos = df["pos"].to_numpy()
if pvalue and pval_col in df.columns:
p_raw = df[pval_col].to_numpy(allow_copy=True).astype(np.float64)
y = -np.log10(np.clip(p_raw, 1e-10, None))
y_label = (
rf"$\mathrm{{{stat_col}}}:\ -\log_{{10}}(p_{{\mathrm{{emp}}}})$"
)
ax.scatter(
pos,
y,
s=4,
color="#333333",
alpha=0.7,
linewidths=0,
rasterized=True,
)
lines = (
threshold_lines
if threshold_lines is not None
else [
(2, "--", "p = 0.01"),
(3, "-", "p = 0.001"),
]
)
for y_val, ls, label in lines:
ax.axhline(
y_val, color="black", linestyle=ls, linewidth=1.0, label=label
)
ax.set_ylim(0, 10)
ax.set_yticks([2, 4, 6, 8, 10])
else:
y = df[stat_col].to_numpy(allow_copy=True).astype(np.float64)
y_label = stat_col
has_neg = np.nanmin(y) < 0
if has_neg:
abs_thresh = np.nanpercentile(np.abs(y), (1 - top_pct) * 100)
pos_out = y >= abs_thresh
neg_out = y <= -abs_thresh
background = ~pos_out & ~neg_out
else:
thresh = np.nanpercentile(y, (1 - top_pct) * 100)
pos_out = y >= thresh
neg_out = np.zeros(len(y), dtype=bool)
background = ~pos_out
ax.scatter(
pos[background],
y[background],
s=3,
color="#333333",
alpha=0.4,
linewidths=0,
rasterized=True,
)
if pos_out.any():
ax.scatter(
pos[pos_out],
y[pos_out],
s=8,
color="#d62728",
alpha=0.8,
linewidths=0,
rasterized=True,
label=f"Top {top_pct * 100:g}%{' (+)' if has_neg else ''}",
)
if has_neg and neg_out.any():
ax.scatter(
pos[neg_out],
y[neg_out],
s=8,
color="#1f77b4",
alpha=0.8,
linewidths=0,
rasterized=True,
label=f"Top {top_pct * 100:g}% (−)",
)
if threshold_lines:
for y_val, ls, label in threshold_lines:
ax.axhline(
y_val,
color="black",
linestyle=ls,
linewidth=1.0,
label=label,
)
ax.axvline(center, color="black", lw=1, ls="--", alpha=1)
ax.set_ylabel(y_label, fontsize=8)
ax.legend(fontsize=8, frameon=False, markerscale=3)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", color="lightgray", linewidth=0.5)
ax.grid(axis="x", visible=False)
if n == 1 and title:
axes[0].set_title(title)
axes[-1].set_xlabel("Position")
axes[-1].xaxis.set_major_formatter(
mticker.FuncFormatter(lambda x, _: f"{x / 1e6:.2f} Mb")
)
fig.tight_layout()
if out:
fig.savefig(out, dpi=150)
plt.close(fig)
else:
plt.show()
return fig
# ── Genome-wide mode ─────────────────────────────────────────────────────
if figsize is None:
figsize = (14, 4) if n == 1 else (14, 3.0 * n)
fig, axes = plt.subplots(n, 1, figsize=figsize, sharex=True, sharey=sharey)
axes = [axes] if n == 1 else list(axes)
# Chr offsets: computed once from the first DataFrame
df0 = dfs[0].with_columns(
[
_chr_to_float(pl.col("chrom")).alias("CHR"),
pl.col("pos").cast(pl.Float64).alias("BP"),
]
)
chr_lens = (
df0.group_by("CHR")
.agg(pl.col("BP").max().alias("chr_len"))
.sort("CHR")
.with_columns((pl.col("chr_len").cum_sum() - pl.col("chr_len")).alias("tot"))
.select(["CHR", "tot"])
)
axisdf = None
alt_colors = ["#333333", "#aaaaaa"]
for ax, df, stat_col in zip(axes, dfs, stat_cols):
pval_col = f"{stat_col}_pvalue"
df_plot = (
df.with_columns(
[
_chr_to_float(pl.col("chrom")).alias("CHR"),
pl.col("pos").cast(pl.Float64).alias("BP"),
]
)
.join(chr_lens, on="CHR", how="left")
.sort(["CHR", "BP"])
.with_columns((pl.col("BP") + pl.col("tot")).alias("BPcum"))
)
if axisdf is None:
axisdf = (
df_plot.group_by("CHR")
.agg(
((pl.col("BPcum").max() + pl.col("BPcum").min()) / 2).alias(
"center"
)
)
.sort("CHR")
)
bpcum = df_plot["BPcum"].to_numpy()
chr_vals = df_plot["CHR"].to_numpy()
if pvalue and pval_col in df_plot.columns:
p_raw = df_plot[pval_col].to_numpy(allow_copy=True).astype(np.float64)
y = -np.log10(np.clip(p_raw, 1e-10, None))
y_label = rf"$\mathrm{{{stat_col}}}:\ -\log_{{10}}(p_{{\mathrm{{emp}}}})$"
chromosomes = sorted(df_plot["CHR"].unique().to_list())
for i, chrom in enumerate(chromosomes):
mask = chr_vals == chrom
ax.scatter(
bpcum[mask],
y[mask],
s=4,
color=alt_colors[i % 2],
alpha=0.7,
linewidths=0,
rasterized=True,
)
lines = (
threshold_lines
if threshold_lines is not None
else [
(2, "--", "p = 0.01"),
(3, "-", "p = 0.001"),
]
)
for y_val, ls, label in lines:
ax.axhline(
y_val, color="black", linestyle=ls, linewidth=1.0, label=label
)
ax.set_ylim(0, 10)
ax.set_yticks([2, 4, 6, 8, 10])
else:
y = df_plot[stat_col].to_numpy(allow_copy=True).astype(np.float64)
y_label = stat_col
has_neg = np.nanmin(y) < 0
if has_neg:
abs_thresh = np.nanpercentile(np.abs(y), (1 - top_pct) * 100)
pos_out = y >= abs_thresh
neg_out = y <= -abs_thresh
background = ~pos_out & ~neg_out
else:
thresh = np.nanpercentile(y, (1 - top_pct) * 100)
pos_out = y >= thresh
neg_out = np.zeros(len(y), dtype=bool)
background = ~pos_out
chromosomes = sorted(df_plot["CHR"].unique().to_list())
for i, chrom in enumerate(chromosomes):
mask = (chr_vals == chrom) & background
ax.scatter(
bpcum[mask],
y[mask],
s=2,
color=alt_colors[i % 2],
alpha=0.4,
linewidths=0,
rasterized=True,
)
if pos_out.any():
ax.scatter(
bpcum[pos_out],
y[pos_out],
s=6,
color="#d62728",
alpha=0.8,
linewidths=0,
rasterized=True,
label=f"Top {top_pct * 100:g}%{' (+)' if has_neg else ''}",
)
if has_neg and neg_out.any():
ax.scatter(
bpcum[neg_out],
y[neg_out],
s=6,
color="#1f77b4",
alpha=0.8,
linewidths=0,
rasterized=True,
label=f"Top {top_pct * 100:g}% (−)",
)
if threshold_lines:
for y_val, ls, label in threshold_lines:
ax.axhline(
y_val, color="black", linestyle=ls, linewidth=1.0, label=label
)
ax.set_ylabel(y_label, fontsize=8)
ax.legend(fontsize=8, frameon=False, markerscale=3)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", color="lightgray", linewidth=0.5)
ax.grid(axis="x", visible=False)
if n == 1 and title:
axes[0].set_title(title)
axes[-1].set_xticks(axisdf["center"].to_list())
axes[-1].set_xticklabels([str(int(c)) for c in axisdf["CHR"].to_list()], fontsize=7)
axes[-1].set_xlabel("Chromosome")
fig.tight_layout()
if out:
fig.savefig(out, dpi=150)
plt.close(fig)
else:
plt.show()
return fig
def plot_scan_zoom(
stats, stat_cols, chrom, center, window_bp=500_000, **kwargs
) -> plt.Figure:
"""Deprecated alias — use ``plot_scan(..., chrom=chrom, center=center)``."""
return plot_scan(
stats,
stat_cols=stat_cols,
chrom=chrom,
center=center,
window_bp=window_bp,
**kwargs,
)
################## Sorting methods
@njit(parallel=False)
def corr_sorting(matrix):
samples, sites = matrix.shape
# Step 1: Compute PCC matrix between rows
PCC = np.zeros((samples, samples), dtype=np.float64)
sum_pcc = np.zeros(samples, dtype=np.float64)
P_A = np.zeros(samples, dtype=np.int32)
for i in range(samples):
for k in range(sites):
P_A[i] += matrix[i, k]
for i in range(samples):
for k in range(samples):
if i == k:
PCC[i, k] = 1.000001
else:
P_AB = 0
for m in range(sites):
if matrix[i, m] == 1 and matrix[k, m] == 1:
P_AB += 1
num = (P_AB / sites - (P_A[i] / sites) * (P_A[k] / sites)) ** 2
den = (
(P_A[i] / sites)
* (1 - P_A[i] / sites)
* (P_A[k] / sites)
* (1 - P_A[k] / sites)
)
PCC[i, k] = num / den if den != 0 else 0.0
for i in range(samples):
for k in range(samples):
sum_pcc[i] += PCC[i, k]
# Step 2: Find max PCC sum index
max_idx = 0
for i in range(1, samples):
if sum_pcc[i] > sum_pcc[max_idx]:
max_idx = i
# Step 3: Sort rows based on PCC[max_idx] in descending order
indices = np.arange(samples)
for m in range(samples):
for n in range(m + 1, samples):
if PCC[max_idx, indices[m]] < PCC[max_idx, indices[n]]:
indices[m], indices[n] = indices[n], indices[m]
# Step 4: Reorder matrix
sorted_matrix = np.empty_like(matrix)
for i in range(samples):
for j in range(sites):
sorted_matrix[i, j] = matrix[indices[i], j]
return sorted_matrix
@njit
def daf_sorting(matrix):
samples, sites = matrix.shape
count = np.zeros(sites, dtype=int64)
# Count number of 1s per column (DAF)
for m in range(sites):
for n in range(samples):
if matrix[n, m] == 1:
count[m] += 1
# Bubble sort columns by descending count
for m in range(sites):
for n in range(m + 1, sites):
if count[m] < count[n]:
# Swap columns m and n
for k in range(samples):
tmp = matrix[k, m]
matrix[k, m] = matrix[k, n]
matrix[k, n] = tmp
tmpc = count[m]
count[m] = count[n]
count[n] = tmpc
return matrix
@njit
def freq_sorting(matrix):
samples, sites = matrix.shape
weights = np.zeros(samples, dtype=np.int32)
# Step 1: Count the number of 1s (Hamming weight) per row
for i in range(samples):
for j in range(sites):
if matrix[i, j] == 1:
weights[i] += 1
# Step 2: Bubble sort rows by descending Hamming weight
for i in range(samples - 1):
for j in range(i + 1, samples):
if weights[i] < weights[j]:
# Swap weights
tmp_w = weights[i]
weights[i] = weights[j]
weights[j] = tmp_w
# Swap rows in matrix
for k in range(sites):
tmp = matrix[i, k]
matrix[i, k] = matrix[j, k]
matrix[j, k] = tmp
return matrix
@njit
def pcc_column_sort_numba(matrix):
n, m = matrix.shape
PCC_matrix = np.zeros((m, m), dtype=np.float64)
scores = np.zeros(m, dtype=np.float64)
# Step 1: Compute PCC_matrix between columns
for i in range(m):
for j in range(m):
if i == j:
PCC_matrix[i, j] = 1.000001
else:
PA_i = 0
PA_j = 0
PAB = 0
for k in range(n):
PA_i += matrix[k, i]
PA_j += matrix[k, j]
PAB += matrix[k, i] * matrix[k, j]
num = (PAB / n - (PA_i * PA_j) / (n * n)) ** 2
den = (PA_i / n) * (1 - PA_i / n) * (PA_j / n) * (1 - PA_j / n)
PCC_matrix[i, j] = num / den if den != 0 else 0.0
# Step 2: Compute total PCC for each SNP (column)
for i in range(m):
for j in range(m):
scores[i] += PCC_matrix[i, j]
# Step 3: Bubble sort columns by score (descending)
for i in range(m - 1):
for j in range(i + 1, m):
if scores[i] < scores[j]:
# Swap scores
tmp_score = scores[i]
scores[i] = scores[j]
scores[j] = tmp_score
# Swap columns i and j in matrix
for k in range(n):
tmp = matrix[k, i]
matrix[k, i] = matrix[k, j]
matrix[k, j] = tmp
return matrix
def haplotype_freq_sorting_hamming(matrix):
# Based on scikit haplotype counts
# setup collection
d = defaultdict(list)
S, n = matrix.shape
# iterate over haplotypes
for i in range(n):
# hash the haplotype
k = hash(matrix[:, i].tobytes())
# collect
d[k].append(i)
# extract sets, sorted by most common
counts = sorted(d.values(), key=len, reverse=True)
f = np.array([len(g) / n for g in counts], dtype=float)
# Representative column index for each group (you said groups are equal, so first is fine)
reps = np.array([g[0] for g in counts], dtype=int)
# Choose reference haplotype
ref = matrix[:, reps[np.argmax(f)]]
# Compute Hamming distance of each group's representative to ref
# Vectorized across all representatives
reps_mat = matrix[:, reps] # shape (S, k)
distances = (reps_mat != ref[:, None]).sum(0) # shape (k,)
# Sort groups by (-frequency, distance, representative index as a stable tie-breaker)
# np.lexsort uses last key as primary, so order keys accordingly
# primary: -f, then distance, then reps
group_order = np.lexsort((reps, distances, -f))
# final column order by concatenating columns from each group in the sorted group order
col_order = np.concatenate([np.asarray(counts[g], dtype=int) for g in group_order])
matrix_reordered = matrix[:, col_order]
return matrix_reordered, f
def haplotype_freq_sorting(matrix):
"""
Reorder matrix columns by haplotype frequency (descending), grouping columns
exactly as in the hashing/defaultdict approach.
Returns
-------
matrix_reordered : (S, n) ndarray
col_order : (n,) ndarray of int
groups_sorted : list[list[int]]
f : (k,) ndarray of float
"""
matrix = np.asarray(matrix)
S, n = matrix.shape
# same grouping logic as in haplotype_freq_sorting_hamming
d = defaultdict(list)
for i in range(n):
k = hash(matrix[:, i].tobytes())
d[k].append(i)
# groups sorted by most common (stable for ties)
groups_sorted = sorted(d.values(), key=len, reverse=True)
# frequencies identical to the first function
f = np.array([len(g) / n for g in groups_sorted], dtype=float)
# column permutation and reordered matrix
if groups_sorted:
col_order = np.concatenate([np.asarray(g, dtype=int) for g in groups_sorted])
else:
col_order = np.array([], dtype=int)
matrix_reordered = matrix[:, col_order] if n > 0 else matrix
return matrix_reordered, col_order, groups_sorted, f
def disrupt_genomic_positions(matrix):
"""
Randomly permute (shuffle) the columns of `matrix`.
Parameters
----------
matrix : array-like
2D data whose columns will be permuted.
Returns
-------
permuted : np.ndarray
Matrix with columns permuted.
"""
arr = np.array(matrix, copy=True)
if arr.ndim != 2:
raise ValueError("`matrix` must be 2D.")
rng = np.random.default_rng()
n_cols = arr.shape[1]
perm = rng.permutation(n_cols)
permuted = arr[:, perm]
return permuted
def disrupt_ld(matrix):
arr = matrix.T
np.random.shuffle(arr)
for j in range(arr.shape[0]):
# shuffle within each rows, which are snp columns
np.random.shuffle(arr[j])
return arr
def disrupt_af(matrix):
rows, cols = matrix.shape
arr = matrix.copy().reshape(rows * cols)
np.random.shuffle(arr)
# reshape back to 2D
arr = arr.reshape(rows, cols)
# plot
return arr
def mediant_af(matrix):
rows, cols = matrix.shape
num_ones = np.count_nonzero(matrix.T)
# make new 1D arr with the counts of two entry groups: zeros and ones
# zeros = # total entries - ones
arr = np.concatenate([np.zeros(rows * cols - num_ones), np.ones(num_ones)])
arr = arr.reshape((cols, rows)).T
return arr
def mediant_af_left(matrix):
rows, cols = matrix.T.shape
num_ones = np.count_nonzero(matrix.T)
arr = np.concatenate([np.zeros(rows * cols - num_ones), np.ones(num_ones)])
arr = arr.reshape((cols, rows)).T
return arr
################## Ranking
def merge_regions(prediction, p):
"""
Merge genomic regions where prob_sweep > p
Parameters:
-----------
prediction : str or pl.DataFrame
File path to CSV or Polars DataFrame
p : float
Probability threshold for filtering
Returns:
--------
tuple: (df_merged, summary_stats)
- df_merged: LazyFrame with merged intervals
- summary_stats: DataFrame with chr, merged_span, total_span, pct
"""
# Load or clone the data
if isinstance(prediction, str):
df_pred = (
pl.read_csv(prediction, has_header=True, separator=",")
.select("chr", "start", "end", "prob_sweep")
.filter(
pl.col("chr")
.str.replace("chr", "")
.is_in([str(i) for i in range(1, 23)])
)
.with_columns(
(pl.lit("chr") + pl.col("chr").str.replace("chr", "")).alias("chr")
)
.sort(["chr", "start"])
)
elif isinstance(prediction, pl.DataFrame):
df_pred = (
prediction.clone()
.select("chr", "start", "end", "prob_sweep")
.sort(["chr", "start"])
)
else:
raise ValueError("prediction must be a file path (str) or a Polars DataFrame")
# Calculate total genomic span analyzed per chromosome
total_span_analyzed = (
df_pred.group_by("chr")
.agg(
[
pl.col("start").min().alias("min_start"),
pl.col("end").max().alias("max_end"),
]
)
.with_columns((pl.col("max_end") - pl.col("min_start")).alias("total_span"))
.select(["chr", "total_span"])
)
# Filter for prob_sweep > p
filtered = df_pred.filter(pl.col("prob_sweep") > p)
# Merge consecutive/overlapping windows
df_merged = (
merge(
filtered,
min_dist=0,
cols=["chr", "start", "end"],
on_cols=None,
output_type="polars.LazyFrame",
projection_pushdown=True,
)
.with_columns((pl.col("end") - pl.col("start")).alias("d"))
.collect()
)
# Calculate merged span per chromosome
merged_span = df_merged.group_by("chr").agg(pl.col("d").sum().alias("merged_span"))
# Join and calculate percentage
summary_stats = (
total_span_analyzed.join(merged_span, on="chr", how="left")
.with_columns(
[
pl.col("merged_span").fill_null(0),
(pl.col("merged_span") / pl.col("total_span") * 100).alias("pct"),
]
)
.sort("chr")
)
return df_merged, summary_stats
[docs]
def rank_probabilities(prediction, feature_coordinates, rank_distance=False, k=111):
"""
Goal: match the original pybedtools/bedtools output *exactly*, including tie order.
Strategy for exact tie-order:
- Create an explicit, deterministic per-gene input order (`gene_order`) from the
sorted gene table (chr,start). This mirrors the order bedtools processes A.
- Carry `gene_order` through the pipeline and use it as the final tie-breaker
in the final sort (stable + deterministic).
Strategy for bedtools `closest -k` parity:
- We emulate `-k K` by sorting hits per gene by (d, chrom_pred, start_pred, end_pred)
and taking the first K.
- This gives deterministic selection when multiple windows share the same distance.
"""
def _to_point_1based_closed(df: pl.DataFrame, chr_col: str) -> pl.DataFrame:
# midpoint -> 1-based coordinate; represent as point interval [pos1, pos1]
return (
df.with_columns(
(((pl.col("start") + pl.col("end")) // 2) + 1).alias("_pos1")
)
.with_columns(pl.col("_pos1").alias("start"), pl.col("_pos1").alias("end"))
.drop("_pos1")
.rename({chr_col: "chrom"})
)
if isinstance(feature_coordinates, str):
df_genes = (
pl.read_csv(
feature_coordinates,
has_header=False,
separator="\t",
schema={
"chr": pl.Utf8,
"start": pl.Int64,
"end": pl.Int64,
"gene_id": pl.Utf8,
"strand": pl.Utf8,
},
)
.select("chr", "start", "end", "strand", "gene_id")
.filter(pl.col("chr").is_in([str(i) for i in range(1, 23)]))
.with_columns((pl.lit("chr") + pl.col("chr")).alias("chr"))
.sort(["chr", "start"])
# explicit deterministic order to break full ties exactly like bedtools A-stream
.with_row_index("gene_order", offset=0)
)
elif isinstance(feature_coordinates, pl.DataFrame):
df_genes = feature_coordinates.clone()
if "gene_id" not in df_genes.columns and "feature_id" in df_genes.columns:
df_genes = df_genes.rename({"feature_id": "gene_id"})
if "gene_order" not in df_genes.columns:
# ensure deterministic gene order if caller didn't provide it
df_genes = df_genes.sort(["chr", "start"]).with_row_index(
"gene_order", offset=0
)
else:
raise ValueError(
"feature_coordinates must be a file path (str) or a Polars DataFrame"
)
if isinstance(prediction, str):
df_pred = (
pl.read_csv(prediction, has_header=True, separator=",") # adjust if TSV
.select("chr", "start", "end", "prob_sweep")
.filter(
pl.col("chr")
.str.replace("chr", "")
.is_in([str(i) for i in range(1, 23)])
)
.with_columns(
(pl.lit("chr") + pl.col("chr").str.replace("chr", "")).alias("chr")
)
.sort(["chr", "start"])
)
elif isinstance(prediction, pl.DataFrame):
df_pred = (
prediction.clone()
.select("chr", "start", "end", "prob_sweep")
.sort(["chr", "start"])
)
else:
raise ValueError("prediction must be a file path (str) or a Polars DataFrame")
genes = _to_point_1based_closed(df_genes.rename({"chr": "chr"}), chr_col="chr")
preds = _to_point_1based_closed(df_pred.rename({"chr": "chr"}), chr_col="chr")
genes_lf = genes.lazy()
preds_lf = preds.lazy()
nearest_raw = nearest(
genes_lf,
preds_lf,
suffixes=("_gene", "_pred"),
cols1=["chrom", "start", "end"],
cols2=["chrom", "start", "end"],
output_type="polars.LazyFrame",
)
nearest_raw = nearest_raw.rename(
{
"gene_id_gene": "gene_id",
"strand_gene": "strand",
"gene_order_gene": "gene_order",
"chrom_gene": "chrom_gene",
"start_gene": "start_gene",
"end_gene": "end_gene",
"distance": "d_min",
}
).select(
"gene_id",
"strand",
"gene_order",
"chrom_gene",
"start_gene",
"end_gene",
"d_min",
)
gene_windows = (
nearest_raw.with_columns(
(pl.col("d_min") + 500_000).alias("_radius"),
pl.col("start_gene").alias("gene_pos"),
)
.with_columns(
(pl.col("gene_pos") - pl.col("_radius")).clip(lower_bound=1).alias("start"),
(pl.col("gene_pos") + pl.col("_radius")).alias("end"),
pl.col("chrom_gene").alias("chrom"),
)
.select(
"chrom",
"start",
"end",
"gene_id",
"strand",
"gene_order",
"gene_pos",
"d_min",
)
)
hits_raw = overlap(
gene_windows,
preds_lf,
suffixes=("_win", "_pred"),
cols1=["chrom", "start", "end"],
cols2=["chrom", "start", "end"],
output_type="polars.LazyFrame",
)
hits = hits_raw.rename(
{
"gene_id_win": "gene_id",
"strand_win": "strand",
"gene_order_win": "gene_order",
"gene_pos_win": "gene_pos",
"d_min_win": "d_min",
"chrom_win": "chrom",
"start_win": "start",
"end_win": "end",
}
).with_columns((pl.col("gene_pos") - pl.col("start_pred")).abs().alias("d"))
# k elements like in bedtools
if k is not None:
k = int(k)
hits = (
hits.sort(["gene_id", "d", "chrom_pred", "start_pred", "end_pred"])
.with_columns(pl.int_range(0, pl.len()).over("gene_id").alias("_k"))
.filter(pl.col("_k") < k)
.drop("_k")
)
if rank_distance:
# Distance-aware composite: sort by prob_sweep sum desc, d_sum asc,
# and finally gene_order asc to match original tie order.
w_rank = (
hits.filter((pl.col("d") - pl.col("d_min")).abs() <= 500_000)
.group_by("gene_id")
.agg(
pl.col("prob_sweep_pred").sum().alias("prob_sweep"),
pl.col("d").sum().alias("d_sum"),
pl.col("gene_order").min().alias("gene_order"),
)
.sort(
["prob_sweep", "d_sum", "gene_order"], descending=[True, False, False]
)
.with_row_index("rank", offset=1)
.select("gene_id", "rank", "prob_sweep")
).collect()
n_rank_max = (
w_rank.filter(pl.col("prob_sweep") == w_rank["prob_sweep"].max()).height
if w_rank.height
else 0
)
return w_rank, int(n_rank_max)
rank_unique = (
hits.filter(pl.col("d") == pl.col("d_min"))
.group_by("gene_id")
.agg(
pl.col("prob_sweep_pred").max().alias("prob_sweep"),
pl.col("gene_order").min().alias("gene_order"),
)
)
rank_rep = (
hits.with_columns((pl.col("d") - pl.col("d_min")).abs().alias("_abs"))
.filter(pl.col("_abs") <= 500_000)
.group_by("gene_id")
.agg(
pl.col("prob_sweep_pred").sum().alias("win_prob_sum"),
pl.col("d").sum().alias("win_d_sum"),
)
)
ranked = (
rank_unique.join(rank_rep, on="gene_id", how="left")
.with_columns(
pl.col("win_prob_sum").fill_null(0),
pl.col("win_d_sum").fill_null(0),
)
# FINAL KEY: gene_order ensures exact tie ordering consistent with the original A-stream
.sort(
["prob_sweep", "win_prob_sum", "win_d_sum", "gene_order"],
descending=[True, True, False, False],
)
.with_row_index("rank", offset=1)
.select("gene_id", "rank", "prob_sweep")
).collect()
n_rank_max = (
ranked.filter(pl.col("prob_sweep") == ranked["prob_sweep"].max()).height
if ranked.height
else 0
)
return ranked, int(n_rank_max)
def interpolate_rates(
prediction, recombination_map, prediction_lr=None, corr=False, bins=10, out=None
):
"""
Interpolate recombination rates (cM/Mb) onto prediction windows and optionally
replace low-recombination probabilities using an alternate predictions file.
Given a prediction output file and a recombination map, this function computes
recombination rates for each unique prediction window per chromosome (via
``get_cm(..., cm_mb=True)``) and joins the resulting ``cm_mb`` column back into
the prediction table using the join keys ``chr``, ``start``, and ``end``.
If ``prediction_lr`` is provided, the function performs an additional join for
that alternate prediction table and selectively replaces ``prob_sweep`` and
``prob_neutral`` in the main predictions for windows with ``cm_mb < 0.5`` using
the corresponding values from ``prediction_lr``.
If ``corr=True``, the function computes correlations between ``prob_sweep`` and
``cm_mb`` per chromosome, bins windows by ``cm_mb`` within each chromosome
(via ``_bin_group``), plots the binned trend per chromosome, and returns both
the augmented predictions and the per-chromosome binned trend table.
Input expectations:
- The **prediction** file must contain at least ``chr``, ``start``, and ``end``,
and typically contains ``prob_sweep`` / ``prob_neutral`` if replacement or
correlation is used. Additional columns are preserved.
- The **recombination_map** must be a tab-separated file (comment lines allowed,
prefixed with ``#``) with columns ``chr``, ``start``, ``end``, ``cm_mb``, ``cm``.
- The **prediction_lr** file (if provided) must contain ``chr``, ``start``, ``end``,
``prob_sweep``, and ``prob_neutral``.
:param str prediction: Path to a CSV file containing prediction windows and
associated probabilities. Must include ``chr``, ``start``, ``end``.
:param str recombination_map: Path to a tab-separated recombination map with
columns ``chr``, ``start``, ``end``, ``cm_mb``, ``cm``; may include ``#``
comment lines.
:param str | None prediction_lr: Optional path to an alternate predictions CSV.
If provided, ``prob_sweep`` and ``prob_neutral`` are replaced for windows
where ``cm_mb < 0.5`` using values from this file.
:param bool corr: If ``True``, compute and plot per-chromosome binned trends
and return ``(df_pred_rate, df_trend)``; if ``False`` (default), return only
the augmented predictions.
:returns: If ``corr=False``, returns **df_pred_rate**, a DataFrame containing
all original prediction columns plus ``cm_mb``. If ``corr=True``, returns a
pair ``(df_pred_rate, df_trend)``, where **df_trend** is a per-chromosome
binned table suitable for plotting (sorted by ``chr`` and ``bin``).
:rtype: polars.DataFrame | tuple[polars.DataFrame, polars.DataFrame]
:notes:
- Recombination rate interpolation is delegated to ``get_cm`` and is performed
independently per chromosome. ``get_cm`` is assumed to return a Polars
DataFrame that includes ``chr``, ``start``, ``end``, and ``cm_mb``.
- The join key for all merges is ``(chr, start, end)``. Missing keys will
result in null ``cm_mb`` (and will prevent probability replacement for
those rows).
- When ``prediction_lr`` is provided, only windows with ``cm_mb < 0.5`` from
the LR-joined table are eligible for overwriting probabilities in the base
predictions.
- When ``corr=True``, this function produces matplotlib output (calls
``plt.show()``) as a side effect.
"""
def _chr_key(s):
s = str(s)
m = re.match(r"^chr(\d+)$", s)
if m:
return (0, int(m.group(1)))
m = re.match(r"^chr([XYM]|MT)$", s)
if m:
order = {"X": 23, "Y": 24, "M": 25, "MT": 25}
return (1, order[m.group(1)])
return (2, s)
def _bin_group(g, n_bins=bins):
g = g.sort("cm_mb")
n = g.height
g = g.with_columns(pl.arange(0, n).alias("i"))
g = g.with_columns(
((pl.col("i") * n_bins) / n).floor().cast(pl.Int64).alias("bin")
)
out = (
g.group_by("bin")
.agg(
pl.mean("cm_mb").alias("cm_mb"),
pl.mean("prob_sweep").alias("prob_sweep"),
pl.len().alias("n_in_bin"),
pl.min("cm_mb").alias("cm_mb_min"),
pl.max("cm_mb").alias("cm_mb_max"),
((pl.col("prob_sweep") > 0.9).sum() / pl.len()).alias("cm_mb_prop"),
)
.sort("cm_mb_min")
.with_columns(
pl.lit(g["chr"][0]).alias("chr"),
# optional: readable range label
pl.format(
"[{}, {}]",
pl.col("cm_mb_min").round(4),
pl.col("cm_mb_max").round(4),
).alias("cm_mb_range"),
)
)
return out
if isinstance(prediction, pl.DataFrame):
df_pred = prediction
else:
df_pred = pl.read_csv(prediction)
df_rec = pl.read_csv(
recombination_map,
separator="\t",
comment_prefix="#",
schema=pl.Schema(
[
("chr", pl.String),
("start", pl.Int64),
("end", pl.Int64),
("cm_mb", pl.Float64),
("cm", pl.Float64),
]
),
).sort(["chr", "start"])
# Compute recombination rates per chr
rate_frames = [
get_cm(
df_rec.filter(pl.col("chr") == chr_), # assumes get_cm returns a Polars df
df_pred.filter(pl.col("chr") == chr_)
.select("start", "end")
.unique()
.to_numpy(),
cm_mb=True,
)
for chr_ in df_pred["chr"].unique()
]
df_rate = (
pl.concat(rate_frames)
.with_columns(pl.col("chr").str.replace("chr", "").cast(pl.Int32).alias("nchr"))
.sort(["nchr", "start"])
.select(["chr", "start", "end", "cm_mb"])
)
# Merge recombination rate into both predictions
df_pred_rate = df_pred.join(df_rate, on=["chr", "start", "end"])
if prediction_lr is not None:
df_pred_lr = pl.read_csv(prediction_lr)
df_pred_rate_lr = df_pred_lr.join(df_rate, on=["chr", "start", "end"])
# Replace probabilities where cm_mb < 0.5
df_lr_filtered = df_pred_rate_lr.filter(pl.col("cm_mb") < 0.5).select(
["chr", "start", "end", "prob_sweep", "prob_neutral"]
)
df_pred_rate = (
df_pred_rate.join(df_lr_filtered, on=["chr", "start", "end"], how="left")
.with_columns(
prob_sweep=pl.when(pl.col("prob_sweep_right").is_not_null())
.then(pl.col("prob_sweep_right"))
.otherwise(pl.col("prob_sweep")),
prob_neutral=pl.when(pl.col("prob_neutral_right").is_not_null())
.then(pl.col("prob_neutral_right"))
.otherwise(pl.col("prob_neutral")),
)
.select(df_pred.columns + ["cm_mb"])
)
else:
df_pred_rate = df_pred_rate.select(df_pred.columns + ["cm_mb"])
if corr:
df_corr = df_pred_rate.group_by("chr").agg(
pl.len().alias("n"),
pl.corr("prob_sweep", "cm_mb").alias("corr"),
)
df_bins = df_pred_rate.group_by("chr").map_groups(_bin_group)
df_corr_binned = df_bins.group_by("chr").agg(
pl.corr("prob_sweep", "cm_mb").alias("corr"), pl.len().alias("n")
)
df_trend = df_corr_binned.join(df_bins, on="chr", how="left")
chrs = sorted(df_trend["chr"].unique().to_list(), key=_chr_key)
ncols = 4 if len(chrs) >= 4 else len(chrs)
nrows = math.ceil(len(chrs) / ncols)
fig, axes = plt.subplots(
nrows, ncols, figsize=(4 * ncols, 3.0 * nrows), sharey=True
)
axes = np.array(axes).reshape(-1)
for ax, c in zip(axes, chrs):
t = df_trend.filter(pl.col("chr") == c)
x = t.get_column("cm_mb").to_numpy()
y = t.get_column("prob_sweep").to_numpy()
# y = t.get_column("cm_mb_prop").to_numpy()
ax.plot(x, y, linewidth=1.2, color="#2166ac")
r = t.get_column("corr")[0]
r_str = f"r={r:.2f}" if r is not None else "r=NA"
ax.set_title(f"{c} ({r_str})", fontsize=8, pad=3)
ax.axhline(0.5, linewidth=0.8, linestyle="--", color="gray", alpha=0.7)
ax.set_ylim(0, 1)
ax.set_xlabel("cM/Mb", fontsize=8, labelpad=2)
ax.set_ylabel(
"P(sweep)" if ax.get_subplotspec().is_first_col() else "", fontsize=8
)
ax.tick_params(labelsize=7)
ax.tick_params(axis="x", pad=2)
ax.spines[["top", "right"]].set_visible(False)
for ax in axes[len(chrs) :]:
ax.axis("off")
plt.tight_layout(h_pad=3.5, w_pad=1.5) # ← h_pad is the main lever here
if out is not None:
plt.savefig(out, dpi=300, bbox_inches="tight")
else:
plt.show()
return (
df_pred_rate,
df_trend.sort("chr", "bin").select(pl.exclude("n", "n_in_bin")),
df_corr,
)
else:
return df_pred_rate