Source code for flexsweep.scan

"""
flexsweep/scan.py — Standalone outlier scan.

Completely separate from the CNN pipeline (fvs-vcf / fvs*.parquet).
No neutral simulations, no fixed window grid. Each stat runs at its natural resolution.
Outlier ranking uses the genome-wide empirical distribution.

API usage::

    from flexsweep.scan import scan

    results = scan(
        "vcf_folder/",          # directory of *.vcf.gz, one per chromosome
        "YRI",
        stats=["ihs", "nsl", "h12", "lassip"],
        nthreads=8,
        config={"lassip": {"max_extend": 1e5}, "raisd": {"window_size": 100}},
    )
    # results["ihs"]    → Polars DataFrame at SNP resolution, ihs_pvalue column included
    # results["lassip"] → Polars DataFrame at window resolution, lassip_pvalue column included

CLI usage::

    flexsweep scan --vcf_path vcf_folder/ --out YRI \\
        --stats ihs,nsl,h12,lassip --w_size 201 --min_maf 0.05 --nthreads 8

Available stats
---------------
Per-SNP (output at SNP resolution):
    ihs, nsl, isafe, dind, high_freq, low_freq, s_ratio, hapdaf_o, hapdaf_s, haf

Sliding SNP-window (output at window resolution):
    h12, garud, lassi, lassip, raisd

Sliding bp-window (output at window resolution):
    tajima_d, pi, theta_w, fay_wu_h, zeng_e, achaz_y,
    fuli_f, fuli_f_star, fuli_d, fuli_d_star, neutrality, omega, zns, beta, ncd

Notes
-----
- iHS and nSL are z-scored within genome-wide DAF bins (100 bins by default).
  Add ``--recombination_map`` to normalize by (DAF × recomb_bin) jointly.
- dind, s_ratio, hapdaf_o, hapdaf_s use REF/ALT as ancestral/derived (same as fvs-vcf).
  Pass ``recombination_map`` to use genetic distances for window boundaries.
- lassip and lassi derive their neutral spectrum from the VCF itself (average
  haplotype frequency spectrum); no simulations are needed.
- delta_ihh is intentionally excluded.
- window_mode="auto" uses per-stat defaults (h12/garud/lassi/lassip/raisd=snp;
  tajima_d/pi/theta_w/fay_wru_h/zeng_e/achaz_y/fuli_*/neutrality/omega/zns/beta=bp).
  Pass window_mode="snp" or window_mode="bp" to override for all window stats.
  Default "auto" is correct when mixing stats with different natural modes.
"""

import glob
import os
import warnings
from collections import namedtuple
from math import ceil

from allel import nsl

from . import Parallel, delayed, np, pl
from .fv import (
    Lambda_statistic_fast,
    LASSI_spectrum_and_Kspectrum,
    Ld,
    T_m_statistic_fast,
    achaz_y,
    compute_r2_matrix_upper,
    dind_high_low_from_pairs,
    fast_sq_freq_pairs,
    fay_wu_h_norm,
    fs_stats_dataframe,
    fuli_d,
    fuli_d_star,
    fuli_f,
    fuli_f_star,
    garud_h,
    genome_reader,
    haf_top,
    hapdaf_from_pairs,
    hscan,
    ihs_ihh,
    mu_stat,
    ncd1,
    neut_average,
    neutrality_stats,
    omega_linear_correct,
    run_beta_window,
    run_isafe,
    s_ratio_from_pairs,
    tajima_d,
    theta_pi,
    theta_watterson,
    zeng_e,
)

# Stat definition

# resolution: "snp" | "window"
# tier: 1=hap+pos only, 2=+allele counts, 3=+rec_map/polarization
# rank_col: primary column used for ranking (higher = more sweepy)
# default_params: defaults merged with shared params at call time

StatDef = namedtuple("StatDef", ["resolution", "tier", "rank_col", "default_params"])

STAT_REGISTRY: dict[str, StatDef] = {
    # -- Per-SNP stats --
    "ihs": StatDef(
        "snp",
        1,
        "ihs",
        {
            "min_maf": 0.05,
            "include_edges": False,
            "gap_scale": 20000,
            "max_gap": 200000,
        },
    ),
    "nsl": StatDef("snp", 1, "nsl", {"min_maf": 0.05}),
    "isafe": StatDef(
        "snp",
        1,
        "isafe",
        {
            "region_size_bp": 1_000_000,
            "isafe_window": 300,
            "isafe_step": 150,
            "top_k": 1,
            "max_rank": 15,
        },
    ),
    "dind": StatDef(
        "snp",
        3,
        "dind",
        {"window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95},
    ),
    "high_freq": StatDef(
        "snp",
        3,
        "high_freq",
        {"window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95},
    ),
    "low_freq": StatDef(
        "snp",
        3,
        "low_freq",
        {"window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95},
    ),
    "s_ratio": StatDef(
        "snp",
        3,
        "s_ratio",
        {"window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95},
    ),
    "hapdaf_o": StatDef(
        "snp",
        3,
        "hapdaf_o",
        {
            "window_size": 50000,
            "min_focal_freq": 0.25,
            "max_focal_freq": 0.95,
            "max_ancest_freq": 0.25,
            "min_tot_freq": 0.25,
        },
    ),
    "hapdaf_s": StatDef(
        "snp",
        3,
        "hapdaf_s",
        {
            "window_size": 50000,
            "min_focal_freq": 0.25,
            "max_focal_freq": 0.95,
            "max_ancest_freq": 0.10,
            "min_tot_freq": 0.10,
        },
    ),
    "haf": StatDef(
        "window", 1, "haf", {"window_mode": "snp", "w_size": 201, "step": 10}
    ),
    "hscan": StatDef(
        "snp", 1, "hscan", {"max_gap": 200_000, "dist_mode": 0, "hscan_step": 1}
    ),
    # -- Sliding-window stats (SNP-count mode) --
    "h12": StatDef(
        "window", 1, "h12", {"window_mode": "snp", "w_size": 200, "step": 10}
    ),
    "garud": StatDef(
        "window", 1, "h12", {"window_mode": "snp", "w_size": 200, "step": 10}
    ),
    "lassi": StatDef(
        "window",
        1,
        "T_m",
        {
            "window_mode": "snp",
            "w_size": 201,
            "step": 10,
            "K_truncation": 10,
            "sweep_mode": 4,
        },
    ),
    "lassip": StatDef(
        "window",
        1,
        "Lambda",
        {
            "window_mode": "snp",
            "w_size": 201,
            "step": 10,
            "K_truncation": 10,
            "sweep_mode": 4,
            "max_extend": 1e5,
            "n_A": 100,
        },
    ),
    "raisd": StatDef(
        "window", 1, "mu_total", {"window_mode": "snp", "window_size": 50}
    ),
    # -- Sliding-window stats (physical bp mode) --
    "tajima_d": StatDef(
        "window",
        2,
        "tajima_d",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "pi": StatDef(
        "window",
        2,
        "pi",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "theta_w": StatDef(
        "window",
        2,
        "theta_w",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "fay_wu_h": StatDef(
        "window",
        2,
        "fay_wu_h",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "zeng_e": StatDef(
        "window",
        2,
        "zeng_e",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "achaz_y": StatDef(
        "window",
        2,
        "achaz_y",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "fuli_f": StatDef(
        "window",
        2,
        "fuli_f",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "fuli_f_star": StatDef(
        "window",
        2,
        "fuli_f_star",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "fuli_d": StatDef(
        "window",
        2,
        "fuli_d",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "fuli_d_star": StatDef(
        "window",
        2,
        "fuli_d_star",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "neutrality": StatDef(
        "window",
        2,
        "tajima_d",
        {"window_mode": "bp", "w_size_bp": 1_000_000, "step_bp": 10_000},
    ),
    "omega": StatDef(
        "window",
        1,
        "omega_max",
        {"window_mode": "bp", "w_size_bp": 100_000, "step_bp": 10_000},
    ),
    "zns": StatDef(
        "window",
        1,
        "zns",
        {"window_mode": "bp", "w_size_bp": 100_000, "step_bp": 10_000},
    ),
    "beta": StatDef(
        "window",
        2,
        "beta1",
        {"window_mode": "bp", "w_size_bp": 50_000, "step_bp": 5_000, "m": 0.1},
    ),
    "ncd": StatDef("window", 3, "ncd1", {"tf": 0.5, "w": 3000, "minIS": 2}),
}


[docs] def available_stats() -> list[str]: """Return list of all available stat keys.""" return list(STAT_REGISTRY.keys())
[docs] def stat_params(stat_key: str | None = None) -> dict: """Return parameter documentation for scan stats. Parameters ---------- stat_key : str, optional A specific stat key (e.g. ``"ihs"``, ``"hscan"``). If None, returns the full table for all stats. Returns ------- dict Keys are stat names. Each value is a dict with keys ``rank_col``, ``resolution``, ``window_mode``, ``default_window``, ``default_step``, ``shared_params``, and ``stat_params``. Notes ----- Shared params always injected by ``scan()`` for every stat: ``w_size`` (201), ``step`` (10), ``w_size_bp`` (1000000), ``step_bp`` (10000), ``min_maf`` (0.05), ``window_mode`` ("auto"). Per-stat params are passed via ``config={"stat": {"param": value}}``. Examples -------- >>> from flexsweep.scan import stat_params >>> stat_params("hscan") >>> stat_params() # full table """ _SHARED = { "w_size": 201, "step": 10, "w_size_bp": 1_000_000, "step_bp": 10_000, "min_maf": 0.05, "window_mode": "auto", } # Stat-specific params: registry defaults merged with runner-read kwargs. # These are the keys a user can pass via config={"stat": {key: val}}. _STAT_SPECIFIC: dict[str, dict] = { "ihs": {"include_edges": False, "gap_scale": 20000, "max_gap": 200000}, "nsl": {}, "isafe": { "region_size_bp": 1_000_000, "isafe_window": 300, "isafe_step": 150, "top_k": 1, "max_rank": 15, }, "dind": {"window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95}, "high_freq": { "window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95, }, "low_freq": { "window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95, }, "s_ratio": { "window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95, }, "hapdaf_o": { "window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95, "max_ancest_freq": 0.25, "min_tot_freq": 0.25, }, "hapdaf_s": { "window_size": 50000, "min_focal_freq": 0.25, "max_focal_freq": 0.95, "max_ancest_freq": 0.10, "min_tot_freq": 0.10, }, "haf": {}, "hscan": { "max_gap": 200_000, "dist_mode": 0, "hscan_step": 1, }, # note: NOT "step" — avoids shared override "h12": {}, "garud": {}, "neutrality": {}, "omega": {}, "zns": {}, "tajima_d": {}, "pi": {}, "theta_w": {}, "fay_wu_h": {}, "zeng_e": {}, "achaz_y": {}, "fuli_f": {}, "fuli_f_star": {}, "fuli_d": {}, "fuli_d_star": {}, "lassi": {"K_truncation": 10, "sweep_mode": 4}, "lassip": {"K_truncation": 10, "sweep_mode": 4, "max_extend": 1e5, "n_A": 100}, "raisd": {"window_size": 50}, "beta": {"m": 0.1}, "ncd": {"tf": 0.5, "w": 3000, "minIS": 2}, } def _entry(key): defn = STAT_REGISTRY[key] dp = defn.default_params # window_mode: per-stat default from registry, or "snp"/"bp" if set there wm = dp.get("window_mode", "n/a (per-SNP stat)") # effective window size and step shown for window stats if wm == "snp": win_default = f"{dp.get('w_size', _SHARED['w_size'])} SNPs" step_default = f"{dp.get('step', _SHARED['step'])} SNPs" elif wm == "bp": win_default = f"{dp.get('w_size_bp', _SHARED['w_size_bp']):,} bp" step_default = f"{dp.get('step_bp', _SHARED['step_bp']):,} bp" else: win_default = "n/a" step_default = "n/a" return { "rank_col": defn.rank_col, "resolution": defn.resolution, "window_mode": wm, "default_window": win_default, "default_step": step_default, "shared_params": _SHARED, "stat_params": _STAT_SPECIFIC.get(key, {}), } if stat_key is not None: if stat_key not in STAT_REGISTRY: raise ValueError( f"Unknown stat: {stat_key!r}. Available: {available_stats()}" ) return {stat_key: _entry(stat_key)} return {k: _entry(k) for k in STAT_REGISTRY}
# Utilities def _sliding_windows(positions, w_size: int, step: int): """Yield (start_idx, end_idx, center_pos) for SNP-count sliding windows.""" n = len(positions) for i in range(0, n - w_size + 1, step): center = int(positions[i + w_size // 2]) yield i, i + w_size, center def _bp_windows(positions, w_size_bp: int, step_bp: int): """Yield (start_idx, end_idx, center_pos) for physical bp sliding windows.""" max_pos = int(positions[-1]) win_start = int(positions[0]) while win_start <= max_pos: win_end = win_start + w_size_bp i = int(np.searchsorted(positions, win_start)) j = int(np.searchsorted(positions, win_end, side="right")) if j - i >= 2: yield i, j, win_start + w_size_bp // 2 win_start += step_bp if win_start >= max_pos: break def _get_windows(positions, params): """Dispatch to SNP-count or bp sliding windows based on window_mode in params.""" mode = params.get("window_mode", "snp") if mode == "bp": w_size_bp = int(params.get("w_size_bp", 1_000_000)) step_bp = int(params.get("step_bp", 10_000)) return _bp_windows(positions, w_size_bp, step_bp) else: w_size = int(params.get("w_size", 201)) step = int(params.get("step", 10)) return _sliding_windows(positions, w_size, step) def _snp_cm_mb(positions: np.ndarray, rec_map: np.ndarray) -> np.ndarray: """Local recombination rate (cM/Mb) at each SNP position. Assigns each SNP to the rec_map segment it falls in and returns that segment's rate (Δ cM / (Δ bp / 1e6)). No arbitrary window size needed. Parameters ---------- positions : 1D int64 array of SNP physical positions (bp) rec_map : numpy array from genome_reader; col 0 = bp, last col = cumulative cM Returns ------- 1D float64 array, length == len(positions), values in cM/Mb (>= 0). """ map_pos = rec_map[:, 0].astype(np.float64) map_cm = rec_map[:, -1].astype(np.float64) delta_bp = np.diff(map_pos) delta_cm = np.diff(map_cm) with np.errstate(divide="ignore", invalid="ignore"): seg_rates = np.where(delta_bp > 0, delta_cm / (delta_bp / 1e6), 0.0) seg_rates = np.maximum(0.0, seg_rates) # Assign each SNP to the leftmost segment that starts at or before its position idx = np.clip( np.searchsorted(map_pos, positions, side="right") - 1, 0, len(seg_rates) - 1, ) return seg_rates[idx] def _normalize_daf_bins( values: np.ndarray, daf: np.ndarray, recomb: np.ndarray | None = None, n_daf_bins: int = 50, n_r_bins: int | None = None, ) -> np.ndarray: """Z-score values within genome-wide DAF bins (+ recomb bins if provided). When ``recomb`` is provided, creates a joint (DAF × recomb_rate) grid: ``n_daf_bins`` equal-frequency DAF bins × ``n_r_bins`` equal-frequency recombination rate bins (Johnson et al. approach: 10 r_bins default). """ if recomb is not None: daf_edges = np.nanpercentile(daf, np.linspace(0, 100, n_daf_bins + 1)) daf_edges[0] -= 1e-10 r_edges = np.nanpercentile(recomb, np.linspace(0, 100, n_r_bins + 1)) r_edges[0] -= 1e-10 daf_bin = np.clip(np.digitize(daf, daf_edges) - 1, 0, n_daf_bins - 1) r_bin = np.clip(np.digitize(recomb, r_edges) - 1, 0, n_r_bins - 1) bin_key = daf_bin * n_r_bins + r_bin else: edges = np.nanpercentile(daf, np.linspace(0, 100, n_daf_bins + 1)) edges[0] -= 1e-10 bin_key = np.clip(np.digitize(daf, edges) - 1, 0, n_daf_bins - 1) normalized = np.full_like(values, np.nan, dtype=np.float64) for b in np.unique(bin_key): mask = bin_key == b if mask.sum() < 2: continue v = values[mask].astype(np.float64) mu, std = np.nanmean(v), np.nanstd(v) if std > 0: normalized[mask] = (v - mu) / std return normalized
[docs] def empirical_pvalues( df: pl.DataFrame, stat_col: str, abs_rank: bool = False ) -> pl.DataFrame: """Empirical p-value following the empirical outlier approach (Akey 2009). p = rank(−value, na.last=keep) / N_valid Small p → outlier/candidate (locus in extreme upper tail of the empirical distribution). NaN → null in output (excluded from ranking and from N_valid denominator). abs_rank=True: rank by ``abs(value)`` before negating — for signed stats where large magnitude in either direction signals selection (iHS, nSL, Tajima's D). """ s = df[stat_col] if abs_rank: s = s.abs() n_valid = int(s.is_not_null().sum() - s.is_nan().sum()) # Negative: largest original value → most negative → rank 1 → p = 1/N_valid ≈ 0 (outlier) p_emp = (-s).fill_nan(None).rank(method="average") / n_valid return df.with_columns(p_emp.alias(f"{stat_col}_pvalue"))
# Per-SNP stat runners def _run_ihs(hap, positions, ac, rec_map, genetic_pos, **params): min_maf = params.get("min_maf", 0.05) include_edges = params.get("include_edges", False) gap_scale = params.get("gap_scale", 20000) max_gap = params.get("max_gap", 200000) map_pos = genetic_pos if genetic_pos is not None else None df = ihs_ihh( hap, positions, map_pos=map_pos, min_maf=min_maf, min_ehh=0.05, include_edges=include_edges, gap_scale=gap_scale, max_gap=max_gap, use_threads=False, ) if df is None or len(df) == 0: return None return df.select(["positions", "daf", "ihs"]).rename({"positions": "pos"}) def _run_nsl(hap, positions, ac, rec_map, genetic_pos, **params): min_maf = params.get("min_maf", 0.05) freqs = ac[:, 1] / ac.sum(axis=1) mask = (freqs >= min_maf) & (freqs <= 1 - min_maf) if mask.sum() < 2: return None nsl_vals = nsl(hap[mask], use_threads=False) return pl.DataFrame( { "pos": positions[mask].tolist(), "daf": freqs[mask].tolist(), "nsl": nsl_vals.tolist(), } ) def _run_isafe(hap, positions, ac, rec_map, genetic_pos, **params): region_size_bp = int(params.get("region_size_bp", 1_000_000)) isafe_window = params.get("isafe_window", 300) isafe_step = params.get("isafe_step", 150) top_k = params.get("top_k", 1) max_rank = params.get("max_rank", 15) pos = positions results = [] region_start = int(pos[0]) max_pos = int(pos[-1]) while region_start <= max_pos: region_end = region_start + region_size_bp mask = (pos >= region_start) & (pos < region_end) if mask.sum() >= 300: df_r = run_isafe( hap[mask], pos[mask], window=isafe_window, step=isafe_step, top_k=top_k, max_rank=max_rank, ) if df_r is not None and len(df_r) > 0: results.append( df_r.select(["positions", "daf", "isafe"]).rename( {"positions": "pos"} ) ) region_start = region_end return pl.concat(results) if results else None def _run_dind(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50000) min_focal_freq = params.get("min_focal_freq", 0.25) max_focal_freq = params.get("max_focal_freq", 0.95) sq_freqs, info, _ = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return None r_dind, r_high, r_low = dind_high_low_from_pairs(sq_freqs, info) df, _, _, _ = fs_stats_dataframe(info, r_dind, r_high, r_low, [], [], []) if df is None or len(df) == 0: return None return df.select(["positions", "daf", "dind", "high_freq", "low_freq"]).rename( {"positions": "pos"} ) def _run_high_freq(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50000) min_focal_freq = params.get("min_focal_freq", 0.25) max_focal_freq = params.get("max_focal_freq", 0.95) sq_freqs, info, _ = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return None r_dind, r_high, r_low = dind_high_low_from_pairs(sq_freqs, info) df, _, _, _ = fs_stats_dataframe(info, r_dind, r_high, r_low, [], [], []) if df is None or len(df) == 0: return None return df.select(["positions", "daf", "high_freq"]).rename({"positions": "pos"}) def _run_low_freq(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50000) min_focal_freq = params.get("min_focal_freq", 0.25) max_focal_freq = params.get("max_focal_freq", 0.95) sq_freqs, info, _ = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return None r_dind, r_high, r_low = dind_high_low_from_pairs(sq_freqs, info) df, _, _, _ = fs_stats_dataframe(info, r_dind, r_high, r_low, [], [], []) if df is None or len(df) == 0: return None return df.select(["positions", "daf", "low_freq"]).rename({"positions": "pos"}) def _run_s_ratio(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50000) min_focal_freq = params.get("min_focal_freq", 0.25) max_focal_freq = params.get("max_focal_freq", 0.95) sq_freqs, info, _ = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return None r_s = s_ratio_from_pairs(sq_freqs) _, df, _, _ = fs_stats_dataframe(info, [], [], [], r_s, [], []) if df is None or len(df) == 0: return None return df.select(["positions", "daf", "s_ratio"]).rename({"positions": "pos"}) def _run_hapdaf_o(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50000) min_focal_freq = params.get("min_focal_freq", 0.25) max_focal_freq = params.get("max_focal_freq", 0.95) max_ancest_freq = params.get("max_ancest_freq", 0.25) min_tot_freq = params.get("min_tot_freq", 0.25) sq_freqs, info, _ = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return None r_hdo = hapdaf_from_pairs(sq_freqs, max_ancest_freq, min_tot_freq) _, _, df_o, _ = fs_stats_dataframe(info, [], [], [], [], r_hdo, []) if df_o is None or len(df_o) == 0: return None return df_o.select(["positions", "daf", "hapdaf_o"]).rename({"positions": "pos"}) def _run_hapdaf_s(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50000) min_focal_freq = params.get("min_focal_freq", 0.25) max_focal_freq = params.get("max_focal_freq", 0.95) max_ancest_freq = params.get("max_ancest_freq", 0.10) min_tot_freq = params.get("min_tot_freq", 0.10) sq_freqs, info, _ = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return None r_hds = hapdaf_from_pairs(sq_freqs, max_ancest_freq, min_tot_freq) _, _, _, df_s = fs_stats_dataframe(info, [], [], [], [], [], r_hds) if df_s is None or len(df_s) == 0: return None return df_s.select(["positions", "daf", "hapdaf_s"]).rename({"positions": "pos"}) def _run_haf(hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params): if _single_window is not None: i, j, center = _single_window val = haf_top(hap[i:j], positions[i:j]) return pl.DataFrame({"pos": [center], "haf": [float(val)]}) rows = [] for i, j, center in _get_windows(positions, params): val = haf_top(hap[i:j], positions[i:j]) rows.append({"pos": center, "haf": float(val)}) if not rows: return None return pl.DataFrame(rows) def _run_hscan(hap, positions, ac, rec_map, genetic_pos, **params): max_gap = params.get("max_gap", 200_000) dist_mode = params.get("dist_mode", 0) step = params.get("hscan_step", 1) pos_out, h_out = hscan( hap, positions, max_gap=max_gap, dist_mode=dist_mode, step=step ) if len(pos_out) == 0: return None return pl.DataFrame({"pos": pos_out.astype(int).tolist(), "hscan": h_out.tolist()}) # Sliding-window stat runners def _run_h12(hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params): if _single_window is not None: h12_val, h2_h1, h1, h123, _n = garud_h(hap) return pl.DataFrame( [ { "pos": _single_window, "n_snps": len(positions), "h12": float(h12_val), "h2_h1": float(h2_h1), } ] ) rows = [] for i, j, center in _get_windows(positions, params): h12_val, h2_h1, h1, h123, _n = garud_h(hap[i:j]) rows.append( { "pos": center, "n_snps": j - i, "h12": float(h12_val), "h2_h1": float(h2_h1), } ) return pl.DataFrame(rows) if rows else None def _run_garud(hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params): if _single_window is not None: h12_val, h2_h1, h1, h123, _n = garud_h(hap) return pl.DataFrame( [ { "pos": _single_window, "n_snps": len(positions), "h1": float(h1), "h12": float(h12_val), "h2_h1": float(h2_h1), } ] ) rows = [] for i, j, center in _get_windows(positions, params): h12_val, h2_h1, h1, h123, _n = garud_h(hap[i:j]) rows.append( { "pos": center, "n_snps": j - i, "h1": float(h1), "h12": float(h12_val), "h2_h1": float(h2_h1), } ) return pl.DataFrame(rows) if rows else None def _run_neutrality( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): if _single_window is not None: res = neutrality_stats(ac.astype(np.int32), positions) return pl.DataFrame( [ { "pos": _single_window, "n_snps": len(positions), "tajima_d": float(res[0]), "fay_wu_h_norm": float(res[3]), "pi": float(res[4]), "theta_w": float(res[5]), } ] ) rows = [] for i, j, center in _get_windows(positions, params): res = neutrality_stats(ac[i:j].astype(np.int32), positions[i:j]) rows.append( { "pos": center, "n_snps": j - i, "tajima_d": float(res[0]), "fay_wu_h_norm": float(res[3]), "pi": float(res[4]), "theta_w": float(res[5]), } ) return pl.DataFrame(rows) if rows else None def _run_omega(hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params): if _single_window is not None: hap_f = np.ascontiguousarray(hap.astype(np.float64)) r2 = compute_r2_matrix_upper(hap_f) omega = omega_linear_correct(r2) return pl.DataFrame( [ { "pos": _single_window, "n_snps": len(positions), "omega_max": float(omega), } ] ) rows = [] for i, j, center in _get_windows(positions, params): hap_f = np.ascontiguousarray(hap[i:j].astype(np.float64)) r2 = compute_r2_matrix_upper(hap_f) omega = omega_linear_correct(r2) rows.append({"pos": center, "n_snps": j - i, "omega_max": float(omega)}) return pl.DataFrame(rows) if rows else None def _run_zns(hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params): if _single_window is not None: hap_f = np.ascontiguousarray(hap.astype(np.float64)) zns_val, _ = Ld(hap_f) return pl.DataFrame( [{"pos": _single_window, "n_snps": len(positions), "zns": float(zns_val)}] ) rows = [] for i, j, center in _get_windows(positions, params): hap_f = np.ascontiguousarray(hap[i:j].astype(np.float64)) zns_val, _ = Ld(hap_f) rows.append({"pos": center, "n_snps": j - i, "zns": float(zns_val)}) return pl.DataFrame(rows) if rows else None def _run_lassi_scan(hap, positions, ac, rec_map, genetic_pos, **params): K_truncation = params.get("K_truncation", 10) w_size = params.get("w_size", 201) step = params.get("step", 10) sweep_mode = params.get("sweep_mode", 4) hap_data = [hap, positions] K_counts, K_spectrum, windows_lassi = LASSI_spectrum_and_Kspectrum( hap_data, K_truncation, w_size, int(step) ) K_neutral = neut_average(np.vstack(K_spectrum)) t_m = T_m_statistic_fast( K_counts, K_neutral, windows_lassi, K_truncation, sweep_mode=sweep_mode ) return t_m.select( [ pl.col("window_lassi").cast(pl.Int64).alias("pos"), pl.col("T").alias("T_m"), pl.col("m"), pl.col("frequency").alias("epsilon"), ] ) def _run_lassip_scan(hap, positions, ac, rec_map, genetic_pos, nthreads=1, **params): K_truncation = params.get("K_truncation", 10) w_size = params.get("w_size", 201) step = params.get("step", 10) sweep_mode = params.get("sweep_mode", 4) max_extend = params.get("max_extend", 1e5) n_A = params.get("n_A", 100) K_counts, K_spectrum, windows_centers = LASSI_spectrum_and_Kspectrum( [hap, positions], K_truncation, w_size, int(step) ) K_neutral = neut_average(np.vstack(K_spectrum)) return ( Lambda_statistic_fast( K_counts, K_neutral, windows_centers, K_truncation, n_A=n_A, sweep_mode=sweep_mode, nthreads=nthreads, max_extend=max_extend, ) .rename({"window_lassip": "pos"}) .select(pl.exclude("iter")) ) def _run_raisd(hap, positions, ac, rec_map, genetic_pos, **params): window_size = params.get("window_size", 50) return mu_stat(hap, positions, window_size).rename({"positions": "pos"}) def _run_beta(hap, positions, ac, rec_map, genetic_pos, **params): m = params.get("m", 0.1) # use run_beta_window which takes ac and positions directly out = run_beta_window(ac, positions, m=m) if out is None or out.shape[0] == 0: return None return pl.DataFrame( { "pos": out[:, 0].astype(int).tolist(), "beta1": out[:, 1].tolist(), "beta2": out[:, 2].tolist(), } ) def _run_ncd(hap, positions, ac, rec_map, genetic_pos, **params): tf = params.get("tf", 0.5) w = params.get("w", 3000) minIS = params.get("minIS", 2) n_hap = ac.sum(axis=1) freqs = ac[:, 1] / n_hap # ncd1 returns results[valid_mask] with no positions — recompute valid window centers w1 = w / 2.0 start_positions = np.arange(positions[0], positions[-1], w1) n_snps = len(positions) valid_centers = [] j_start = 0 j_end = 0 for widx in range(len(start_positions)): start = start_positions[widx] end = start + w while j_start < n_snps and positions[j_start] < start: j_start += 1 while j_end < n_snps and positions[j_end] <= end: j_end += 1 if j_end - j_start >= minIS: valid_centers.append(int(start + w / 2.0)) ncd_vals = ncd1(positions, freqs, tf=tf, w=w, minIS=minIS) if len(valid_centers) == 0 or len(ncd_vals) == 0: return None n_valid = min(len(valid_centers), len(ncd_vals)) return pl.DataFrame( {"pos": valid_centers[:n_valid], "ncd1": ncd_vals[:n_valid].tolist()} ) # Individual neutrality runners (bp-window mode) def _run_single_neutrality_array( hap, positions, ac, params, col_name, arr_idx, _single_window=None ): """Run per-window, extract from neutrality_stats() array by index.""" if _single_window is not None: res = neutrality_stats(ac.astype(np.int32), positions) return pl.DataFrame( [ { "pos": _single_window, "n_snps": len(positions), col_name: float(res[arr_idx]), } ] ) rows = [] for i, j, center in _get_windows(positions, params): res = neutrality_stats(ac[i:j].astype(np.int32), positions[i:j]) rows.append({"pos": center, "n_snps": j - i, col_name: float(res[arr_idx])}) return pl.DataFrame(rows) if rows else None def _run_single_neutrality_fn( hap, positions, ac, params, col_name, fn, _single_window=None ): """Run per-window, call individual function fn(ac_win[, pos_win]).""" import inspect _fn_nparams = len(inspect.signature(fn).parameters) def _call(ac_win, pos_win): return float(fn(ac_win, pos_win) if _fn_nparams >= 2 else fn(ac_win)) if _single_window is not None: val = _call(ac, positions) return pl.DataFrame( [{"pos": _single_window, "n_snps": len(positions), col_name: val}] ) rows = [] for i, j, center in _get_windows(positions, params): val = _call(ac[i:j], positions[i:j]) rows.append({"pos": center, "n_snps": j - i, col_name: val}) return pl.DataFrame(rows) if rows else None def _run_tajima_d( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_array( hap, positions, ac, params, "tajima_d", 0, _single_window=_single_window ) def _run_pi(hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params): return _run_single_neutrality_array( hap, positions, ac, params, "pi", 4, _single_window=_single_window ) def _run_theta_w( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_array( hap, positions, ac, params, "theta_w", 5, _single_window=_single_window ) def _run_fay_wu_h( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "fay_wu_h", fay_wu_h_norm, _single_window=_single_window, ) def _run_zeng_e( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "zeng_e", zeng_e, _single_window=_single_window ) def _run_achaz_y( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "achaz_y", achaz_y, _single_window=_single_window ) def _run_fuli_f( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "fuli_f", fuli_f, _single_window=_single_window ) def _run_fuli_f_star( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "fuli_f_star", fuli_f_star, _single_window=_single_window, ) def _run_fuli_d( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "fuli_d", fuli_d, _single_window=_single_window ) def _run_fuli_d_star( hap, positions, ac, rec_map, genetic_pos, _single_window=None, **params ): return _run_single_neutrality_fn( hap, positions, ac, params, "fuli_d_star", fuli_d_star, _single_window=_single_window, ) # Map from stat key to runner function _RUNNERS = { "ihs": _run_ihs, "nsl": _run_nsl, "isafe": _run_isafe, "dind": _run_dind, "high_freq": _run_high_freq, "low_freq": _run_low_freq, "s_ratio": _run_s_ratio, "hapdaf_o": _run_hapdaf_o, "hapdaf_s": _run_hapdaf_s, "haf": _run_haf, "hscan": _run_hscan, "h12": _run_h12, "garud": _run_garud, "tajima_d": _run_tajima_d, "pi": _run_pi, "theta_w": _run_theta_w, "fay_wu_h": _run_fay_wu_h, "zeng_e": _run_zeng_e, "achaz_y": _run_achaz_y, "fuli_f": _run_fuli_f, "fuli_f_star": _run_fuli_f_star, "fuli_d": _run_fuli_d, "fuli_d_star": _run_fuli_d_star, "neutrality": _run_neutrality, "omega": _run_omega, "zns": _run_zns, "lassi": _run_lassi_scan, "lassip": _run_lassip_scan, "raisd": _run_raisd, "beta": _run_beta, "ncd": _run_ncd, } # Stats that need DAF-bin normalization after computation _NORMALIZE_BY_DAF = { "ihs", "nsl", "dind", "high_freq", "low_freq", "s_ratio", "hapdaf_o", "hapdaf_s", } # Stats where rank is by absolute value (signed stats) _ABS_RANK = { "ihs", "nsl", "tajima_d", "zeng_e", "fay_wu_h", "neutrality", } # Stat category sets for global task-pool parallelism # Per-SNP flat: one task per stat per chromosome (whole chromosome in one shot) _SNP_FLAT = { "ihs", "nsl", "dind", "high_freq", "low_freq", "s_ratio", "hapdaf_o", "hapdaf_s", "hscan", } # Per-SNP regional: isafe runs as non-overlapping 1–5 Mb chunks, one task per region _SNP_REGIONAL = {"isafe"} # Window batchable: windows are independent → split into nthreads batches per chromosome _WINDOW_BATCHABLE = { "h12", "garud", "haf", "neutrality", "omega", "zns", "tajima_d", "pi", "theta_w", "fay_wu_h", "zeng_e", "achaz_y", "fuli_f", "fuli_f_star", "fuli_d", "fuli_d_star", } # Window whole: stat needs all windows at once (LASSI spatial kernel, RAiSD, beta/ncd) _WINDOW_WHOLE = {"lassi", "lassip", "raisd", "beta", "ncd"} # Parallelism helpers (called from global task pool in scan()) def _region_bounds(positions, region_size_bp: int): """Yield non-overlapping (lo, hi) bp bounds covering all positions.""" lo = int(positions[0]) max_pos = int(positions[-1]) while lo <= max_pos: yield lo, lo + region_size_bp lo += region_size_bp def _run_isafe_region(hap_slice, pos_slice, ac_slice, params): """Run isafe on a pre-sliced non-overlapping chromosome region (one task).""" isafe_window = params.get("isafe_window", 300) isafe_step = params.get("isafe_step", 150) top_k = params.get("top_k", 1) max_rank = params.get("max_rank", 15) if len(pos_slice) < 300: return None df_r = run_isafe( hap_slice, pos_slice, window=isafe_window, step=isafe_step, top_k=top_k, max_rank=max_rank, ) if df_r is None or len(df_r) == 0: return None return df_r.select(["positions", "daf", "isafe"]).rename({"positions": "pos"}) def _run_window_batch( stat_key, hap, positions, ac, rec_map, genetic_pos, window_batch, **params ): """Run a batchable window stat on a pre-enumerated list of (i, j, center_pos) windows. Called from the global joblib pool. Each window is computed independently. The runner is called with pre-sliced arrays and _single_window=center_pos so it skips its internal _get_windows loop and computes directly on the slice. """ runner = _RUNNERS[stat_key] parts = [] for i, j, center_pos in window_batch: try: df = runner( hap[i:j], positions[i:j], ac[i:j], rec_map, genetic_pos, _single_window=center_pos, **params, ) except Exception: continue if df is not None and len(df) > 0: parts.append(df) return pl.concat(parts) if parts else None # Main scan function
[docs] def scan( vcf_path, out_prefix, stats, config=None, w_size=201, step=10, w_size_bp=1_000_000, step_bp=10_000, min_maf=0.05, recombination_map=None, n_daf_bins=50, n_r_bins=None, nthreads=1, window_mode="auto", **kwargs, ) -> dict[str, pl.DataFrame]: """Standalone outlier scan from a directory of per-chromosome VCF files. Uses a global task pool across ALL VCF files simultaneously: all chromosomes are pre-loaded, then one ``Parallel(n_jobs=nthreads)`` call processes every stat × chromosome combination. This fully exploits nthreads regardless of how many chromosomes are present. Parameters ---------- vcf_path: Directory containing ``*.vcf.gz`` (or ``*.bcf.gz``) files, one per chromosome/contig. Must be a directory; single-file input is not supported. out_prefix: Output file prefix. Writes ``{out_prefix}.{stat}.txt`` for each stat. stats: List of stat keys to compute. See ``available_stats()`` for options. Per-SNP: ihs, nsl, isafe, dind, high_freq, low_freq, s_ratio, hapdaf_o, hapdaf_s. SNP-window: h12, garud, lassi, lassip, raisd. bp-window: tajima_d, pi, theta_w, fay_wu_h, zeng_e, achaz_y, fuli_f, fuli_f_star, fuli_d, fuli_d_star, neutrality, omega, beta, ncd. config: Per-stat parameter overrides, e.g. ``{"raisd": {"window_size": 100}, "lassip": {"max_extend": 5e4}}``. Overrides ``w_size``, ``step``, and any kwargs for that stat only. w_size: SNP-count window size for SNP-mode sliding-window stats (default 201). step: SNP step for SNP-mode sliding-window stats (default 10). w_size_bp: Physical window size in bp for bp-mode stats (default 1 Mb). step_bp: Physical step size in bp for bp-mode stats (default 10 kb). min_maf: Minimum minor allele frequency for iHS and nSL (default 0.05). recombination_map: Path to recombination map TSV (chr, start, end, cm_mb, cm). If provided: genetic distances used for T3 stat windows, and frequency-sensitive stats (iHS, nSL, dind, …) are normalized by joint (DAF × recomb_rate) bins (Johnson et al. approach). n_daf_bins: Number of equal-frequency DAF bins for normalization (default 50). n_r_bins: Number of equal-frequency recombination rate bins for joint (DAF × r_bins) normalization. ``None`` (default) → DAF-only normalization. Set to 10 (Johnson et al.) to enable joint normalization when ``recombination_map`` is provided. nthreads: Total worker threads for the global task pool (default 1). Window-batchable stats split each chromosome into ``nthreads`` window batches so every thread stays busy. window_mode: Override window mode for all sliding-window stats. "auto" (default) uses per-stat defaults from STAT_REGISTRY. "snp" forces SNP-count windows for all window stats. "bp" forces physical bp windows for all window stats. kwargs: Shared overrides forwarded to all stats: max_extend, K_truncation, sweep_mode, raisd_window, tf, etc. Returns ------- dict[str, polars.DataFrame] Keys are stat names; each DataFrame has a ``{rank_col}_pvalue`` column. Files written to ``{out_prefix}.{stat}.txt`` (tab-separated). """ unknown = [s for s in stats if s not in STAT_REGISTRY] if unknown: raise ValueError(f"Unknown stats: {unknown}. Available: {available_stats()}") config = config or {} # vcf_path must be a directory if not os.path.isdir(vcf_path): raise ValueError( f"vcf_path must be a directory of *.vcf.gz files, got: {vcf_path!r}. " "Use --vcf_path pointing to a directory." ) vcf_files = sorted( glob.glob(os.path.join(vcf_path, "*.vcf.gz")) + glob.glob(os.path.join(vcf_path, "*.bcf.gz")) ) if not vcf_files: raise FileNotFoundError(f"No *.vcf.gz or *.bcf.gz files found in {vcf_path}") def _make_params(stat_key): """Merge params: registry defaults < shared < per-stat config.""" p = {**STAT_REGISTRY[stat_key].default_params} if window_mode != "auto": p["window_mode"] = window_mode p.update( { "w_size": w_size, "step": step, "w_size_bp": w_size_bp, "step_bp": step_bp, "min_maf": min_maf, } ) p.update(config.get(stat_key, {})) return p # ------------------------------------------------------------------ # Phase 1: Pre-load all chromosomes sequentially (genome_reader uses # pysam which is not thread-safe, so this must stay sequential). # ------------------------------------------------------------------ chrom_data: dict = ( {} ) # chrom → (hap_int, rec_map, ac, positions, genetic_pos, recomb_vals) for vcf_file in vcf_files: hap_int, rec_map, ac, _, position_masked, genetic_pos = genome_reader( vcf_file, recombination_map=recombination_map ) chrom = str(int(rec_map[0, 0])) recomb_vals = ( _snp_cm_mb(position_masked, rec_map) if recombination_map is not None else None ) chrom_data[chrom] = ( hap_int, rec_map, ac, position_masked, genetic_pos, recomb_vals, ) # ------------------------------------------------------------------ # Phase 2: Build one global flat task list across ALL chromosomes # and ALL stats. Three task categories: # snp — 1 task per stat per chromosome (whole-chromosome runner) # isafe — 1 task per non-overlapping 1-5 Mb region per chromosome # win_batch — window batches (nthreads tasks per chrom per batchable stat) # win_whole — 1 whole-chromosome task per chromosome (lassi/lassip/raisd/…) # ------------------------------------------------------------------ all_tasks: list = [] all_labels: list = [] for chrom, ( hap_int, rec_map, ac, position_masked, genetic_pos, _, ) in chrom_data.items(): for stat_key in stats: params = _make_params(stat_key) if stat_key in _SNP_REGIONAL: # isafe: non-overlapping region tasks if "region_size_bp" not in params: raise KeyError( f"Stat '{stat_key}' (snp-regional) requires 'region_size_bp' " f"in STAT_REGISTRY default_params or config override." ) region_size_bp = int(params.get("region_size_bp", 1_000_000)) for region_idx, (lo, hi) in enumerate( _region_bounds(position_masked, region_size_bp) ): mask = (position_masked >= lo) & (position_masked < hi) if mask.sum() >= 300: all_tasks.append( delayed(_run_isafe_region)( hap_int[mask], position_masked[mask], ac[mask], params ) ) all_labels.append(("isafe", chrom, region_idx)) elif stat_key in _SNP_FLAT: # Per-SNP flat: one whole-chromosome task per stat all_tasks.append( delayed(_RUNNERS[stat_key])( hap_int, position_masked, ac, rec_map, genetic_pos, **params ) ) all_labels.append(("snp", chrom, stat_key)) elif stat_key in _WINDOW_BATCHABLE: # Split windows into up to nthreads batches per chromosome wm = params.get("window_mode", "snp") if wm == "bp" and ( "w_size_bp" not in params or "step_bp" not in params ): raise KeyError( f"Stat '{stat_key}' uses window_mode='bp' but 'w_size_bp' or " f"'step_bp' are missing from STAT_REGISTRY default_params." ) if wm != "bp" and ("w_size" not in params or "step" not in params): raise KeyError( f"Stat '{stat_key}' uses window_mode='{wm}' but 'w_size' or " f"'step' are missing from STAT_REGISTRY default_params." ) all_windows = list(_get_windows(position_masked, params)) if not all_windows: continue chunk_size = max(1, ceil(len(all_windows) / nthreads)) for batch_idx, start in enumerate( range(0, len(all_windows), chunk_size) ): batch = all_windows[start : start + chunk_size] all_tasks.append( delayed(_run_window_batch)( stat_key, hap_int, position_masked, ac, rec_map, genetic_pos, batch, **params, ) ) all_labels.append(("win_batch", chrom, stat_key, batch_idx)) else: # _WINDOW_WHOLE: whole-chromosome task (lassi, lassip, raisd, beta, ncd) # Pass nthreads=1 to avoid nested parallelism inside the pool. all_tasks.append( delayed(_RUNNERS[stat_key])( hap_int, position_masked, ac, rec_map, genetic_pos, nthreads=1, **params, ) ) all_labels.append(("win_whole", chrom, stat_key)) # ------------------------------------------------------------------ # Phase 3: Single global Parallel call — exploits all nthreads across # every chromosome and stat simultaneously. # ------------------------------------------------------------------ with warnings.catch_warnings(): warnings.simplefilter("ignore") task_results = Parallel(n_jobs=nthreads, backend="loky", verbose=2)(all_tasks) # ------------------------------------------------------------------ # Phase 4: Collect results and reassemble into raw_per_stat dict. # ------------------------------------------------------------------ raw_per_stat: dict[str, list] = {s: [] for s in stats} isafe_parts: dict[str, list] = {} # chrom → list of region DataFrames win_batch_parts: dict[tuple, list] = {} # (chrom, stat) → list of batch DataFrames for label, result in zip(all_labels, task_results): if result is None or (hasattr(result, "__len__") and len(result) == 0): continue kind = label[0] chrom = label[1] if kind == "snp": stat_key = label[2] df = result.with_columns(pl.lit(chrom).alias("chrom")).select( ["chrom", "pos"] + [c for c in result.columns if c not in ("chrom", "pos")] ) pos_masked = chrom_data[chrom][3] recomb_vals = chrom_data[chrom][5] raw_per_stat[stat_key].append((df, pos_masked, recomb_vals)) elif kind == "isafe": df = result.with_columns(pl.lit(chrom).alias("chrom")) isafe_parts.setdefault(chrom, []).append(df) elif kind == "win_batch": stat_key = label[2] df = result.with_columns(pl.lit(chrom).alias("chrom")).select( ["chrom", "pos"] + [c for c in result.columns if c not in ("chrom", "pos")] ) win_batch_parts.setdefault((chrom, stat_key), []).append(df) elif kind == "win_whole": stat_key = label[2] df = result.with_columns(pl.lit(chrom).alias("chrom")).select( ["chrom", "pos"] + [c for c in result.columns if c not in ("chrom", "pos")] ) pos_masked = chrom_data[chrom][3] recomb_vals = chrom_data[chrom][5] raw_per_stat[stat_key].append((df, pos_masked, recomb_vals)) # Consolidate isafe: regions are non-overlapping, just concat per chromosome if "isafe" in stats: for chrom, parts in isafe_parts.items(): df_iso = pl.concat(parts).select( ["chrom", "pos"] + [c for c in parts[0].columns if c not in ("chrom", "pos")] ) pos_masked = chrom_data[chrom][3] recomb_vals = chrom_data[chrom][5] raw_per_stat["isafe"].append((df_iso, pos_masked, recomb_vals)) # Consolidate window batches: concat all batches per (chrom, stat) for (chrom, stat_key), parts in win_batch_parts.items(): df_win = pl.concat(parts) pos_masked = chrom_data[chrom][3] recomb_vals = chrom_data[chrom][5] raw_per_stat[stat_key].append((df_win, pos_masked, recomb_vals)) # ------------------------------------------------------------------ # Phase 5: Genome-wide DAF normalization, ranking, and output writing. # ------------------------------------------------------------------ results: dict[str, pl.DataFrame] = {} for stat_key in stats: if not raw_per_stat[stat_key]: continue defn = STAT_REGISTRY[stat_key] rank_col = defn.rank_col df_all = pl.concat([t[0] for t in raw_per_stat[stat_key]]) # Genome-wide DAF-bin normalization for frequency-sensitive per-SNP stats if stat_key in _NORMALIZE_BY_DAF and "daf" in df_all.columns: daf = df_all["daf"].to_numpy() recomb_for_norm = None if recombination_map is not None and n_r_bins is not None: aligned_parts = [] for df_contig, pos_masked, rec_vals in raw_per_stat[stat_key]: if rec_vals is None: aligned_parts.append(np.full(len(df_contig), np.nan)) else: pos_arr = df_contig["pos"].to_numpy() idx = np.clip( np.searchsorted(pos_masked, pos_arr), 0, len(rec_vals) - 1, ) aligned_parts.append(rec_vals[idx]) recomb_for_norm = np.concatenate(aligned_parts) if rank_col in df_all.columns: vals = df_all[rank_col].to_numpy().astype(np.float64) normalized = _normalize_daf_bins( vals, daf, recomb_for_norm, n_daf_bins, n_r_bins ) df_all = df_all.with_columns(pl.Series(rank_col, normalized)) if rank_col in df_all.columns: df_all = empirical_pvalues( df_all, rank_col, abs_rank=(stat_key in _ABS_RANK) ) df_all.write_csv(f"{out_prefix}.{stat_key}.txt", separator="\t") results[stat_key] = df_all.sort("chrom", "pos") return results