Source code for flexsweep.fv

import glob
import gzip
import heapq
import os
import pickle
import re
from collections import OrderedDict, defaultdict, namedtuple
from contextlib import contextmanager
from copy import deepcopy
from functools import lru_cache, partial, reduce
from itertools import product
from math import ceil, comb
from multiprocessing.pool import ThreadPool
from warnings import filterwarnings, warn

from allel import nsl
from allel.compat import memoryview_safe
from allel.opt.stats import ihh01_scan
from allel.stats.selection import compute_ihh_gaps
from allel.util import asarray_ndim, check_dim0_aligned, check_integer_dtype

os.environ.setdefault("NUMBA_CACHE_DIR", os.path.expanduser("~/.cache/numba/flexsweep"))
from numba import float32, float64, int32, int64, njit, prange, uint64
from numba.typed import Dict

from . import np, pl

filterwarnings("ignore", message="invalid INFO header", module="allel.io.vcf_read")
filterwarnings(
    "ignore",
    category=RuntimeWarning,
    message="invalid value encountered in scalar divide",
)
np.seterr(divide="ignore", invalid="ignore")

# Define the inner namedtuple structure
summaries = namedtuple("summaries", ["stats", "parameters"])
binned_stats = namedtuple("binned_stats", ["mean", "std"])


################## Utils


def load_pickle(f):
    with open(f, "rb") as handle:
        return pickle.load(handle)


def save_pickle(f, data):
    with open(f, "wb") as handle:
        pickle.dump(data, handle)


def reset_sims_bins(results, r_bins=None, nthreads=1):
    from . import Parallel, delayed

    def make_r_bins(params):
        if r_bins is None:
            return None
        return (
            pl.DataFrame({"r": params[:, -1] * 1e8})
            .with_row_index(name="iter", offset=1)
            .with_columns(pl.col("r").cut(breaks=r_bins).alias("r_bins"))
            .select(["iter", "r_bins"])
        )

    def join_stat_frames(frames):
        return (
            reduce(
                lambda left, right: left.join(
                    right, on=["iter", "positions", "daf"], how="full", coalesce=True
                ),
                frames,
            )
            .sort("positions")
            .filter(pl.col("daf").is_not_null())
            .collect()
        )

    def process_stat(i):
        # New format: i is {"snps": snps_joined, "window": window_df}
        return {
            "snps": i["snps"],
            "windows": i["window"],
        }

    def process_batch(batch):
        return [process_stat(i) for i in batch]

    def chunked(iterable, size):
        for i in range(0, len(iterable), size):
            yield iterable[i : i + size]

    df_r = {sim_type: make_r_bins(v.parameters) for sim_type, v in results.items()}

    neutral = results["neutral"]
    batches = list(chunked(neutral.stats, 50))
    norm_stats = [
        item
        for batch_result in Parallel(n_jobs=nthreads)(
            delayed(process_batch)(batch) for batch in batches
        )
        for item in batch_result
    ]

    neutral_binned = binned_stats(
        *normalize_neutral(norm_stats, df_r_bins=df_r["neutral"])
    )
    return neutral_binned, df_r


def reset_empirical_bins(results, r_bins):
    def make_r_bins(iter_ids, params):
        if r_bins is None:
            return None
        return (
            pl.DataFrame({"iter": iter_ids, "cm_mb": params[:, -1]})
            .with_columns(pl.col("cm_mb").cut(breaks=r_bins).alias("r_bins"))
            .select(["iter", "cm_mb", "r_bins"])
        )

    def join_stat_frames(stats):
        stat_keys = [
            "nsl",
            "ihs",
            "isafe",
            "dind_high_low",
            "s_ratio",
            "hapdaf_o",
            "hapdaf_s",
        ]
        return (
            reduce(
                lambda left, right: left.join(
                    right, on=["iter", "positions", "daf"], how="full", coalesce=True
                ),
                [stats[k].lazy() for k in stat_keys],
            )
            .sort("positions")
            .collect()
        )

    norm_stats = []
    df_r_l = []
    regions = {}

    for k, v in results.items():
        print(k)
        iter_ids = v.stats["window"]["iter"].unique(maintain_order=True)
        regions[k] = iter_ids
        df_r_l.append(make_r_bins(iter_ids, v.parameters))
        norm_stats.append(
            {
                "snps": join_stat_frames(v.stats),
                "windows": v.stats.get("window"),
            }
        )

    try:
        df_r = pl.concat(df_r_l)
    except Exception as e:
        print(f"_process_vcf: failed to concat r_bins table: {e}")
        df_r = None

    empirical_bins = binned_stats(
        *normalize_neutral(norm_stats, vcf=True, df_r_bins=df_r)
    )
    return empirical_bins, df_r, regions


################## Utils


def open_tree(ts, seq_len=1.2e6):
    """Read a tree sequence file and return outputs matching parse_ms_numpy format.

    Returns
    -------
    hap_bi : np.ndarray, shape (n_biallelic_sites, n_samples), dtype int8
    rec_bi : np.ndarray, shape (n_biallelic_sites, 4), dtype int64
    ac_bi  : np.ndarray, shape (n_biallelic_sites, 2), dtype int64
    biallelic_mask : np.ndarray, shape (n_sites,), dtype bool
    position_masked : np.ndarray, shape (n_biallelic_sites,), dtype int64
    genetic_position_masked : np.ndarray, shape (n_biallelic_sites,), dtype int64
    """
    from allel import HaplotypeArray

    try:
        if isinstance(ts, str):
            ts = tskit.load(ts)
        G = ts.genotype_matrix()
        positions_raw = np.array([v.position for v in ts.variants()])
    except Exception as e:
        raise ValueError(f"Could not load tree sequence: {e}")

    n_sites, n_samples = G.shape

    if positions_raw.max() <= 1.0:
        # fractional [0, 1] coordinates — scale to bp
        positions_bp = np.round(positions_raw * seq_len).astype(np.int64)
    else:
        positions_bp = positions_raw.astype(np.int64)
    positions_bp = np.clip(positions_bp, 1, int(seq_len))

    rec_map = np.column_stack(
        (
            np.ones(n_sites, dtype=np.int64),
            np.arange(n_sites, dtype=np.int64),
            positions_bp,
            positions_bp,
        )
    )

    hap = HaplotypeArray(G.astype(np.int8), copy=False)
    ac = hap.count_alleles()
    biallelic_mask = ac.is_biallelic_01()

    hap_bi = hap.compress(biallelic_mask, axis=0)
    ac_bi = ac.compress(biallelic_mask, axis=0)
    rec_bi = rec_map[biallelic_mask]

    position_masked = rec_bi[:, 3].astype(np.int64, copy=False)
    genetic_position_masked = rec_bi[:, 2]

    return (
        hap_bi.view(np.ndarray),
        rec_bi,
        ac_bi.view(np.ndarray),
        biallelic_mask,
        position_masked,
        genetic_position_masked,
    )


def best_window_idx_per_position_cm(
    pos_sorted,
    w_start,
    w_end,
    w_cm,
    mode="max",
):
    """
    For each SNP position, pick the overlapping window with max/min cm_mb.
    Tie-breaker: larger end wins.
    Returns -1 if no window covers the position.
    """
    out = np.full(pos_sorted.size, -1, dtype=np.int32)
    heap = []
    i = 0
    nw = w_start.size

    if mode == "max":

        def key(k):  # (primary, secondary, idx)
            return (-w_cm[k], -w_end[k], k)

    elif mode == "min":

        def key(k):
            return (w_cm[k], -w_end[k], k)

    else:
        raise ValueError("mode must be 'max' or 'min'")

    for j, p in enumerate(pos_sorted):
        while i < nw and w_start[i] <= p:
            heapq.heappush(heap, key(i))
            i += 1

        while heap and (-heap[0][1]) < p:
            heapq.heappop(heap)

        if heap:
            out[j] = heap[0][2]

    return out


@contextmanager
def omp_num_threads(n_threads: int):
    """
    Context manager to temporarily set the OMP_NUM_THREADS environment variable.

    Args:
        n_threads (int): Number of OpenMP threads to expose inside the context.

    Usage:
        with omp_num_threads(10):
            # Inside this block, os.environ['OMP_NUM_THREADS'] == "10"
            heavy_compute()
        # On exit, the previous value (or absence) is restored.

    Notes:
        - Only affects libraries honoring OMP_NUM_THREADS (e.g., numexpr, MKL-backed numpy).
        - This modifies process environment for the duration of the context only.
    """
    key = "OMP_NUM_THREADS"
    old_val = os.environ.get(key, None)
    os.environ[key] = str(n_threads)
    try:
        yield
    finally:
        # restore original state
        if old_val is None:
            os.environ.pop(key, None)
        else:
            os.environ[key] = old_val


def _ms_float_to_int(positions, total_phys_len):
    """
    Pure-Python/numpy equivalent of msPositionsToIntegerPositions (diploshic msTools.py).
    Converts float ms positions in [0,1) to unique integer bp positions in [1, total_phys_len].
    """
    n = len(positions)
    new_positions = np.empty(n, dtype=np.int64)
    prev_pos = -1.0
    prev_int_pos = -1

    for i in range(n):
        pos = float(positions[i])
        orig_pos = pos
        if pos == prev_pos:
            pos += 1e-6
        prev_pos = orig_pos

        int_pos = int(total_phys_len * pos)
        if int_pos == 0:
            int_pos = 1
        if int_pos <= prev_int_pos:
            int_pos = prev_int_pos + 1
        prev_int_pos = int_pos
        new_positions[i] = int_pos

    # Handle positions that overflow total_phys_len (mirrors fillInSnpSlotsWithOverflowers)
    overflow = new_positions > total_phys_len
    n_over = int(overflow.sum())
    if n_over > 0:
        kept = new_positions[~overflow]
        kept_set = set(kept.tolist())
        extras = []
        for p in range(int(total_phys_len), 0, -1):
            if p not in kept_set:
                extras.append(p)
                if len(extras) == n_over:
                    break
        new_positions = np.sort(
            np.concatenate([kept, np.array(extras, dtype=np.int64)])
        )

    return new_positions


def parse_and_filter_ms(
    ms_file: str,
    seq_len: float = 1.2e6,
    discretize_positions: bool = True,
):
    from allel import HaplotypeArray

    if not ms_file.endswith((".out", ".out.gz", ".ms", ".ms.gz")):
        warn(f"File {ms_file} has an unexpected extension.")

    open_function = gzip.open if ms_file.endswith(".gz") else open

    in_rep = False
    num_segsites = None
    pos_arr = None
    hap_rows = []

    with open_function(ms_file, "rt") as fh:
        for raw in fh:
            line = raw.strip()

            if line.startswith("//"):
                if in_rep:
                    break
                in_rep = True
                num_segsites = None
                pos_arr = None
                hap_rows.clear()
                continue

            if not in_rep or not line:
                continue

            if line.startswith("segsites"):
                parts = line.split()
                if len(parts) >= 2:
                    try:
                        num_segsites = int(parts[1])
                        if num_segsites == 0:
                            # match parse_ms_numpy: 6-tuple with empty arrays
                            empty_hap = np.empty((0, 1), dtype=np.int8)
                            empty_rec = np.empty((0, 4), dtype=np.int64)
                            empty_ac = np.empty((0, 2), dtype=np.int64)
                            empty_mask = np.empty(0, dtype=bool)
                            empty_pos = np.empty(0, dtype=np.int64)
                            empty_gpos = np.empty(0, dtype=np.float64)
                            return (
                                empty_hap,
                                empty_rec,
                                empty_ac,
                                empty_mask,
                                empty_pos,
                                empty_gpos,
                            )
                    except ValueError:
                        warn(f"File {ms_file} is malformed.")
                        return None
                else:
                    warn(f"File {ms_file} is malformed.")
                    return None
                continue

            if line.startswith("positions"):
                try:
                    _, values = line.split(":", 1)
                    positions = np.fromstring(values, sep=" ", dtype=np.float64)
                except Exception:
                    warn(f"File {ms_file} is malformed.")
                    return None

                if discretize_positions:
                    new_positions = _ms_float_to_int(positions, seq_len)
                    new_positions[new_positions > seq_len] = int(seq_len)
                    pos_arr = new_positions
                else:
                    pos_arr = positions * seq_len
                continue

            if line[0] in "01":
                hap_rows.append(line)

    if not hap_rows or pos_arr is None or num_segsites is None:
        warn(f"File {ms_file} is malformed.")
        return None

    n_samples = len(hap_rows)
    n_sites = len(hap_rows[0])
    H = np.empty((n_samples, n_sites), dtype=np.int8)
    for i, s in enumerate(hap_rows):
        H[i, :] = np.frombuffer(s.encode("ascii"), dtype=np.uint8) - 48
    H = H.T  # (num_segsites, num_samples)

    rec_map = np.column_stack(
        (
            np.ones(n_sites, dtype=np.int64),
            np.arange(n_sites, dtype=np.int64),
            pos_arr,
            pos_arr,
        )
    )

    hap = HaplotypeArray(H, copy=False)
    ac = hap.count_alleles()
    biallelic_mask = ac.is_biallelic_01()

    hap_bi = hap.compress(biallelic_mask, axis=0)
    ac_bi = ac.compress(biallelic_mask, axis=0)
    rec_bi = rec_map[biallelic_mask]

    position_masked = rec_bi[:, 3]
    genetic_position_masked = rec_bi[:, 2]

    return (
        np.ascontiguousarray(hap_bi.view(np.ndarray), dtype=np.int8),
        rec_bi,
        ac_bi.view(np.ndarray),
        biallelic_mask,
        np.ascontiguousarray(position_masked, dtype=np.int64),
        genetic_position_masked,
    )


def parse_ms_numpy(
    ms_file: str,
    seq_len: float = 1.2e6,
    discretize_positions: bool = True,
):
    """
    Vectorized ms parser. Eliminates line-by-line Python loops.
    Optimized for high-core count scaling (128+ threads).
    """
    open_func = gzip.open if ms_file.endswith(".gz") else open

    # Use 'rb' (binary) to avoid the overhead of UTF-8 decoding
    try:
        with open_func(ms_file, "rb") as fh:
            content = fh.read()
    except Exception:
        warn(f"Could not read file {ms_file}")
        return None

    # 1. Fast-find headers using byte-searches
    # find() is implemented in C and much faster than iterating lines in Python
    try:
        sep = b"//"
        start_rep = content.find(sep)
        if start_rep == -1:
            return None

        # Locate segsites
        seg_idx = content.find(b"segsites:", start_rep)
        line_end = content.find(b"\n", seg_idx)
        num_segsites = int(content[seg_idx + 10 : line_end])

        if num_segsites == 0:
            # Return identical empty signature as original
            return (
                np.empty((0, 1), "i1"),
                np.empty((0, 4), "i8"),
                np.empty((0, 2), "i4"),
                np.empty(0, bool),
                np.empty(0, "i8"),
                np.empty(0, "f8"),
            )

        # Locate positions
        pos_idx = content.find(b"positions:", line_end)
        pos_end = content.find(b"\n", pos_idx)
        pos_vals = content[pos_idx + 11 : pos_end].split()
        positions = np.array(pos_vals, dtype=np.float64)

        # 2. Extract Haplotype Matrix Vectorially
        # The matrix starts exactly 1 byte after pos_end
        data_block = content[pos_end + 1 :].strip()

        # Fast-convert ASCII '0'/'1' to integers (0/1)
        # We ignore newlines by filtering the buffer
        raw_haps = (
            np.frombuffer(
                data_block.replace(b"\n", b"").replace(b"\r", b""), dtype=np.uint8
            )
            - 48
        )

        n_samples = len(raw_haps) // num_segsites
        H = raw_haps.reshape((n_samples, num_segsites)).T  # (n_sites, n_samples)

    except (ValueError, IndexError):
        warn(f"File {ms_file} is malformed.")
        return None

    # 3. Position Discretization
    if discretize_positions:
        # Use vectorized floor or round as per your _ms_float_to_int logic
        pos_arr = np.floor(positions * seq_len).astype(np.int64)
        pos_arr[pos_arr > seq_len] = int(seq_len)
    else:
        pos_arr = positions * seq_len

    # 4. Filter Biallelic Sites (Vectorized)
    alt_counts = H.sum(axis=1).astype(np.int32)
    ref_counts = np.int32(n_samples) - alt_counts
    biallelic_mask = (alt_counts > 0) & (ref_counts > 0)

    H_bi = np.ascontiguousarray(H[biallelic_mask], dtype=np.int8)
    ac_bi = np.column_stack((ref_counts[biallelic_mask], alt_counts[biallelic_mask]))
    pos_bi = pos_arr[biallelic_mask]

    # 5. Build Rec Matrix
    n_bi = H_bi.shape[0]
    rec_bi = np.column_stack(
        (
            np.ones(n_bi, dtype=np.int64),
            np.arange(n_bi, dtype=np.int64),
            pos_bi,
            pos_bi,
        )
    )

    return (
        np.ascontiguousarray(H_bi, dtype=np.int8),
        rec_bi,
        ac_bi,
        biallelic_mask,
        np.ascontiguousarray(rec_bi[:, 3], dtype=np.int64),
        rec_bi[:, 2],
    )


def cleaning_summaries(summ_stats, params, model):
    """
    Cleans summary statistics by removing entries where either list in summ_stats has None.

    When an entire axis is consistently None (e.g., all SNPs are None for window-only
    stats), those entries are kept as-is rather than discarding everything.

    Args:
        data: Unused input (kept for compatibility).
        summ_stats (list of 2 lists): Summary statistics [list1, list2].
        params (np.ndarray): Parameter matrix.
        model (str): Model identifier.

    Returns:
        summ_stats_filtered (list of 2 lists): Cleaned summary statistics.
        params (np.ndarray): Filtered params.
        malformed_files (list of str): Indices removed with reason.
    """
    # Detect if an entire axis is uniformly None
    all_x_none = all(x is None for x in summ_stats[0])
    all_y_none = all(y is None for y in summ_stats[1])

    mask = []
    summ_stats_filtered = [[], []]
    malformed_files = []

    for i, (x, y) in enumerate(zip(summ_stats[0], summ_stats[1])):
        # Logic: An entry is "bad" ONLY if it is None while other entries
        # in that same list are NOT None.
        x_bad = x is None and not all_x_none
        y_bad = y is None and not all_y_none

        if x_bad or y_bad:
            mask.append(i)
            malformed_files.append(f"Model {model}, index {i} is malformed.")
        else:
            # If we reach here, either the data is present,
            # or the entire axis was intentionally None.
            summ_stats_filtered[0].append(x)
            summ_stats_filtered[1].append(y)

    # Use the mask to clean the parameter matrix
    if mask and params is not None:
        params = np.delete(params, mask, axis=0)

    return summ_stats_filtered, params, malformed_files


def genome_reader(hap_data, recombination_map=None, region=None, samples=None):
    """
    Read a VCF/BCF region and return haplotypes, recombination map, allel count array, biallelic masking, physical and genetic positions arrays.

    Args:
        hap_data (str): Path to VCF/BCF file.
        recombination_map (str | None, default=None):
            Optional TSV map with columns: chr, start, end, cm_mb, cm.
        region (str | None, default=None): Region string 'CHR:START-END' for subsetting.
        samples (list[str] | np.ndarray | None, default=None): Optional sample subset.

    Returns:
        dict[str, tuple]:
            {region: (hap_int, rec_map, ac.values, biallelic_filter, position_masked, genetic_position_masked)}
            or {region: None} if no biallelic sites are present.

        Where:
            - hap_int: (S x N) np.int8 haplotypes.
            - rec_map: array with columns [chrom, idx, pos, cm].
            - ac.values: allele counts (scikit-allel).
            - biallelic_filter: boolean mask on original sites.
            - position_masked: np.int64 physical positions after biallelic filtering.
            - genetic_position_masked: last column of rec_map.

    Notes:
        - If `recombination_map` is None, genetic distance defaults to physical positions.
    """
    from allel import GenotypeArray, read_vcf

    filterwarnings("ignore", message="invalid INFO header", module="allel.io.vcf_read")

    raw_data = read_vcf(hap_data, region=region, samples=samples)

    try:
        gt = GenotypeArray(raw_data["calldata/GT"])
    except Exception:
        return {region: None}

    pos = raw_data["variants/POS"]
    np_chrom = np.char.replace(raw_data["variants/CHROM"].astype(str), "chr", "")
    try:
        np_chrom = np_chrom.astype(int)
    except Exception:
        pass
    _ac = gt.count_alleles()

    # Filtering monomorphic just in case
    biallelic_filter = _ac.is_biallelic_01()
    ac = _ac[biallelic_filter]
    hap_int = gt.to_haplotypes().values[biallelic_filter].astype(np.int8)
    position_masked = pos[biallelic_filter].astype(np.int64)
    np_chrom = np_chrom[biallelic_filter]

    if hap_int.shape[0] == 0:
        return {region: None}

    # if region is None:
    #     d_pos = dict(zip(np.arange(position_masked.size + 1), position_masked))
    # else:
    #     tmp = list(map(int, region.split(":")[-1].split("-")))
    #     d_pos = dict(zip(np.arange(tmp[0], tmp[1] + 1), np.arange(int(5e5)) + 1))

    if recombination_map is None:
        rec_map = pl.DataFrame(
            {
                "chrom": np_chrom,
                "idx": np.arange(position_masked.size),
                "pos": position_masked,
                "cm": position_masked,
            }
        ).to_numpy()
    else:
        df_recombination_map = (
            pl.read_csv(
                recombination_map,
                separator="\t",
                comment_prefix="#",
                schema=pl.Schema(
                    [
                        ("chr", pl.String),
                        ("start", pl.Int64),
                        ("end", pl.Int64),
                        ("cm_mb", pl.Float64),
                        ("cm", pl.Float64),
                    ]
                ),
            )
            .filter(pl.col("chr") == "chr" + str(np_chrom[0]))
            .sort("start")
        )
        genetic_distance = get_cm(df_recombination_map, position_masked)

        rec_map = pl.DataFrame(
            [
                np_chrom,
                np.arange(position_masked.size),
                position_masked,
                genetic_distance,
            ]
        ).to_numpy()

        if np.all(rec_map[:, -1] == 0):
            rec_map[:, -1] = rec_map[:, -2]

    genetic_position_masked = rec_map[:, -1]

    return (
        np.ascontiguousarray(hap_int, dtype=np.int8),
        rec_map,
        asarray_ndim(ac.values, 2),
        biallelic_filter,
        np.ascontiguousarray(position_masked, dtype=np.int32),
        genetic_position_masked,
    )


def get_cm(df_rec_map, positions, cm_mb=False):
    """
    Interpolate cumulative genetic distance (cM) at given physical positions.

    Args:
        df_rec_map (polars.DataFrame): Map where column 1 is physical position (bp),
            and last column is cumulative cM (monotonic expected).
        positions (np.ndarray): 1D array of physical positions (bp) to interpolate.

    Returns:
        np.ndarray: Interpolated cumulative cM (negative values clamped to 0).

    Notes:
        - Uses linear interpolation with extrapolation at ends.
    """
    from scipy.interpolate import interp1d

    interp_func = interp1d(
        df_rec_map.select("end").to_numpy().flatten(),
        df_rec_map.select("cm").to_numpy().flatten(),
        kind="linear",
        fill_value="extrapolate",
    )

    if cm_mb:
        rr1 = interp_func(positions[:, 0])
        rr2 = interp_func(positions[:, 1])

        rr1[rr1 < 0] = 0
        rr2[rr2 < 0] = 0

        rate = (rr2 - rr1) / ((positions[:, 1] - positions[:, 0]) / 1e6)

        return pl.DataFrame(
            {
                "chr": df_rec_map.select("chr").unique().item(),
                "start": positions[:, 0],
                "end": positions[:, 1],
                "cm_mb": rate,
            }
        )

    else:
        # Interpolate the cM values at the interval positions
        rr1 = interp_func(positions)
        # rr2 = interp_func(positions[:, 1])
        rr1[rr1 < 0] = 0

        return rr1


def center_window_cols(df, _iter=1):
    """
    Add iter to statistic dataframe to ensure proper stats/replica combination.

    Args:
        df (polars.DataFrame): Input feature rows for a single window/region.
        _iter (int, default=1): Iteration identifier to add as an 'iter' column.

    Returns:
        polars.DataFrame:
            - If `df` is empty: returns a single-row DF with just 'iter' plus `df` columns (empty).
            - Otherwise: returns `df` with an added 'iter' column (Int64) and with columns ordered as:
                ['iter', 'positions', <all other columns excluding 'iter' and 'positions'>]

    """
    if df.is_empty():
        # Return a dataframe with one row of the specified default values
        return pl.concat(
            [pl.DataFrame({"iter": _iter}), df],
            how="horizontal",
        )

    df = (
        df.with_columns(
            [
                pl.lit(_iter).alias("iter"),
            ]
        )
        .with_columns(pl.col(["iter"]).cast(pl.Int64))
        .select(
            pl.col(["iter", "positions"]),
            pl.all().exclude(["iter", "positions"]),
        )
    )
    return df


def pivot_feature_vectors(df_fv, vcf=False):
    """
    Categorizes genomic sweep data into different models based on timing and fixation status,
    then pivots the data for analysis.


    Args:
        df_fv (polars.DataFrame): Feature vectors with columns including
            't', 'f_t', 'f_i', 's', 'iter', 'window', 'center', and metrics.
        vcf (bool, default=False): Whether the input comes from VCF processing (special handling).

    Returns:
        polars.DataFrame: Wide/pivoted feature table with cleaned column names.
    Notes:
        - When `vcf=True`, constructs 'iter' from 'nchr' and ±600kb window around center.
    """

    # Categorize sweeps based on age and completeness
    df_fv = df_fv.with_columns(
        pl.when((pl.col("t") >= 2000) & (pl.col("f_t") >= 0.9))
        .then(pl.lit("hard_old_complete"))
        .when((pl.col("t") >= 2000) & (pl.col("f_t") < 0.9))
        .then(pl.lit("hard_old_incomplete"))
        .when((pl.col("t") < 2000) & (pl.col("f_t") >= 0.9))
        .then(pl.lit("hard_young_complete"))
        .otherwise(pl.lit("hard_young_incomplete"))
        .alias("model")
    )

    # Further categorize as soft or hard sweep based on initial frequency
    df_fv = df_fv.with_columns(
        pl.when(pl.col("f_i") != df_fv["f_i"].min())
        .then(pl.col("model").str.replace("hard", "soft"))
        .otherwise(pl.col("model"))
        .alias("model")
    )

    # Handle the case where all selection coefficients are zero (neutral model)
    if (df_fv["s"] == 0).all():
        df_fv = df_fv.with_columns(pl.lit("neutral").alias("model"))

    # Determine sorting method based on iter column type
    # sort_multi = True if df_fv["iter"].dtype == pl.Utf8 else False

    # Pivot the data
    # Assuming columns 7 to end-1 are the values to pivot
    # value_columns = df_fv.columns[7:-1]
    value_columns = df_fv.columns[9:-1]

    if vcf:
        if df_fv["iter"].dtype == pl.Int64:
            # remove nchr
            value_columns = value_columns[:-1]

            fv_center = np.linspace(6e5 - 1e5, 6e5 + 1e5, 21).astype(int)
            fv_center = df_fv["center"].unique().to_numpy()
            rows_per_center = df_fv["window"].unique().len()
            n_rows = df_fv.height
            full_center = np.tile(
                np.repeat(fv_center, rows_per_center),
                n_rows // (len(fv_center) * rows_per_center) + 1,
            )[:n_rows]

            df_fv = df_fv.with_columns(
                (
                    pl.col("nchr").cast(pl.String)
                    + ":"
                    + (pl.col("iter").cast(pl.Int64) - int(6e5)).cast(pl.String)
                    + "-"
                    + (pl.col("iter").cast(pl.Int64) + int(6e5)).cast(pl.String)
                ).alias("iter"),
                pl.lit(full_center).alias("center"),
            ).select(pl.exclude("nchr"))

    df_fv_w = df_fv.pivot(
        values=value_columns,
        index=["iter", "s", "t", "f_i", "f_t", "mu", "r", "model"],
        on=["window", "center"],
    )

    # Clean up column names
    df_fv_w = df_fv_w.rename(
        {
            col: col.replace("{", "").replace("}", "").replace(",", "_")
            for col in df_fv_w.columns
        }
    )

    return df_fv_w


def get_closest_snps(position_array, center, N):
    """
    Given a list of SNP positions and a center position, return the indices of the N closest SNPs.

    Args:
        position_array (np.ndarray): 1D array of SNP positions (bp).
        center (int | float): Central genomic coordinate.
        N (int): Number of SNPs to select. Must be <= len(position_array).

    Returns:
        np.ndarray: Indices of the N closest SNPs (sorted by increasing distance, then by position).

    Raises:
        AssertionError: If `position_array` is not 1D or if `N` exceeds array length.

    Notes:
        - Ties are resolved by `np.argsort` stability on the distance array; if exact distances tie,relative order follows input order.
    """
    position_array = np.asarray(position_array)
    assert position_array.ndim == 1, "position_array must be a 1D array"
    assert N <= len(position_array), "N exceeds the number of SNPs in the array"

    distances = np.abs(position_array - center)
    closest_indices = np.argsort(distances)[:N]
    return np.sort(closest_indices)


################## Summaries


[docs] def _process_vcf( data_dir, nthreads, windows=None, step=1e5, step_vcf=int(1e4), locus_length=int(1.2e6), recombination_map=None, r_bins=None, min_rate=None, suffix=None, func=None, save_stats=False, stats=None, ): from . import Parallel, delayed from .data import Data """ Process VCF/BCF files to compute, normalize, and estimate feature vectors. This function scans a directory for bgzipped VCFs (``*.vcf.gz``), extracts variant information, computes per-window summary statistics, normalizes them using empirical distributions estimated from files, and returns feature vectors to train/predict the CNN. :param str data_dir: Directory containing bgzipped VCF or BCF files (pattern ``*vcf.gz`` and ``*bcf.gz``). Output Parquet and pickle files are written under the same directory. :param int nthreads: Number of threads. :param list windows: List of window sizes (in base pairs) to compute summary statistics. :param int step: Step size (in base pairs) for sliding windows. :param str recombination_map: Optional path to a recombination map. If ``None``, physical distances are used as a proxy for genetic distances. :param str suffix: Optional suffix appended to output file names. :param callable func: Function to estimate summary statistics. See ``calculate_stats_vcf_flat``. :returns: A pair of Polars DataFrames: - **df_pred**: normalized feature vectors used for model training or prediction. - **df_pred_raw**: corresponding raw feature vectors before normalization. :rtype: tuple[polars.DataFrame, polars.DataFrame] :raises FileNotFoundError: If no VCF/BCF files matching ``*vcf.gz`` are found. :raises ValueError: If an input file cannot be parsed correctly. :notes: - Writes the following files to ``data_dir``: * ``fvs{suffix}.parquet`` – normalized feature vectors * ``fvs_raw{suffix}.parquet`` – raw feature vectors * ``empirical_bins{suffix}.pickle`` – empirical normalization bins - Each VCF is processed independently to reduce memory load. - The function assumes variants follow standard diploid encoding and are suitable for per-window summary statistic computation. """ assert r_bins is None or ( min_rate is not None and isinstance(min_rate, float) ), "If r_bins is not None, min_rate must be a float (minimum recombination rate simulated)." if func is None: func = calculate_stats_vcf_flat center = [int(step // 2), int(locus_length - step // 2)] if windows is None: windows = [100000] suffix_str = f"_{suffix}" if suffix is not None else "" # Paths and containers fvs_file = {} sims = {} regions = {} df_params_l = [] vcf_files = sorted( glob.glob(os.path.join(data_dir, "*vcf.gz")) + glob.glob(os.path.join(data_dir, "*bcf.gz")) ) if not vcf_files: raise FileNotFoundError(f"No VCF/BCF files found in directory: {data_dir}") for vcf_path in vcf_files[:]: # if 'chr22_' not in vcf_path: # continue basename = os.path.basename(vcf_path) key = basename.replace(".vcf", "").replace(".bcf", "").replace(".gz", "") key = key.replace(".", "_").lower() fs_data = Data( vcf_path, nthreads=nthreads, window_size=locus_length, step=step_vcf ) sim_dict = fs_data.read_vcf() # build parameter DataFrame n = len(sim_dict["region"]) df_params_l.append( pl.DataFrame( { "model": sim_dict["region"], "s": np.zeros(n), "t": np.zeros(n), "saf": np.zeros(n), "eaf": np.zeros(n), "mu": np.zeros(n), "r": np.zeros(n), } ) ) sims[key] = sim_dict["sweep"] regions[key] = sim_dict["region"] fvs_file[key] = os.path.join(data_dir, "vcfs", f"fvs_{key}.parquet") df_params = pl.concat(df_params_l) if recombination_map is not None: df_recombination_map = pl.read_csv( recombination_map, separator="\t", comment_prefix="#", schema=pl.Schema( [ ("chr", pl.String), ("start", pl.Int64), ("end", pl.Int64), ("cm_mb", pl.Float64), ("cm", pl.Float64), ] ), ) else: df_recombination_map = None results = {} tmp_bins = [] df_r_l = [] # Pool 1: summary statistics # Single pool shared across all VCF files — workers stay warm, # numba caches persist across chromosomes. with Parallel(n_jobs=nthreads, backend="loky", verbose=2) as stats_pool: for k, vcf_file in sims.items(): print(k) # compute center from region strings "chr: start-end" center_coords = [ tuple(map(int, r.split(":")[-1].split("-"))) for r in regions[k] ] nchr = regions[k][0].split(":")[0] params = df_params.filter(pl.col("model").str.contains(f"{nchr}:")) if recombination_map is not None: cm_mb = get_cm( df_recombination_map.filter(pl.col("chr") == nchr), np.asarray(center_coords), cm_mb=True, ) if r_bins is not None: tmp_r = cm_mb.with_columns( [ pl.col("cm_mb").cut(breaks=r_bins).alias("r_bins"), pl.format( "{}:{}-{}", pl.col("chr"), pl.col("start"), pl.col("end") ) .alias("region") .alias("iter"), ] ).select("iter", "cm_mb", "r_bins") params = params.with_columns((tmp_r["cm_mb"]).alias("r")) mask = (tmp_r["cm_mb"] >= min_rate).to_numpy() exclude_r = tmp_r.filter(~mask)["iter"].to_numpy() # remove excluded regions from regions[k] exclude_set = set(exclude_r) regions[k] = np.array([r for r in regions[k] if r not in exclude_set]) # filter tmp_r tmp_r = tmp_r.filter(mask) params = params.filter(mask).to_numpy()[:, 1:].astype(float) df_r_l.append(tmp_r) else: tmp_r = None if recombination_map is not None: params = params.with_columns(cm_mb["cm_mb"].alias("r")) params = ( params.select(["s", "t", "saf", "eaf", "mu", "r"]) .to_numpy() .astype(float) ) _tmp_stats = func( vcf_file, regions[k], center=center, windows=windows, step=step, recombination_map=recombination_map, locus_length=locus_length, stats=stats, nthreads=nthreads, parallel_manager=stats_pool, ) snps_df, window_df = _tmp_stats tmp_bins.append({"snps": snps_df, "windows": window_df}) if not np.all(params[:, 3] == 0): params[:, 0] = -np.log(params[:, 0]) results[k] = summaries({"snps": snps_df, "window": window_df}, params) if save_stats: save_pickle(f"{data_dir}/raw_statistics{suffix_str}.pickle", results) try: df_r = pl.concat(df_r_l) except Exception as e: if r_bins is not None: print(f"_process_vcf: failed to concat r_bins table: {e}") df_r = None empirical_bins = binned_stats( *normalize_neutral(tmp_bins, vcf=True, df_r_bins=df_r) ) df_fv_cnn = {} df_fv_cnn_raw = {} # Pool 2: normalization with Parallel(n_jobs=nthreads, backend="loky", verbose=2) as norm_pool: for k, stats_values in results.items(): print(k) df_w, df_w_raw = normalize_stats( stats_values, empirical_bins, region=regions[k], center=center, windows=windows, step=step, parallel_manager=norm_pool, nthreads=nthreads, vcf=True, df_r_bins=df_r, locus_length=locus_length, ) df_fv_cnn[k] = df_w df_fv_cnn_raw[k] = df_w_raw df_pred = pl.concat(df_fv_cnn.values(), how="vertical") df_pred_raw = pl.concat(df_fv_cnn_raw.values(), how="vertical") df_pred = ( df_pred.with_columns( pl.col("iter") .str.extract_groups(r"chr(\d+):(\d+)-(\d+)") .struct.rename_fields(["chrom", "start", "end"]) .alias("g") ) .unnest("g") .with_columns(pl.col(["chrom", "start", "end"]).cast(pl.Int64)) .sort(["chrom", "start", "end"]) .select(pl.exclude(["chrom", "start", "end"])) ) df_pred_raw = ( df_pred_raw.with_columns( pl.col("iter") .str.extract_groups(r"chr(\d+):(\d+)-(\d+)") .struct.rename_fields(["chrom", "start", "end"]) .alias("g") ) .unnest("g") .with_columns(pl.col(["chrom", "start", "end"]).cast(pl.Int64)) .sort(["chrom", "start", "end"]) .select(pl.exclude(["chrom", "start", "end"])) ) with open(os.path.join(data_dir, f"empirical_bins{suffix_str}.pickle"), "wb") as f: pickle.dump(empirical_bins, f) df_pred.write_parquet(f"{data_dir}/fvs{suffix_str}.parquet") df_pred_raw.write_parquet(f"{data_dir}/fvs_raw{suffix_str}.parquet") return df_pred, df_pred_raw
[docs] def _process_sims( data_dir, nthreads, windows=None, step=1e5, r_bins=None, suffix=None, func=None, save_stats=False, locus_length=int(1.2e6), stats=None, ): """ Process ms files from simulation to compute, normalize, and estimate feature vectors. It scans a directory for simulation outputs (neutral and sweep), computes summary statistics for each replicate, normalizes the results between classes, and exports Parquet feature vectors to traing the CNN. :param str data_dir: Directory containing simulation output files (neutral and sweep). Expected substructure and file naming conventions follow the Flexsweep.Simulator class output (e.g., ``data_dir/sweeps/`` and ``data_dir/neutral/``). :param int nthreads: Number of threads. :param list windows: List of window sizes (in base pairs) to compute summary statistics. :param int step: Step size (in base pairs) for sliding windows. :param str suffix: Optional suffix appended to output file names. :param callable func: Function to estimate summary statistics. See ``calculate_stats_simulations``. :returns: A pair of Polars DataFrames: - **df_pred**: normalized feature vectors for CNN training. - **df_pred_raw**: raw feature vectors. :rtype: tuple[polars.DataFrame, polars.DataFrame] :raises FileNotFoundError: If no simulation files are found under ``data_dir``. :raises ValueError: If simulation data are malformed or incompatible with the expected format. :notes: - Writes the following files to ``data_dir``: * ``fvs{suffix}.parquet`` – normalized feature vectors. * ``fvs_raw{suffix}.parquet`` – raw feature vectors. * ``empirical_bins{suffix}.pickle`` – empirical normalization bins. - Neutral and sweep simulations are processed jointly to derive shared empirical normalization bins. - Designed for Flexsweep.Simulators class simulations folders but can handle compatible structures with proper naming. """ from . import Parallel, delayed from .data import Data center = [int(step // 2), int(locus_length - step // 2)] if windows is None: windows = [100000] suffix_str = f"_{suffix}" if suffix is not None else "" for folder in ("neutral", "sweep"): path = os.path.join(data_dir, folder) if not os.path.isdir(path): raise ValueError(f"Missing folder: {path}") if not glob.glob(os.path.join(path, "*")): raise ValueError(f"No files in folder: {path}") fs_data = Data(data_dir) sims, df_params = fs_data.read_simulations() results = {} malformed_files = {} d_centers = {} binned_data = {} df_r = {} with Parallel(n_jobs=nthreads, backend="loky", verbose=2) as parallel: for sim_type, sim_list in sims.items(): if len(sim_list) > 250000: mask = np.random.choice(np.arange(len(sim_list)), 250000) else: mask = np.arange(0, len(sim_list)) # mask = np.arange(0, 10000) params = df_params.filter(pl.col("model") == sim_type)[mask, 1:].to_numpy() d_centers[sim_type] = np.array(center).astype(int) if r_bins is not None: tmp_r = ( pl.DataFrame({"r": params[:, -1] * 1e8}) .with_row_index(name="iter", offset=1) .with_columns(pl.col("r").cut(breaks=r_bins).alias("r_bins")) .select(["iter", "r_bins"]) ) else: tmp_r = None df_r[sim_type] = tmp_r # Small-batch dispatch: ~50 files per task for dynamic load balancing # while keeping IPC overhead low (stacked numpy per batch). BATCH_SIZE = 50 n_sims = len(sim_list[mask]) batches = [] for i in range(0, n_sims, BATCH_SIZE): batch_end = min(i + BATCH_SIZE, n_sims) batches.append((sim_list[mask][i:batch_end], i + 1)) _tmp_results = parallel( delayed(batch_simulations)( batch, start_idx, func, center, windows, step, locus_length, stats=stats, ) for batch, start_idx in batches ) # Reconstruct per-file Polars DataFrames from batched numpy results window_cols_resolved, _, _, _ = resolve_stats(stats) win_schema = ( ["iter", "center", "window"] + window_cols_resolved if window_cols_resolved else None ) flat = [] for win_stacked, snp_list, nwpf in _tmp_results: for i in range(len(snp_list)): # Reconstruct window DataFrame from stacked numpy slice win_df = None if win_stacked is not None and nwpf > 0: w = win_stacked[i * nwpf : (i + 1) * nwpf] if not np.isnan(w[0, 0]): # iter col set → valid file win_df = pl.from_numpy(w, schema=win_schema).with_columns( [ pl.col("iter").cast(pl.Int64), pl.col("center").cast(pl.Int64), pl.col("window").cast(pl.Int64), ] ) # Reconstruct SNP DataFrame from dict of numpy arrays snp_raw = snp_list[i] snp_df = None if snp_raw is not None: snp_df = pl.DataFrame(snp_raw) if "iter" in snp_df.columns: snp_df = snp_df.with_columns(pl.col("iter").cast(pl.Int64)) if "positions" in snp_df.columns: snp_df = snp_df.with_columns( pl.col("positions").fill_nan(None).cast(pl.Int64) ) flat.append((snp_df, win_df)) _tmp_stats = tuple(zip(*flat)) stats_values, params, malformed = cleaning_summaries( _tmp_stats, params, sim_type ) malformed_files[sim_type] = malformed snps_list, windows_list = stats_values if sim_type == "neutral": norm_stats = [ {"snps": s, "windows": w} for s, w in zip(snps_list, windows_list) ] binned_data["neutral"] = binned_stats( *normalize_neutral(norm_stats, df_r_bins=tmp_r) ) # normalize_cut_raw expects {"snps": ..., "window": ...} per item raw_stats = [ {"snps": s, "window": w} for s, w in zip(snps_list, windows_list) ] if not np.all(params[:, 3] == 0): params[:, 0] = -np.log(params[:, 0]) results[sim_type] = summaries(raw_stats, params) if save_stats: save_pickle(f"{data_dir}/raw_statistics{suffix_str}.pickle", results) df_fv_cnn = {} df_fv_cnn_raw = {} with Parallel(n_jobs=nthreads, backend="loky", verbose=2) as parallel: for sim_type, stats_values in results.items(): df_w, df_w_raw = normalize_stats( stats_values, bins=binned_data["neutral"], region=None, center=center, windows=windows, step=step, parallel_manager=None, nthreads=nthreads, vcf=False, df_r_bins=df_r[sim_type], locus_length=locus_length, ) df_fv_cnn[sim_type] = df_w df_fv_cnn_raw[sim_type] = df_w_raw df_train = pl.concat(df_fv_cnn.values(), how="vertical") df_train_raw = pl.concat(df_fv_cnn_raw.values(), how="vertical") out_base = os.path.join(data_dir, f"fvs{suffix_str}.parquet") df_train.write_parquet(out_base) df_train_raw.write_parquet(out_base.replace(".parquet", "_raw.parquet")) with open(os.path.join(data_dir, f"neutral_bins{suffix_str}.pickle"), "wb") as f: pickle.dump(binned_data["neutral"], f) return df_train, df_train_raw
[docs] def summary_statistics( data_dir, vcf=False, nthreads=1, windows=[100000], step=1e5, step_vcf=1e4, locus_length=int(1.2e6), recombination_map=None, r_bins=None, min_rate=0.0, suffix=None, func=None, save_stats=False, only_normalize=False, stats=None, ): """ Compute summary statistics to create needed feature vectors for CNN training/prediction from either simulated or VCF/BCF input data. This function dispatches automatically to the appropriate backend depending on the ``vcf`` flag. When ``vcf=False`` (default) it processes mdiscoal simulations; when ``vcf=True`` it processes VCF/BCF files. :param str data_dir: Input directory. When ``vcf=False``, the root folder of simulation outputs containing subdirectories ``neutral/``, ``sweep/``, and ``params.txt.gz``. When ``vcf=True``, a directory of bgzipped VCF/BCF files. :param bool vcf: Whether to use the VCF/BCF processing pipeline (``True``) or the simulation pipeline (``False``). **Default:** ``False``. :param int nthreads: Number of threads. **Default:** ``1``. :param list windows: List of window sizes (in base pairs) to compute summary statistics. Supports multi-scale: e.g. ``[50000, 100000, 500000]``. **Default:** ``[100000]``. :param int step: Step size (in base pairs) for sliding windows. Together with ``locus_length``, determines the center positions: ``centers = range(step//2, locus_length - step//2 + step, step)``. **Default:** ``1e5``. :param int locus_length: Total locus length in base pairs. **Default:** ``1200000``. :param str recombination_map: Optional path to a recombination map. If ``None``, physical distances are used as a proxy for genetic distances. **Default:** ``None``. :param str suffix: Optional suffix appended to output file names. **Default:** ``None``. :param callable func: Function used internally for computing summary statistics per replicate or genomic window. See ``calculate_stats_simulations`` or ``calculate_stats_vcf_flat`` :returns: Feature-vector DataFrame combining all computed summary statistics. Downstream pipelines may also write Parquet and normalization artifacts (e.g., ``fvs.parquet``, ``empirical_bins.pickle``). :rtype: polars.DataFrame :raises FileNotFoundError: If input files or directories are missing. :raises ValueError: If an input file or directory is malformed or inconsistent. :notes: - Internally dispatches to :func:`_process_vcf` or :func:`_process_sims` depending on the ``vcf`` flag. - Center positions are derived automatically from ``locus_length`` and ``step`` as ``[step//2, locus_length - step//2]``. For each center, stats are computed at every size in ``windows``, producing ``n_centers × len(windows)`` rows per replicate. - Supports automatic normalization of features using empirical bins. - Designed for use as a top-level wrapper in the Flexsweep feature-vector generation pipeline. :examples: From simulated data (discoal/ms format): >>> df = summary_statistics("./simulations", nthreads=8) Multi-scale windows: >>> df = summary_statistics("./simulations", windows=[50000, 100000, 500000], nthreads=8) From VCF data with recombination map: >>> df = summary_statistics( ... "./vcf_data", ... vcf=True, ... recombination_map="recomb_map.csv", ... nthreads=8 ... ) """ if only_normalize: if vcf: return _normalize_vcf_stats( data_dir, nthreads, windows, step, locus_length, r_bins, suffix, ) else: return _normalize_sims_stats( data_dir, nthreads, windows, step, locus_length, r_bins, suffix, ) else: if vcf: if func is not None and stats is None: assert ( suffix is not None ), "You are using a custom function. Please input a suffix string to avoid feature vectors duplications" return _process_vcf( data_dir, nthreads, windows=windows, step=step, step_vcf=step_vcf, locus_length=locus_length, recombination_map=recombination_map, r_bins=r_bins, min_rate=min_rate, suffix=suffix, func=func, save_stats=save_stats, stats=stats, ) else: if func is not None: assert ( suffix is not None ), "You are using a custom function. Please input a suffix string to avoid feature vectors duplications" return _process_sims( data_dir, nthreads, windows=windows, step=step, r_bins=r_bins, suffix=suffix, func=func if func is not None else calculate_stats_simulations, save_stats=save_stats, locus_length=locus_length, stats=stats, )
################## Stats def run_fs_stats( hap, ac, rec_map, min_focal_freq=0.25, max_focal_freq=0.95, window_size=50000, hapdaf_o_max_ancest_freq=0.25, hapdaf_o_min_tot_freq=0.25, hapdaf_s_max_ancest_freq=0.10, hapdaf_s_min_tot_freq=0.10, _iter=None, ): """ Wrapper to extracts per-focal-SNP neighbor pairs via :func:`fast_sq_freq_pairs`, then estimate DIND, hapDAF-o/s, Sratio, highfreq and lowfreq statistics. Results are returned as four Polars DataFrames. :param numpy.ndarray hap: Haplotype matrix ``(n_snps, n_samples)`` with 0/1 values. :param numpy.ndarray ac: Allele counts ``(n_snps, 2)`` as ``[ancestral, derived]``. :param numpy.ndarray rec_map: Map array; penultimate column is the window coordinate. :param float min_focal_freq: Minimum focal derived frequency. Default ``0.25``. :param float max_focal_freq: Maximum focal derived frequency. Default ``0.95``. :param int window_size: Window size in coordinate physicial units extracted from ``rec_map[:, -2]``. Default ``50000``. :returns: Four DataFrames in order: ``df_dind_high_low``, ``df_s_ratio``, ``df_hapdaf_s``, ``df_hapdaf_o``. :rtype: tuple[polars.DataFrame, polars.DataFrame, polars.DataFrame, polars.DataFrame] """ sq_freqs, info, snps_indices = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=min_focal_freq, max_focal_freq=max_focal_freq, window_size=window_size, ) if info.shape[0] == 0: return fs_stats_dataframe(info, [], [], [], [], [], []) results_dind, results_high, results_low = dind_high_low_from_pairs(sq_freqs, info) results_s_ratio = s_ratio_from_pairs(sq_freqs) results_hapdaf_o = hapdaf_from_pairs( sq_freqs, hapdaf_o_max_ancest_freq, hapdaf_o_min_tot_freq ) results_hapdaf_s = hapdaf_from_pairs( sq_freqs, hapdaf_s_max_ancest_freq, hapdaf_s_min_tot_freq ) df_dind_high_low, df_s_ratio, df_hapdaf_o, df_hapdaf_s = fs_stats_dataframe( info, results_dind, results_high, results_low, results_s_ratio, results_hapdaf_o, results_hapdaf_s, _iter=_iter, ) return df_dind_high_low, df_s_ratio, df_hapdaf_o, df_hapdaf_s def run_windows_stat(_tmp_hap, _positions, _ac, _f, window_size=None, iu=None, ju=None): try: h12_v, h2_h1, h1_v, _, k_counts = garud_h_numba(_tmp_hap) # k_counts = np.unique(_tmp_hap, axis=1).shape[1] except Exception: h12_v, h2_h1, h1_v, k_counts = np.nan, np.nan, np.nan, np.nan zns_v, omega_max = Ld(_tmp_hap) ( _tajima_d, _theta_h, # theta_h (Fay-Wu θH, absolute) _h_raw, _h_norm, _pi, # pi absolute _theta_w, # theta_w absolute _theta_w_pb, # theta_w per-base _pi_pb, # pi per-base ) = neutrality_stats(_ac, _positions) # Override per-base values using window_size as denominator (matching fv_workstation) if window_size is not None and window_size > 0: _pi_pb = _pi / window_size _theta_w_pb = _theta_w / window_size # haf_v = haf_top(_tmp_hap.astype(np.int8), _positions) haf_v = haf_top(_tmp_hap, _positions) max_fda = _f.max() dists = ( pairwise_diffs_precomp(_tmp_hap, iu, ju, use_float32=False) if iu is not None else pairwise_diffs(_tmp_hap) ) if window_size is not None: dists = dists / window_size dist_var, dist_skew, dist_kurtosis = fast_skew_kurt(dists, bias=True) # Schema order: pi, tajima_d, theta_w, theta_h, k_counts, haf, # h1, h12, h2_h1, zns, omega_max, # max_fda, dist_var, dist_skew, dist_kurtosis return ( _pi_pb, _tajima_d, _theta_w_pb, _theta_h, k_counts, haf_v, h1_v, h12_v, h2_h1, zns_v, omega_max, max_fda, dist_var, dist_skew, dist_kurtosis, )
[docs] def calculate_stats_simulations( hap_data, _iter=1, center=None, windows=[100000], step=1e5, locus_length=int(1.2e6), stats=None, ): if center is None: center = [int(step // 2), int(locus_length - step // 2)] # ── resolve which stats to compute ────────────────────────────── window_cols, snp_groups, compute_isafe, snp_cols = resolve_stats(stats) try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = parse_ms_numpy(hap_data, seq_len=locus_length) freqs = np.ascontiguousarray(ac[:, 1] / ac.sum(axis=1), dtype=np.float64) if hap_int.shape[0] != rec_map_01.shape[0]: return None, None except Exception: return None, None # SNP-level stats (full chromosome) snp_dfs = [] if snp_groups or compute_isafe: _snp_window_size = int(np.asarray(windows).min()) // 2 if "fs" in snp_groups: df_dind_high_low, df_s_ratio, df_hapdaf_s, df_hapdaf_o = run_fs_stats( hap_int, ac, rec_map_01, window_size=_snp_window_size ) snp_dfs.extend( [ center_window_cols(df_dind_high_low, _iter=_iter), center_window_cols(df_s_ratio, _iter=_iter), center_window_cols(df_hapdaf_o, _iter=_iter), center_window_cols(df_hapdaf_s, _iter=_iter), ] ) if "ihs" in snp_groups: df_ihs = ihs_ihh( hap_int, position_masked, map_pos=genetic_position_masked, min_ehh=0.05 if locus_length > 1e6 else 0.1, min_maf=0.05, include_edges=False if locus_length > 1e6 else True, ) snp_dfs.append(center_window_cols(df_ihs, _iter=_iter)) if "nsl" in snp_groups: nsl_v = nsl(hap_int[freqs >= 0.05], use_threads=False) df_nsl = pl.DataFrame( { "positions": position_masked[freqs >= 0.05], "daf": freqs[freqs >= 0.05], "nsl": nsl_v, } ).fill_nan(None) snp_dfs.append(center_window_cols(df_nsl, _iter=_iter)) if "hscan" in snp_groups: pos_hscan, h_scores = hscan(hap_int, position_masked) df_hscan = pl.DataFrame( { "positions": pos_hscan.astype(np.int64), "daf": freqs, "hscan": h_scores, } ).fill_nan(None) snp_dfs.append(center_window_cols(df_hscan, _iter=_iter)) if compute_isafe: df_isafe = run_isafe(hap_int, position_masked) snp_dfs.append(center_window_cols(df_isafe, _iter=_iter)) if "beta" in snp_groups: df_beta = run_beta_window(ac, position_masked, w=_snp_window_size) df_beta = df_beta.rename({"t": "beta_t"}) daf_lookup = pl.DataFrame( { "positions": position_masked.astype(np.int64), "daf": freqs, } ) df_beta = df_beta.join(daf_lookup, on="positions", how="left") snp_dfs.append(center_window_cols(df_beta, _iter=_iter)) if snp_dfs: snps_joined = ( reduce( lambda left, right: left.join( right, on=["iter", "positions", "daf"], how="full", coalesce=True, ), [d.lazy() for d in snp_dfs], ) .sort("positions") .collect() ) # Filter to only requested SNP columns if snp_cols is not None: keep = {"iter", "positions", "daf"} | snp_cols snps_joined = snps_joined.select( [c for c in snps_joined.columns if c in keep] ) # 1a: convert to dict of numpy arrays for lightweight IPC snp_data = {col: snps_joined[col].to_numpy() for col in snps_joined.columns} else: snp_data = None # Window-level stats — return raw numpy array (no DataFrame in worker) if not window_cols: window_data = None else: if len(center) == 1: centers = np.arange(center[0], center[0] + step, step).astype(int) else: centers = np.arange(center[0], center[1] + step, step).astype(int) all_combos = list(product(centers, windows)) lowers = np.array([c - w // 2 for c, w in all_combos]) uppers = np.array([c + w // 2 for c, w in all_combos]) left_idxs = np.searchsorted(position_masked, lowers) right_idxs = np.searchsorted(position_masked, uppers, side="right") num_windows = len(all_combos) n_stat_cols = len(window_cols) iu, ju = np.triu_indices(hap_int.shape[1], k=1) hap_f = hap_int.astype(np.float64, copy=False) # When stats is None (default), use run_windows_stat directly if stats is None: results = np.full((num_windows, 3 + n_stat_cols), np.nan) for idx in range(num_windows): c, w = all_combos[idx] left = left_idxs[idx] right = right_idxs[idx] _tmp_hap = hap_f[left:right] if _tmp_hap.size == 0: results[idx, :3] = [_iter, c, w] continue _tmp_pos = position_masked[left:right] _ac = ac[left:right] _f = freqs[left:right] _windowed_stats = run_windows_stat( _tmp_hap, _tmp_pos, _ac, _f, window_size=w, iu=iu, ju=ju, ) results[idx, :3] = [int(_iter), c, w] results[idx, 3:] = _windowed_stats else: # Registry-driven: use compute_window_stats _stat_func = partial( compute_window_stats, stat_cols=window_cols, groups_needed=frozenset( WINDOW_STAT_REGISTRY[c][0] for c in window_cols ), ) results = np.full((num_windows, 3 + n_stat_cols), np.nan) for idx in range(num_windows): c, w = all_combos[idx] left = left_idxs[idx] right = right_idxs[idx] _tmp_hap = hap_f[left:right] if _tmp_hap.size == 0: results[idx, :3] = [_iter, c, w] continue _tmp_pos = position_masked[left:right] _ac = ac[left:right] _f = freqs[left:right] _windowed_stats = _stat_func( _tmp_hap, _tmp_pos, _ac, _f, window_size=w, iu=iu, ju=ju, ) results[idx, :3] = [int(_iter), c, w] results[idx, 3:] = _windowed_stats # 1a: return raw numpy — DataFrame reconstruction happens in _process_sims window_data = results return snp_data, window_data
def batch_simulations( batch_data, start_idx, func, center, windows, step, locus_length=int(1.2e6), stats=None, ): """1a+1b: workers return numpy; window arrays stacked per batch for one IPC transfer.""" snp_results = [] # list of (dict | None), one per file win_results = [] # list of (np.ndarray | None), one per file for i, hap_data in enumerate(batch_data, start=start_idx): try: out = func( hap_data, i, center=center, windows=windows, step=step, locus_length=locus_length, stats=stats, ) if out is None or not isinstance(out, (tuple, list)) or len(out) < 2: snp_results.append(None) win_results.append(None) else: snp_results.append(out[0]) win_results.append(out[1]) except Exception: snp_results.append(None) win_results.append(None) # 1b: stack window arrays into one contiguous block for efficient IPC valid_wins = [w for w in win_results if w is not None] if valid_wins: nwpf = valid_wins[0].shape[0] # num windows per file (constant) n_cols = valid_wins[0].shape[1] n_files = len(win_results) win_stacked = np.full((n_files * nwpf, n_cols), np.nan) for i, w in enumerate(win_results): if w is not None: win_stacked[i * nwpf : (i + 1) * nwpf] = w else: win_stacked = None nwpf = 0 return win_stacked, snp_results, nwpf # Schema constants — used by calculate_stats_vcf_flat. # To add a custom stat variant: define a new list here + a leaf run_windows_stat_* function. _WINDOW_STAT_COLS = [ "pi", "tajima_d", "theta_w", "theta_h", "k_counts", "haf", "h1", "h12", "h2_h1", "zns", "omega_max", "max_fda", "dist_var", "dist_skew", "dist_kurtosis", ] # Stat registries — map user-facing names to computation groups # Window stats: (group_key, index_in_group_return_tuple) WINDOW_STAT_REGISTRY = { "pi": ("neutrality", 7), # pi_per_base "tajima_d": ("neutrality", 0), "theta_w": ("neutrality", 6), # theta_w_per_base "theta_h": ("neutrality", 1), # Fay-Wu θH (absolute) "fay_wu_h": ("neutrality", 3), # normalized Fay & Wu's H (h_norm) "k_counts": ("garud", 4), "h1": ("garud", 2), "h12": ("garud", 0), "h2_h1": ("garud", 1), "zns": ("ld", 0), "omega_max": ("ld", 1), "haf": ("haf", 0), "max_fda": ("max_fda", 0), "dist_var": ("pairwise", 0), "dist_skew": ("pairwise", 1), "dist_kurtosis": ("pairwise", 2), # Extended SFS-based neutrality tests — computed per window via ext_neutrality group "achaz_y": ("ext_neutrality", 0), "zeng_e": ("ext_neutrality", 1), "fuli_f_star": ("ext_neutrality", 2), "fuli_f": ("ext_neutrality", 3), "fuli_d_star": ("ext_neutrality", 4), "fuli_d": ("ext_neutrality", 5), "achaz_y_star": ("ext_neutrality", 6), "achaz_t": ("ext_neutrality", 7), # Balancing selection — sliding sub-window mean per fixed window "ncd1": ("ncd1_win", 0), } # SNP stats: stat_name -> (group_key, actual_column_name) # The column name is what appears in the output DataFrame. SNP_STAT_REGISTRY = { "ihs": ("ihs", "ihs"), "delta_ihh": ("ihs", "delta_ihh"), "nsl": ("nsl", "nsl"), "isafe": ("isafe", "isafe"), "dind": ("fs", "dind"), "dind_high_low": ("fs", "dind"), # alias for dind "highfreq": ("fs", "high_freq"), "high_freq": ("fs", "high_freq"), "lowfreq": ("fs", "low_freq"), "low_freq": ("fs", "low_freq"), "s_ratio": ("fs", "s_ratio"), "hapdaf_o": ("fs", "hapdaf_o"), "hapdaf_s": ("fs", "hapdaf_s"), # Balancing selection — per-SNP beta statistic (Siewert & Voight 2020) "beta": ("beta", "beta"), "beta_t": ("beta", "beta_t"), # Haplotype homozygosity scan (Enard et al.) "hscan": ("hscan", "hscan"), } _ALL_VALID_STATS = sorted(set(WINDOW_STAT_REGISTRY) | set(SNP_STAT_REGISTRY))
[docs] def resolve_stats(stats): """Partition a user stat list into window_cols, snp_groups, compute_isafe, and snp_cols. Returns ``(window_cols, snp_groups, compute_isafe, snp_cols)``. ``snp_cols`` is the set of actual DataFrame column names to keep in SNP output, or ``None`` when stats is ``None`` (default = keep all). When stats is ``None``, returns defaults matching the full flexsweep pipeline. """ if stats is None: return list(_WINDOW_STAT_COLS), {"ihs", "nsl", "fs"}, True, None unknown = [ s for s in stats if s not in WINDOW_STAT_REGISTRY and s not in SNP_STAT_REGISTRY ] if unknown: raise ValueError(f"Unknown stat(s): {unknown}. Valid: {_ALL_VALID_STATS}") window_cols = [s for s in stats if s in WINDOW_STAT_REGISTRY] snp_groups = set() snp_cols = set() compute_isafe = False for s in stats: if s in SNP_STAT_REGISTRY: group, col = SNP_STAT_REGISTRY[s] if group == "isafe": compute_isafe = True else: snp_groups.add(group) snp_cols.add(col) return window_cols, snp_groups, compute_isafe, snp_cols
def compute_window_stats( hap, pos, ac, freqs, window_size=None, iu=None, ju=None, stat_cols=None, groups_needed=None, ): """Registry-driven window stat computation. Called in the inner center×window loop via functools.partial that bakes in stat_cols and groups_needed (both picklable → joblib-safe). """ group_results = {} if "neutrality" in groups_needed: _neut = list(neutrality_stats(ac, pos)) if window_size is not None and window_size > 0: _neut[6] = _neut[5] / window_size # theta_w_per_base _neut[7] = _neut[4] / window_size # pi_per_base group_results["neutrality"] = tuple(_neut) if "garud" in groups_needed: try: h12_v, h2_h1, h1_v, _, k_counts = garud_h_numba(hap) group_results["garud"] = (h12_v, h2_h1, h1_v, None, k_counts) except Exception: group_results["garud"] = (np.nan, np.nan, np.nan, None, np.nan) if "ld" in groups_needed: group_results["ld"] = Ld(hap) if "haf" in groups_needed: group_results["haf"] = (haf_top(hap, pos),) if "max_fda" in groups_needed: group_results["max_fda"] = (freqs.max(),) if "pairwise" in groups_needed: dists = ( pairwise_diffs_precomp(hap, iu, ju, use_float32=False) if iu is not None else pairwise_diffs(hap) ) if window_size is not None: dists = dists / window_size group_results["pairwise"] = fast_skew_kurt(dists, bias=True) if "ext_neutrality" in groups_needed: # All functions take ac (S×2 derived allele counts); SFS built internally. # fuli_f and fuli_d require polarized ac (derived allele in ac[:,1]). group_results["ext_neutrality"] = ( achaz_y(ac), # [0] Achaz's Y (polarized, excludes ξ₁) zeng_e(ac), # [1] Zeng's E fuli_f_star(ac), # [2] Fu & Li F* (no outgroup) fuli_f(ac), # [3] Fu & Li F (polarized) fuli_d_star(ac), # [4] Fu & Li D* (no outgroup) fuli_d(ac), # [5] Fu & Li D (polarized) achaz_y_star(ac), # [6] Achaz's Y* (folded, excludes η₁) achaz_t(ac), # [7] Achaz's T_Ω (exponential weights, bottleneck) ) if "ncd1_win" in groups_needed: # NCD1 computed on the window's SNP positions/frequencies; summarised as mean. if len(pos) >= 2: _ncd1 = ncd1(pos, freqs) group_results["ncd1_win"] = ( float(np.nanmean(_ncd1)) if len(_ncd1) > 0 else np.nan, ) else: group_results["ncd1_win"] = (np.nan,) out = [] for col in stat_cols: group_key, idx = WINDOW_STAT_REGISTRY[col] grp = group_results.get(group_key) out.append(grp[idx] if grp is not None else np.nan) return tuple(out) def run_isafe_region(hap_int, position_masked, start, end): """Compute iSAFE/SAFE scores for one non-overlapping chromosome region. Worker for calculate_stats_vcf_flat. Slices hap_int and position_masked to [start, end], calls run_isafe, and returns a Polars DataFrame with columns (positions, daf, isafe) carrying absolute physical positions. Returns None if fewer than 10 SNPs are present in the region. """ left = int(np.searchsorted(position_masked, start, side="left")) right = int(np.searchsorted(position_masked, end, side="right")) if right - left < 10: return None df = run_isafe( hap_int[left:right].astype(np.int8, copy=False), position_masked[left:right], ).fill_nan(None) if df.is_empty(): return None return df def batch_windowed_stats_flat( hap_f, ac, position_masked, combos, stat_func, stat_cols, ): """Worker for calculate_stats_vcf_flat. Computes stats for a batch of unique absolute (center, window_size) pairs directly on the chromosome hap matrix — no locus-window framing, no relative_position remapping. Each (abs_center, window_size) pair is computed exactly once regardless of how many locus windows it belongs to. Parameters ---------- hap_f : (S, N) float64 ndarray Full-chromosome haplotype matrix, pre-cast before joblib dispatch. ac : (S, 2) ndarray Allele counts array. position_masked : (S,) int32 ndarray Absolute physical positions after biallelic filtering. combos : list of (int, int) (abs_center, window_size) pairs to compute in this batch. stat_func : callable run_windows_stat or partial(compute_window_stats, ...). stat_cols : list of str Column names matching stat_func output length. Returns ------- list of (abs_center, window_size, stats_tuple) """ iu, ju = np.triu_indices(hap_f.shape[1], k=1) results = [] for abs_center, w in combos: left = np.searchsorted(position_masked, abs_center - w // 2, side="left") right = np.searchsorted(position_masked, abs_center + w // 2, side="right") _hap = hap_f[left:right] if _hap.size == 0: results.append((abs_center, w, (np.nan,) * len(stat_cols))) continue _ac = ac[left:right] _pos = position_masked[left:right] _f = _hap.mean(axis=1) sv = stat_func(_hap, _pos, _ac, _f, window_size=w, iu=iu, ju=ju) results.append((abs_center, w, sv)) return results
[docs] def calculate_stats_vcf_flat( vcf_file, region, center=None, windows=[100000], step=1e5, _iter=1, recombination_map=None, nthreads=1, locus_length=int(1.2e6), stats=None, # Legacy params — used only when stats is None stat_func=None, stat_cols=None, compute_snp_stats=True, parallel_manager=None, isafe_region_size=int(2e6), ): """Compute per-locus-window summary statistics from a VCF with O(N) window work. Instead of computing stats for every (locus_window × center) pair — which would recompute each overlapping sub-window up to numSubWins times — this function: 1. Enumerates the unique set of absolute (center, window_size) positions across all locus windows. 2. Dispatches one joblib task per unique pair (via batch_windowed_stats_flat). 3. iSAFE is fully supported via non-overlapping region tiling. It divides the chromosome into non-overlapping regions of isafe_region_size bp (default 2 Mb) and runs run_isafe once per region. SNP scores are absolute-position keyed so they join correctly to any locus window that contains them. 4. Assembles output rows by looking up cached results per locus window. Parameters ---------- isafe_region_size : int Size in bp of non-overlapping regions used to compute iSAFE. Valid range 1e6–5e6. Default 2e6. """ from . import Parallel, delayed if center is None: center = [int(step // 2), int(locus_length - step // 2)] filterwarnings( "ignore", category=RuntimeWarning, message="invalid value encountered in scalar divide", ) np.seterr(divide="ignore", invalid="ignore") # resolve which stats to compute if stats is not None: window_cols, snp_groups, compute_isafe, snp_cols = resolve_stats(stats) if window_cols: stat_func_inner = partial( compute_window_stats, stat_cols=window_cols, groups_needed=frozenset( WINDOW_STAT_REGISTRY[c][0] for c in window_cols ), ) else: stat_func_inner = run_windows_stat window_cols = list(_WINDOW_STAT_COLS) stat_cols_inner = window_cols else: stat_func_inner = stat_func if stat_func is not None else run_windows_stat stat_cols_inner = ( stat_cols if stat_cols is not None else list(_WINDOW_STAT_COLS) ) snp_groups = {"ihs", "nsl", "fs"} if compute_snp_stats else set() compute_isafe = compute_snp_stats snp_cols = None # read VCF try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = genome_reader(vcf_file, recombination_map=recombination_map, region=None) freqs = np.ascontiguousarray(ac[:, 1] / ac.sum(axis=1), dtype=np.float64) except Exception: return None if recombination_map is None: genetic_position_masked = None genomic_windows = np.asarray( [tuple(map(int, r.split(":")[-1].split("-"))) for r in region] ) nchr = region[0].split(":")[0] if len(center) == 1: centers = np.arange(center[0], center[0] + step, step).astype(int) else: centers = np.arange(center[0], center[1] + step, step).astype(int) # build joblib tasks tasks = [] snp_task_keys = [] if "fs" in snp_groups: tasks.append( delayed(run_fs_stats)( hap_int, ac, rec_map_01, window_size=int(np.asarray(windows).min()) // 2, ) ) snp_task_keys.append("fs") if "ihs" in snp_groups: tasks.append( delayed(ihs_ihh)( hap_int, position_masked, map_pos=genetic_position_masked, min_ehh=0.05 if locus_length > 1e6 else 0.1, min_maf=0.05, include_edges=False if locus_length > 1e6 else True, use_threads=True, ) ) snp_task_keys.append("ihs") if "nsl" in snp_groups: tasks.append(delayed(nsl)(hap_int[freqs >= 0.05], use_threads=True)) snp_task_keys.append("nsl") if "hscan" in snp_groups: tasks.append(delayed(hscan)(hap_int, position_masked)) snp_task_keys.append("hscan") if "beta" in snp_groups: _snp_window_size = int(np.asarray(windows).min()) // 2 tasks.append(delayed(run_beta_window)(ac, position_masked, w=_snp_window_size)) snp_task_keys.append("beta") # iSAFE: non-overlapping region tiles # Each tile is one joblib task; run_isafe_region returns (positions, daf, # isafe) with absolute coordinates. Tiles cover the full chromosome so # every SNP in any locus window gets a score. n_isafe_tasks = 0 if compute_isafe: chrom_start = int(position_masked[0]) chrom_end = int(position_masked[-1]) isafe_starts = list(range(chrom_start, chrom_end, int(isafe_region_size))) isafe_ends = [s + int(isafe_region_size) for s in isafe_starts] isafe_ends[-1] = chrom_end + 1 # ensure last SNP is included n_isafe_tasks = len(isafe_starts) for s, e in zip(isafe_starts, isafe_ends): tasks.append(delayed(run_isafe_region)(hap_int, position_masked, s, e)) # flat unique absolute (center, window_size) pairs # For locus window starting at locus_start, relative center c maps to # absolute position: locus_start + c - 1 unique_abs_centers = sorted( {int(lw[0]) + int(c) - 1 for lw in genomic_windows for c in centers} ) unique_combos = [(abs_c, w) for abs_c in unique_abs_centers for w in windows] hap_f = np.ascontiguousarray(hap_int.astype(np.float64)) chunk_size = max(1, ceil(len(unique_combos) / (nthreads * 2))) tasks.extend( delayed(batch_windowed_stats_flat)( hap_f, ac, position_masked, unique_combos[i : i + chunk_size], stat_func_inner, stat_cols_inner, ) for i in range(0, len(unique_combos), chunk_size) ) # execute if parallel_manager is not None: results = parallel_manager(tasks) else: with Parallel(n_jobs=nthreads, backend="loky", verbose=2) as parallel: results = parallel(tasks) # build stat cache from flat window results n_snp_tasks = len(snp_task_keys) n_pre_window_tasks = n_snp_tasks + n_isafe_tasks stat_cache = {} for batch in results[n_pre_window_tasks:]: for abs_c, w, sv in batch: stat_cache[(abs_c, w)] = sv # assemble output rows per locus window # Output layout: center (relative), window, stat_cols..., iter ("start-end") n_stat_cols = len(stat_cols_inner) all_combos = list(product(centers, windows)) num_combos = len(all_combos) out_window = [] for lw in genomic_windows: locus_start = int(lw[0]) iter_str = f"{lw[0]}-{lw[1]}" arr = np.full((num_combos, 2 + n_stat_cols), np.nan) for idx, (c, w) in enumerate(all_combos): abs_c = locus_start + int(c) - 1 arr[idx, 0] = c arr[idx, 1] = w sv = stat_cache.get((abs_c, w)) if sv is not None: arr[idx, 2:] = sv df = ( pl.from_numpy( arr, schema=["center", "window"] + list(stat_cols_inner), ) .with_columns( [ pl.col("center").cast(pl.Int64), pl.col("window").cast(pl.Int64), ] ) .with_columns(pl.lit(iter_str).alias("iter")) ) out_window.append(df) df_window_new = pl.concat(out_window, how="vertical") df_window_new = df_window_new.with_columns( (nchr + ":" + pl.col("iter")).alias("iter") ) # unpack SNP results snp_results = {key: results[i] for i, key in enumerate(snp_task_keys)} snp_dfs = [] if "fs" in snp_results: df_dind_high_low, df_s_ratio, df_hapdaf_s, df_hapdaf_o = snp_results["fs"] df_dind_high_low = center_window_cols(df_dind_high_low, _iter=_iter) df_s_ratio = center_window_cols(df_s_ratio, _iter=_iter) df_hapdaf_o = center_window_cols(df_hapdaf_o, _iter=_iter) df_hapdaf_s = center_window_cols(df_hapdaf_s, _iter=_iter) snp_dfs.extend([df_dind_high_low, df_s_ratio, df_hapdaf_o, df_hapdaf_s]) if "ihs" in snp_results: snp_dfs.append(center_window_cols(snp_results["ihs"], _iter=_iter)) if "nsl" in snp_results: nsl_v = snp_results["nsl"] df_nsl = pl.DataFrame( { "positions": position_masked[freqs >= 0.05], "daf": freqs[freqs >= 0.05], "nsl": nsl_v, } ).fill_nan(None) snp_dfs.append(center_window_cols(df_nsl, _iter=_iter)) if "hscan" in snp_results: pos_hscan, h_scores = snp_results["hscan"] df_hscan = pl.DataFrame( { "positions": pos_hscan.astype(np.int64), "daf": freqs, "hscan": h_scores, } ).fill_nan(None) snp_dfs.append(center_window_cols(df_hscan, _iter=_iter)) if "beta" in snp_results: df_beta = snp_results["beta"].rename({"t": "beta_t"}) daf_lookup = pl.DataFrame( { "positions": position_masked.astype(np.int64), "daf": freqs, } ) df_beta = df_beta.join(daf_lookup, on="positions", how="left") snp_dfs.append(center_window_cols(df_beta, _iter=_iter)) # unpack iSAFE results # Tile results carry absolute positions; concat directly — no mean # aggregation needed since each SNP appears in exactly one tile. # iSAFE output schema: (iter, positions, daf, isafe). if compute_isafe and n_isafe_tasks > 0: isafe_tile_results = results[n_snp_tasks : n_snp_tasks + n_isafe_tasks] valid_tiles = [r for r in isafe_tile_results if r is not None] if valid_tiles: df_isafe = ( pl.concat(valid_tiles, how="vertical") .with_columns(pl.col("positions").cast(pl.Int64)) .sort("positions") ) df_isafe = df_isafe.with_columns(pl.lit(_iter).cast(pl.Int64).alias("iter")) snp_dfs.append(df_isafe) if snp_dfs: df_stats_norm = ( reduce( lambda left, right: left.join( right, on=["iter", "positions", "daf"], how="full", coalesce=True, ), [d.lazy() for d in snp_dfs], ) .sort("positions") .collect() ).with_columns(pl.lit(nchr).cast(pl.Utf8).alias("iter")) _to_drop = [c for c in ("h12", "haf") if c in df_stats_norm.columns] if _to_drop: df_stats_norm = df_stats_norm.drop(_to_drop) if snp_cols is not None: keep = {"iter", "positions", "daf"} | snp_cols df_stats_norm = df_stats_norm.select( [c for c in df_stats_norm.columns if c in keep] ) else: df_stats_norm = None return df_stats_norm, df_window_new
################## Normalization @njit(cache=True) def relative_position(positions, window): return positions - window[0] + 1 def cut_snps(df, centers, windows, stats_names, fixed_center=None, iter_value=1): """ Processes data within windows across multiple centers and window sizes. Parameters ---------- normalized_df : polars.DataFrame DataFrame containing the positions and statistics. iter_value : int Iteration or replicate number. centers : list List of center positions to analyze. windows : list List of window sizes to use. stats_names : list, optional Names of statistical columns to compute means for. If None, all columns except position-related ones will be used. position_col : str, optional Name of the column containing position values. center_col : str, optional Name of the column containing center values. fixed_center : int, optional If provided, use this fixed center value instead of the ones in centers list. Returns ------- polars.DataFrame DataFrame with aggregated statistics for each center and window. """ # If stats_names not provided, use all appropriate columns # reset centers if centers is None: centers = np.arange(5e5, 7e5 + 1e4, 1e4).astype(int) sim_mid = 6e5 else: sim_mid = (centers[0] + centers[-1]) // 2 centers = np.asarray(centers).astype(int) if fixed_center is not None: centers_abs = np.array([fixed_center + c - sim_mid for c in centers]).astype( int ) else: centers_abs = centers results = [] # out = [] for c, w in list(product(centers_abs, windows)): query = df.lazy() # HUGE BUG, REPEATING ACTUAL CENTER/WINDOW VALUES BASED ON ALL CENTERS SIZE # 1.2MB simulations derives into 21 center/windows combinations # if fixed_center is not None: # c_fix = fixed_center - c # else: # c_fix = c if fixed_center is not None: c_sim = int(c - fixed_center + sim_mid) else: c_sim = int(c) # Filter data by center and window boundaries # Define window boundaries lower = c - (w // 2) upper = c + (w // 2) window_data = query.filter( (pl.col("positions") >= lower) & (pl.col("positions") <= upper) ) # Calculate mean statistics for window window_stats = window_data.select(stats_names).fill_nan(None).mean().collect() # Add metadata columns metadata_cols = [ pl.lit(iter_value).alias("iter"), pl.lit(c_sim).cast(pl.Int64).alias("center"), pl.lit(int(w)).cast(pl.Int64).alias("window"), ] results.append(window_stats.with_columns(metadata_cols)) return ( pl.concat(results, how="vertical").select( ["iter", "center", "window"] + stats_names ) if results else None ) def bin_values(values, freq=0.02): """ Bins allele frequency data into discrete frequency intervals (bins) for further analysis. This function takes a DataFrame containing a column of derived allele frequencies ("daf") and bins these values into specified frequency intervals. The resulting DataFrame will contain a new column, "freq_bins", which indicates the frequency bin for each data point. Parameters ---------- values : pandas.DataFrame A DataFrame containing at least a column labeled "daf", which represents the derived allele frequency for each variant. freq : float, optional (default=0.02) The width of the frequency bins. This value determines how the frequency range (0, 1) is divided into discrete bins. For example, a value of 0.02 will create bins such as [0, 0.02], (0.02, 0.04], ..., [0.98, 1.0]. Returns ------- values_copy : pandas.DataFrame A copy of the original DataFrame, with an additional column "freq_bins" that contains the frequency bin label for each variant. The "freq_bins" are categorical values based on the derived allele frequencies. Notes ----- - The `pd.cut` function is used to bin the derived allele frequencies into intervals. - The bins are inclusive of the lowest boundary (`include_lowest=True`) to ensure that values exactly at the boundary are included in the corresponding bin. - The resulting bins are labeled as strings with a precision of two decimal places. """ # Modify the copy values_copy = pl.concat( [ values, values["daf"] .cut(np.arange(0, 1 + freq, freq)) .to_frame() .rename({"daf": "freq_bins"}), ], how="horizontal", ) try: return values_copy.sort("iter", "positions") except Exception: return values_copy.sort("chr", "positions") def _parse_breaks_from_rbins(df_r_bins): # expects categorical interval strings like "(a, b]" uniq = df_r_bins["r_bins"].unique().sort() breaks = [float(re.search(r",\s*([0-9.]+)\]$", s).group(1)) for s in uniq] return breaks def snps_to_r_bins( snps_df, df_r_bins_windows, mode="nearest_center", ): """ Returns snps_df with added columns: cm_mb, r_bins (then you can drop cm_mb if you want). mode: - 'nearest_center': pick window whose center is closest to SNP position - 'mean_overlap': mean cm_mb across all overlapping windows - 'max_overlap': max cm_mb across all overlapping windows """ breaks = _parse_breaks_from_rbins(df_r_bins_windows) snps = snps_df.with_columns( [pl.col("chr").cast(pl.Categorical), pl.col("positions").cast(pl.Int64)] ) windows = df_r_bins_windows.with_columns( [ pl.col("chr").cast(pl.Categorical), pl.col("start").cast(pl.Int64), pl.col("end").cast(pl.Int64), pl.col("cm_mb").cast(pl.Float32), ] ).sort(["chr", "start"]) snps_parts = snps.partition_by("chr", as_dict=True) win_parts = windows.partition_by("chr", as_dict=True) annotated_parts = [] for chrom, snps_chrom in snps_parts.items(): w = win_parts.get(chrom) if w is None or w.height == 0: continue pos = ( snps_chrom.select(pl.col("positions").unique().sort()) .to_series() .to_numpy() ) pos = pos.astype(np.int64, copy=False) w_start = w["start"].to_numpy().astype(np.int64, copy=False) w_end = w["end"].to_numpy().astype(np.int64, copy=False) w_cm = w["cm_mb"].to_numpy().astype(np.float32, copy=False) if mode == "nearest_center": w_center = (w_start + w_end) // 2 idx = np.searchsorted(w_center, pos, side="left") idx = np.clip(idx, 1, len(w_center) - 1) left = idx - 1 right = idx choose_right = np.abs(w_center[right] - pos) < np.abs(w_center[left] - pos) best = np.where(choose_right, right, left) cm_assigned = w_cm[best] elif mode in ("mean_overlap", "max_overlap"): # windows with start <= p are [0:right) right = np.searchsorted(w_start, pos, side="right") # overlapping also needs end > p -> left boundary in w_end left = np.searchsorted(w_end, pos, side="left") cm_assigned = np.full(pos.shape[0], np.nan, dtype=np.float32) for i in range(pos.shape[0]): l_idx = left[i] r_idx = right[i] if l_idx >= r_idx: continue vals = w_cm[l_idx:r_idx] cm_assigned[i] = ( float(vals.mean()) if mode == "mean_overlap" else float(vals.max()) ) else: raise ValueError(f"Unknown mode: {mode}") tmp = pl.DataFrame( { "chr": np.repeat( chrom[0] if isinstance(chrom, tuple) else chrom, pos.size ), "positions": pos, "cm_mb": cm_assigned, }, schema_overrides={ "chr": pl.Categorical, "positions": pl.Int64, "cm_mb": pl.Float32, }, ).with_columns(pl.col("cm_mb").cut(breaks=breaks).alias("r_bins")) annotated_parts.append(tmp) matches = pl.concat(annotated_parts, rechunk=False) if annotated_parts else None if matches is None: return snps_df return snps.join(matches, on=["chr", "positions"], how="left")
[docs] def normalize_neutral(d_stats_neutral, vcf=False, df_r_bins=None): """ Calculates the expected mean and standard deviation of summary statistics from neutral simulations, used for normalization in downstream analyses. This function processes a DataFrame of neutral simulation statistics, bins the values based on frequency, and computes the mean (expected) and standard deviation for each bin. These statistics are used as a baseline to normalize sweep or neutral simulations Parameters ---------- df_stats_neutral : list or pandas.DataFrame A list or concatenated pandas DataFrame containing the neutral simulation statistics. The DataFrame should contain frequency data and various summary statistics, including H12 and HAF, across multiple windows and bins. Returns ------- expected : pandas.DataFrame A DataFrame containing the mean (expected) values of the summary statistics for each frequency bin. The frequency bins are the index, and the columns are the summary statistics. stdev : pandas.DataFrame A DataFrame containing the standard deviation of the summary statistics for each frequency bin. The frequency bins are the index, and the columns are the summary statistics. Notes ----- - The function first concatenates the neutral statistics, if provided as a list, and bins the values by frequency using the `bin_values` function. - It computes both the mean and standard deviation for each frequency bin, which can later be used to normalize observed statistics (e.g., from sweeps). - The summary statistics processed exclude window-specific statistics such as "h12" and "haf." """ snps_list = [d["snps"] for d in d_stats_neutral if d["snps"] is not None] windows_list = [d["windows"] for d in d_stats_neutral if d["windows"] is not None] try: if not snps_list: raise ValueError("No SNP stats available (window-only mode)") tmp_neutral_snps = pl.concat(snps_list, how="vertical", rechunk=False) # For VCF, associate each SNP to the nearest window center (default) to assign r_bins if df_r_bins is not None: if vcf: tmp_neutral_snps = tmp_neutral_snps.rename({"iter": "chr"}) df_r_bins_w = ( df_r_bins.with_columns( [ pl.col("iter").str.extract(r"^([^:]+)", 1).alias("chr"), pl.col("iter") .str.extract(r":(\d+)-", 1) .cast(pl.Int64) .alias("start"), pl.col("iter") .str.extract(r"-(\d+)$", 1) .cast(pl.Int64) .alias("end"), ] ) .select(["chr", "start", "end", "cm_mb", "r_bins", "iter"]) .sort(["chr", "start"]) ) tmp_neutral_snps = snps_to_r_bins(tmp_neutral_snps, df_r_bins_w) else: tmp_neutral_snps = tmp_neutral_snps.join( df_r_bins, on="iter", how="left" ) df_binned = bin_values(tmp_neutral_snps).fill_nan(None) group_keys = ["freq_bins"] + ( ["r_bins"] if (df_r_bins is not None and "r_bins" in df_binned.columns) else [] ) stat_cols = [ c for c in df_binned.columns if c not in ("iter", "chr", "positions", "daf", "freq_bins", "r_bins", "cm_mb") ] expected = ( df_binned.group_by(group_keys) .agg(pl.col(stat_cols).mean()) .sort(group_keys) .fill_nan(None) ) stdev = ( df_binned.group_by(group_keys) .agg(pl.col(stat_cols).std()) .sort(group_keys) .fill_nan(None) ) except Exception as e: print(f"normalize_neutral SNP bins failed: {e}") expected, stdev = None, None try: if not windows_list: raise ValueError("No window stats available (SNP-only mode)") df_window = pl.concat(windows_list, rechunk=False).fill_nan(None) if df_r_bins is not None: df_window = df_window.join( df_r_bins.select(pl.exclude("cm_mb")), on="iter", how="left" ) group_w = ["center", "window"] + ( ["r_bins"] if (df_r_bins is not None and "r_bins" in df_window.columns) else [] ) # exclude iter + keys from aggregation win_stat_cols = [ c for c in df_window.columns if c not in ("iter", "center", "window", "r_bins") ] df_window_mean = ( df_window.group_by(group_w).agg(pl.col(win_stat_cols).mean()).sort(group_w) ) df_window_std = ( df_window.group_by(group_w).agg(pl.col(win_stat_cols).std()).sort(group_w) ) except Exception as e: print(f"normalize_neutral window bins failed: {e}") df_window_mean = None df_window_std = None return ([expected, df_window_mean], [stdev, df_window_std])
[docs] def normalize_stats( stats_values, bins, region=None, center=[5e5, 7e5], windows=[50000, 100000, 200000, 500000, 1000000], step=1e4, parallel_manager=None, nthreads=1, vcf=False, df_r_bins=None, locus_length=int(1.2e6), ): df_fv, df_fv_raw = normalization_raw( deepcopy(stats_values), bins, region=region, center=center, windows=windows, step=step, parallel_manager=parallel_manager, nthreads=nthreads, vcf=vcf, df_r_bins=df_r_bins, locus_length=locus_length, ) df_fv_w = pivot_feature_vectors(df_fv, vcf=vcf) df_fv_w_raw = pivot_feature_vectors(df_fv_raw, vcf=vcf) df_fv_w = df_fv_w.fill_nan(None) num_nans = ( df_fv_w.select(pl.exclude(["iter", "s", "t", "f_i", "f_t", "mu", "r", "model"])) .transpose() .null_count() .to_numpy() .flatten() ) df_fv_w = df_fv_w.filter( num_nans < int( df_fv_w.select( pl.exclude(["iter", "s", "t", "f_i", "f_t", "r", "mu", "model"]) ).shape[1] * 0.1 ) ).fill_null(0) if not vcf: df_fv_w.sort(["iter", "model"]) df_fv_w_raw = df_fv_w_raw.fill_nan(None) return df_fv_w, df_fv_w_raw
def batch_normalize_cut_raw(batch_data, bins, center, windows, step, df_r_bins): """Process a batch of normalize_cut_raw calls.""" results_norm = [] results_raw = [] for snps_values in batch_data: try: df_norm, df_raw = normalize_cut_raw( snps_values, bins, center, windows, step, df_r_bins ) results_norm.append(df_norm) results_raw.append(df_raw) except Exception as e: print(f"Error in normalize_cut_raw: {e}") results_norm.append(None) results_raw.append(None) return results_norm, results_raw def batch_cut_snps(batch_data, centers, windows, stats_names): """ CHANGE: - Preserve output length and ordering. - Fail fast if anything goes wrong (to avoid silent normalized/raw scrambling). """ results = [] for df, coord, iter_val in batch_data: # let exceptions propagate (or raise a clearer one) results.append( cut_snps( df, centers, windows, stats_names, fixed_center=coord, iter_value=iter_val, ) ) return results def normalization_raw( stats_values, bins, region=None, center=[5e5, 7e5], step=1e4, windows=[50000, 100000, 200000, 500000, 1000000], vcf=False, df_r_bins=None, nthreads=1, parallel_manager=None, locus_length=int(1.2e6), ): from . import Parallel, delayed df_stats, params = stats_values center = np.asarray(center).astype(int) windows = np.asarray(windows).astype(int) if vcf: nchr = region[0].split(":")[0] center_coords = [tuple(map(int, r.split(":")[-1].split("-"))) for r in region] center_g = np.array([(a + b) // 2 for a, b in center_coords]) df_window = df_stats.get("window") has_snps = df_stats.get("snps") is not None if has_snps: snps_values = df_stats["snps"].sort(["iter", "positions"]) stats_names = [ c for c in snps_values.columns if c not in ("iter", "positions", "daf") ] try: if df_r_bins is not None: breaks = [ float(re.search(r",\s*([0-9.]+)\]$", s).group(1)) for s in df_r_bins["r_bins"].unique().sort() ] tmp_neutral_snps = snps_values.rename({"iter": "chr"}).with_columns( pl.col("chr").cast(pl.Categorical) ) nchr = tmp_neutral_snps["chr"].unique().item() df_r_bins_w = ( ( df_r_bins.with_columns( [ pl.col("iter") .str.extract(r"^([^:]+)", 1) .alias("chr"), pl.col("iter") .str.extract(r":(\d+)-", 1) .cast(pl.Int64) .alias("start"), pl.col("iter") .str.extract(r"-(\d+)$", 1) .cast(pl.Int64) .alias("end"), ] ) .select(["chr", "start", "end", "cm_mb", "r_bins", "iter"]) .sort(["chr", "start"]) ) .filter(pl.col("chr") == nchr) .sort("start") ) # unique positions reduces work; join back later pos = ( tmp_neutral_snps.select(pl.col("positions").unique().sort()) .to_series() .to_numpy() ) pos = pos.astype(np.int64, copy=False) # w = df_r_bins_w.sort("start") w_start = ( df_r_bins_w["start"].to_numpy().astype(np.int64, copy=False) ) w_end = df_r_bins_w["end"].to_numpy().astype(np.int64, copy=False) w_cm = ( df_r_bins_w["cm_mb"].to_numpy().astype(np.float32, copy=False) ) w_center = (w_start + w_end) // 2 idx = np.searchsorted(w_center, pos, side="left") idx = np.clip(idx, 1, len(w_center) - 1) left = idx - 1 right = idx choose_right = np.abs(w_center[right] - pos) < np.abs( w_center[left] - pos ) best = np.where(choose_right, right, left) cm_assigned = w_cm[best] tmp = pl.DataFrame( { "chr": np.repeat( nchr if isinstance(nchr, tuple) else nchr, pos.size ), "positions": pos, "cm_mb": cm_assigned, }, schema_overrides={ "chr": pl.Categorical, "positions": pl.Int64, "cm_mb": pl.Float32, }, ).with_columns(pl.col("cm_mb").cut(breaks=breaks).alias("r_bins")) snps_values = ( tmp_neutral_snps.join(tmp, on=["chr", "positions"], how="left") if tmp is not None else tmp_neutral_snps ).select(pl.exclude("cm_mb")) binned_values = ( bin_values(snps_values) .sort("positions") .select(pl.exclude("chr")) ) df_window = df_window.join(df_r_bins, on=["iter"]).select( pl.exclude("cm_mb") ) else: binned_values = bin_values(snps_values) except Exception: stats_names = None binned_values = None normalized_df, normalized_window = normalize_snps_statistics( binned_values, df_window, bins, stats_names ) # centers range if len(center) == 2: centers = np.arange(center[0], center[1] + step, step).astype(int) else: centers = center left_idxs = np.searchsorted( normalized_df["positions"].to_numpy(), center_g - (center[0] + center[-1]) // 2, side="left", ) right_idxs = np.searchsorted( normalized_df["positions"].to_numpy(), center_g + (center[0] + center[-1]) // 2, side="right", ) tmp_normalized = [ normalized_df.slice(start, end - start) for start, end in zip(left_idxs, right_idxs) ] tmp_raw = [ binned_values.slice(start, end - start) for start, end in zip(left_idxs, right_idxs) ] # CHANGE: keep deterministic ordering: [normalized..., raw...] all_data = [ (df, coord, coord) for df, coord in zip(tmp_normalized, center_g) ] + [(df, coord, coord) for df, coord in zip(tmp_raw, center_g)] batch_size = max(100, len(all_data) // (max(nthreads, 1) * 2)) batches = [ all_data[i : i + batch_size] for i in range(0, len(all_data), batch_size) ] _cut_tasks = ( delayed(batch_cut_snps)(batch, centers, windows, stats_names) for batch in batches ) if parallel_manager is not None: batch_results = parallel_manager(_cut_tasks) else: with Parallel(n_jobs=nthreads, backend="loky", verbose=2) as parallel: batch_results = parallel(_cut_tasks) # CHANGE: flatten WITHOUT dropping anything out_cut = [item for batch in batch_results for item in batch] # CHANGE: enforce invariant (prevents silent scrambling) expected_len = 2 * center_g.size if len(out_cut) != expected_len: raise RuntimeError( f"cut_snps produced {len(out_cut)} results; expected {expected_len}. " "Do not drop/skip items before splitting normalized/raw." ) if any(x is None for x in out_cut): raise RuntimeError( "cut_snps returned None for at least one element; " "refuse to continue because it scrambles normalized/raw alignment." ) df_fv_n = pl.concat(out_cut[: center_g.size]) df_fv_n_raw = pl.concat(out_cut[center_g.size :]) _half = locus_length // 2 df_fv_n = df_fv_n.with_columns( ( f"{nchr}:" + (pl.col("iter") - _half + 1).cast(int).cast(str) + "-" + (pl.col("iter") + _half).cast(int).cast(str) ).alias("iter") ) df_fv_n_raw = df_fv_n_raw.with_columns( ( f"{nchr}:" + (pl.col("iter") - _half + 1).cast(int).cast(str) + "-" + (pl.col("iter") + _half).cast(int).cast(str) ).alias("iter") ) # window joins unchanged if normalized_window is not None: df_fv_n = df_fv_n.join( normalized_window, on=["iter", "center", "window"], how="full", coalesce=True, ) if df_window is not None: df_fv_n_raw = df_fv_n_raw.join( df_window, on=["iter", "center", "window"], how="full", coalesce=True, ) else: # No SNP stats — window-only path normalized_window = None if df_window is not None and bins is not None: # Mirror the has_snps path: join r_bins onto df_window before # normalization so window bins are conditioned on r correctly. _df_window_norm = df_window if df_r_bins is not None: _df_window_norm = _df_window_norm.join( df_r_bins.select(pl.exclude("cm_mb")), on="iter", how="left" ) normalized_window = normalize_snps_statistics( None, _df_window_norm, bins, None )[1] if normalized_window is not None: df_fv_n = normalized_window elif df_window is not None: df_fv_n = df_window else: _half = locus_length // 2 iter_labels = [ f"{nchr}:{int(cg - _half + 1)}-{int(cg + _half)}" for cg in center_g ] df_fv_n = pl.DataFrame({"iter": iter_labels}) df_fv_n_raw = df_window if df_window is not None else df_fv_n df_params_unpack = pl.DataFrame( np.repeat( params, df_fv_n.select(["center", "window"]) .unique() .sort(["center", "window"]) .shape[0], axis=0, ), schema=["s", "t", "f_i", "f_t", "mu", "r"], ) df_fv_n = pl.concat([df_params_unpack, df_fv_n], how="horizontal") df_fv_n_raw = pl.concat([df_params_unpack, df_fv_n_raw], how="horizontal") force_order = ["iter"] + [col for col in df_fv_n.columns if col != "iter"] df_fv_n = df_fv_n.select(force_order) force_order_raw = ["iter"] + [ col for col in df_fv_n_raw.columns if col != "iter" ] df_fv_n_raw = df_fv_n_raw.select(force_order_raw) return df_fv_n, df_fv_n_raw else: batch_size = max(100, len(df_stats) // nthreads) batches = [ df_stats[i : i + batch_size] for i in range(0, len(df_stats), batch_size) ] if parallel_manager is None: batch_results = Parallel(n_jobs=nthreads, verbose=10)( delayed(batch_normalize_cut_raw)( batch, bins, center, windows, step, df_r_bins ) for batch in batches ) else: batch_results = parallel_manager( delayed(batch_normalize_cut_raw)( batch, bins, center, windows, step, df_r_bins ) for batch in batches ) # Flatten the batched results df_fv_n_l = [ item for batch_norm, _ in batch_results for item in batch_norm if item is not None ] df_fv_n_l_raw = [ item for _, batch_raw in batch_results for item in batch_raw if item is not None ] df_fv_n = pl.concat(df_fv_n_l).with_columns( pl.col(["iter", "window", "center"]).cast(pl.Int64) ) df_fv_n_raw = pl.concat(df_fv_n_l_raw).with_columns( pl.col(["iter", "window", "center"]).cast(pl.Int64) ) df_params_unpack = pl.DataFrame( np.repeat( params, df_fv_n.select(["center", "window"]) .unique() .sort(["center", "window"]) .shape[0], axis=0, ), schema=["s", "t", "f_i", "f_t", "mu", "r"], ) df_fv_n = pl.concat( [df_params_unpack, df_fv_n], how="horizontal", ) df_fv_n_raw = pl.concat( [df_params_unpack, df_fv_n_raw], how="horizontal", ) force_order = ["iter"] + [col for col in df_fv_n.columns if col != "iter"] df_fv_n = df_fv_n.select(force_order) force_order_raw = ["iter"] + [col for col in df_fv_n_raw.columns if col != "iter"] df_fv_n_raw = df_fv_n_raw.select(force_order_raw) return df_fv_n, df_fv_n_raw
[docs] def normalize_cut_raw( snps_values, bins, center=[5e5, 7e5], windows=[50000, 100000, 200000, 500000, 1000000], step=int(1e4), df_r_bins=None, ): """ Sims-only refactor: - join df_r onto SNP and window tables (by iter) BEFORE binning/normalization - normalize using neutral bins conditional on (freq_bins, r_bins) when available - drop r_bins from outputs so feature dimension stays unchanged """ # Get _iter from whichever data is non-None if snps_values["snps"] is not None: _iter = snps_values["snps"]["iter"].unique().to_numpy() elif snps_values["window"] is not None: _iter = snps_values["window"]["iter"].unique().to_numpy() else: return None, None if len(center) == 2: centers = np.arange(center[0], center[1] + step, step).astype(int) else: centers = center # Merge SNP-level stats (already pre-joined in calculate_stats_simulations) if snps_values["snps"] is not None: try: df = snps_values["snps"] # attach r_bins per replicate if df_r_bins is not None: df = df.join(df_r_bins, on="iter", how="left") # stats columns only (exclude keys + bins) stats_names = [ c for c in df.columns if c not in ("iter", "positions", "daf", "freq_bins", "r_bins") ] binned_values = bin_values(df) except Exception as e: print(f"normalize_cut_raw SNP merge/bin failed: {e}") df, binned_values, stats_names = None, None, None else: df, binned_values, stats_names = None, None, None # Window-level stats if snps_values["window"] is not None: try: df_window = snps_values["window"].select(pl.exclude("positions")) if df_r_bins is not None: df_window = df_window.join(df_r_bins, on="iter", how="left") except Exception as e: print(f"normalize_cut_raw window join failed: {e}") df_window = None else: df_window = None normalized_df, normalized_window = normalize_snps_statistics( binned_values, df_window, bins, stats_names ) # Drop r_bins from normalized artifacts to keep downstream feature schema unchanged if normalized_df is not None and "r_bins" in normalized_df.columns: normalized_df = normalized_df.drop("r_bins") if normalized_window is not None and "r_bins" in normalized_window.columns: normalized_window = normalized_window.drop("r_bins") if df_window is not None and "r_bins" in df_window.columns: df_window = df_window.drop("r_bins") if normalized_df is not None and normalized_window is not None: df_out = cut_snps( normalized_df, centers, windows, stats_names, fixed_center=None, iter_value=_iter, ) df_out_raw = cut_snps( df, centers, windows, stats_names, fixed_center=None, iter_value=_iter ) df_out = df_out.join( normalized_window, on=["iter", "center", "window"], how="full", coalesce=True, ) df_out_raw = df_out_raw.join( df_window, on=["iter", "center", "window"], how="full", coalesce=True, ) elif normalized_df is not None and normalized_window is None: df_out = cut_snps( normalized_df, centers, windows, stats_names, fixed_center=None, iter_value=_iter, ) df_out_raw = cut_snps( df, centers, windows, stats_names, fixed_center=None, iter_value=_iter ) elif normalized_df is None and normalized_window is not None: df_out = normalized_window df_out_raw = df_window else: df_out, df_out_raw = None, None _CANONICAL_STAT_ORDER = [ "iter", "center", "window", "dind", "high_freq", "low_freq", "s_ratio", "hapdaf_s", "hapdaf_o", "ihs", "delta_ihh", "isafe", "nsl", "pi", "tajima_d", "theta_w", "theta_h", "k_counts", "haf", "h1", "h12", "h2_h1", "zns", "omega_max", "max_fda", "dist_var", "dist_skew", "dist_kurtosis", ] def _reorder(df): if df is None: return df present = [c for c in _CANONICAL_STAT_ORDER if c in df.columns] return df.select(present) return _reorder(df_out), _reorder(df_out_raw)
def normalize_snps_statistics(df_snps, df_window, bins, stats_names, dps_shape=False): # SNP normalization if df_snps is not None: snp_join = ["freq_bins"] # keep r_bins logic; for your target case r_bins absent if ( "r_bins" in df_snps.columns and bins.mean[0] is not None and "r_bins" in bins.mean[0].columns ): snp_join.append("r_bins") neutral_means = bins.mean[0].select(snp_join + stats_names) neutral_stds = bins.std[0].select(snp_join + stats_names) normalized_df = ( df_snps.join( neutral_means, on=snp_join, how="left", coalesce=True, suffix="_mean_neutral", ) .join( neutral_stds, on=snp_join, how="left", coalesce=True, suffix="_std_neutral", ) .fill_nan(None) ) normalized_df = normalized_df.with_columns( [ # pl.when( # pl.col(f"{s}_std_neutral").is_null() # | (pl.col(f"{s}_std_neutral") == 0) # ) # .then(pl.lit(0.0)) # .otherwise ( (pl.col(s) - pl.col(f"{s}_mean_neutral")) / pl.col(f"{s}_std_neutral") ).alias(s) for s in stats_names ] ).select(["positions"] + stats_names) else: normalized_df = None # window normalization: keep your join-based approach if you want, # but to match original results best, also remove the std==0 clamp here. if df_window is None: return (normalized_df, None) win_join = ["center", "window"] if ( "r_bins" in df_window.columns and bins.mean[1] is not None and "r_bins" in bins.mean[1].columns ): win_join.append("r_bins") exclude_cols = {"iter", "center", "window"} if "r_bins" in df_window.columns: exclude_cols.add("r_bins") stats_windowed_all = [c for c in df_window.columns if c not in exclude_cols] df_window_z = ( df_window.join( bins.mean[1].select([c for c in bins.mean[1].columns if c != "iter"]), on=win_join, how="left", suffix="_mean", ).join( bins.std[1].select([c for c in bins.std[1].columns if c != "iter"]), on=win_join, how="left", suffix="_std", ) ).with_columns( [ # pl.when(pl.col(f"{c}_std").is_null() | (pl.col(f"{c}_std") == 0)) # .then(pl.lit(0.0)) # .otherwise( ((pl.col(c) - pl.col(f"{c}_mean")) / pl.col(f"{c}_std")).alias(c) for c in stats_windowed_all ] ) keep_base = ["iter", "center", "window"] + ( ["r_bins"] if "r_bins" in df_window.columns else [] ) df_window_z = df_window_z.select(keep_base + stats_windowed_all) return (normalized_df, df_window_z.select(pl.exclude("r_bins"))) ################## Haplotype structure stats
[docs] def ihs_ihh( h, pos, map_pos=None, min_ehh=0.05, min_maf=0.05, include_edges=False, gap_scale=20000, max_gap=200000, is_accessible=None, use_threads=False, ): """ Compute iHS (integrated Haplotype Score) and delta iHH from haplotypes. The routine integrates EHH (extended haplotype homozygosity) on both sides of each focal SNP to obtain iHH for ancestral and derived alleles, and then reports iHS (log ratio) and the absolute difference in iHH (``delta_ihh``). :param numpy.ndarray h: Haplotype matrix of shape ``(n_snps, n_haplotypes)`` with 0/1 values, where rows are SNPs and columns are haplotypes. :param numpy.ndarray pos: Physical positions for SNPs (length ``n_snps``). Used for gap handling and, when ``map_pos`` is ``None``, for integration spacing. :param numpy.ndarray map_pos: Optional genetic map positions (same length as ``pos``). If provided, integration uses these coordinates instead of ``pos``. Default ``None``. :param float min_ehh: Minimum EHH value to include in the integration. Default ``0.05``. :param float min_maf: Minimum minor-allele frequency required to compute iHS at a SNP. Default ``0.05``. :param bool include_edges: If ``True``, permit edge SNPs to contribute even when EHH dips below ``min_ehh``. Default ``False``. :param int gap_scale: Scaling used for gaps between consecutive SNPs when integrating over physical distance (ignored if ``map_pos`` is provided). Default ``20000``. :param int max_gap: Maximum gap allowed when integrating; larger gaps are capped to ``max_gap`` to avoid overweighting sparse regions. Default ``200000``. :param numpy.ndarray is_accessible: Optional boolean mask (length ``n_snps``) indicating accessible SNPs. If ``None``, all SNPs are considered accessible. Default ``None``. :param bool use_threads: Enable threaded computation in downstream primitives when available. Default ``False``. :returns: Polars DataFrame with columns: ``positions`` (physical position), ``daf`` (derived allele frequency), ``ihs`` (log iHH ratio), and ``delta_ihh`` (absolute difference between derived and ancestral iHH). :rtype: polars.DataFrame :raises ValueError: Propagated if inputs are inconsistent in length or malformed. .. note:: SNPs that fail the MAF threshold or have invalid iHS values are omitted from the returned table. """ # check inputs h = asarray_ndim(h, 2) check_integer_dtype(h) pos = asarray_ndim(pos, 1) check_dim0_aligned(h, pos) h = memoryview_safe(h) pos = memoryview_safe(pos) # compute gaps between variants for integration gaps = compute_ihh_gaps(pos, map_pos, gap_scale, max_gap, is_accessible) # setup kwargs kwargs = dict(min_ehh=min_ehh, min_maf=min_maf, include_edges=include_edges) if use_threads: # run with threads # create pool pool = ThreadPool(2) # scan forward result_fwd = pool.apply_async(ihh01_scan, (h, gaps), kwargs) # scan backward result_rev = pool.apply_async(ihh01_scan, (h[::-1], gaps[::-1]), kwargs) # wait for both to finish pool.close() pool.join() # obtain results ihh0_fwd, ihh1_fwd = result_fwd.get() ihh0_rev, ihh1_rev = result_rev.get() # cleanup pool.terminate() else: # run without threads # scan forward ihh0_fwd, ihh1_fwd = ihh01_scan(h, gaps, **kwargs) # scan backward ihh0_rev, ihh1_rev = ihh01_scan(h[::-1], gaps[::-1], **kwargs) # compute unstandardized score ihh0 = ihh0_fwd + ihh0_rev ihh1 = ihh1_fwd + ihh1_rev # og estimation with np.errstate(divide="ignore", invalid="ignore"): ihs = np.log(ihh0 / ihh1) # mask = (ihh1 != 0) & (ihh0 > 0) & (ihh1 > 0) # ihs = np.full_like(ihh0, np.nan, dtype=float) # ihs[mask] = np.log(ihh0[mask] / ihh1[mask]) delta_ihh = np.abs(ihh1 - ihh0) df_ihs = ( pl.DataFrame( { "positions": pos, "daf": h.sum(axis=1) / h.shape[1], "ihs": ihs, "delta_ihh": delta_ihh, } ) .fill_nan(None) .drop_nulls() ) df_ihs = df_ihs.filter(~pl.col("ihs").is_infinite()) return df_ihs
[docs] def haf_top(hap, pos, cutoff=0.1, start=None, stop=None, window_size=None, n_snps=None): """ Compute the upper-tail HAF (Haplotype Allele Frequency) summary in a region. Rows of ``hap`` are SNPs, columns are haplotypes. HAF values are computed per SNP from haplotypes, then restricted to the specified genomic region (``start``/``stop`` or ``window_size``) if given. The HAF values are sorted and the top portion after trimming by ``cutoff`` is summed. :param numpy.ndarray hap: Haplotype matrix of shape ``(n_snps, n_haplotypes)`` with 0/1 values. :param numpy.ndarray pos: Physical positions for SNPs (length ``n_snps``). :param float cutoff: Proportion used for tail trimming. For example, ``0.1`` trims the lowest 10% and the highest 10% before summing the remaining HAF values. Default ``0.1``. :param float start: Optional region start position (inclusive). Default ``None``. :param float stop: Optional region end position (inclusive). Default ``None``. :param int window_size: Optional window size in base pairs centered by the caller’s convention. If provided, it can be used to define the region when ``start``/``stop`` are not specified. Default ``None``. :param int n_snps: Optional limit on the number of SNPs considered by certain strategies (implementation-dependent). Default ``1001``. :returns: Upper-tail HAF summary as a single float after trimming by ``cutoff``. :rtype: float :raises ValueError: Propagated if inputs are malformed or if no SNPs fall within the region. .. note:: If neither ``start``/``stop`` nor ``window_size`` is provided, the computation uses all SNPs in ``hap``/``pos``. """ if start is not None or stop is not None: loc = (pos >= start) & (pos <= stop) pos = pos[loc] hap = hap[loc, :] elif window_size is not None: loc = (pos >= (6e5 - window_size // 2)) & (pos <= (6e5 + window_size // 2)) hap = hap[loc, :] elif n_snps is not None: S, N = hap.shape # if (N >= 50 and N < 100): # n_snps = 401 # elif N < 50: # n_snps = 201 closer_center_snp = np.argmin(np.abs(pos - 6e5)) loc = np.arange( max(closer_center_snp - n_snps // 2, 0), min(closer_center_snp + n_snps // 2 + 1, pos.size), ) hap = hap[loc, :] freqs = hap.sum(axis=1) / hap.shape[1] hap_tmp = hap.astype(np.float64, copy=False)[(freqs > 0) & (freqs < 1)] haf_num = (np.dot(hap_tmp.T, hap_tmp) / hap.shape[1]).sum(axis=1) # haf_num = (jax_dot(hap_tmp.T) / hap.shape[1]).sum(axis=1) haf_den = hap_tmp.sum(axis=0) # haf = np.sort(haf_num / haf_den) if 0 in haf_den: mask_zeros = haf_den != 0 haf = np.full_like(haf_num, np.nan, dtype=np.float64) haf[mask_zeros] = haf_num[mask_zeros] / haf_den[mask_zeros] haf = np.sort(haf) else: haf = np.sort(haf_num / haf_den) if cutoff <= 0 or cutoff >= 1: cutoff = 1 # idx_low = int(cutoff * haf.size) idx_high = int((1 - cutoff) * haf.size) # 10% higher return np.nansum(haf[idx_high:])
@njit(cache=True) def fast_skew_kurt(data, bias=False): """Single-pass numba variance/skew/kurtosis. bias=False matches scipy.stats defaults. Returns (variance, skewness, kurtosis).""" n = len(data) if n < 4: return 0.0, 0.0, 0.0 mu = 0.0 for x in data: mu += x mu /= n m2 = m3 = m4 = 0.0 for x in data: diff = x - mu d2 = diff * diff m2 += d2 m3 += d2 * diff m4 += d2 * d2 m2 /= n m3 /= n m4 /= n if m2 <= (1.1e-16 * abs(mu)) ** 2 or m2 < 1e-20: return 0.0, np.nan, np.nan g1 = m3 / (m2**1.5) g2 = (m4 / (m2**2)) - 3.0 if bias: return (m2 * n / (n - 1.0)), g1, g2 skew_u = (np.sqrt(n * (n - 1.0)) / (n - 2.0)) * g1 kurt_u = (n - 1.0) / ((n - 2.0) * (n - 3.0)) * ((n + 1.0) * g2 + 6.0) var_u = m2 * n / (n - 1.0) return var_u, skew_u, kurt_u # @njit(cache=True) # def garud_h_numba(h): # """ # Compute Garud’s haplotype homozygosity statistics in Numba. # The input is a binary haplotype matrix with shape ``(L, n)``, where ``L`` is # the number of variant sites (rows) and ``n`` is the number of haplotypes # (columns). The function counts distinct haplotypes (columns), converts those # counts to frequencies :math:`p_i`, sorts them descending to obtain # :math:`p_1 \\ge p_2 \\ge p_3 \\ge \\dots`, and computes: # - :math:`H1 = \\sum_i p_i^2` # - :math:`H12 = (p_1 + p_2)^2 + \\sum_{i\\ge 3} p_i^2` # - :math:`H123 = (p_1 + p_2 + p_3)^2 + \\sum_{i\\ge 4} p_i^2` # - :math:`H2/H1 = (H1 - p_1^2) / H1` # :param numpy.ndarray h: # 2D array of dtype ``uint8`` with values in ``{0, 1}`` and shape # ``(n_variants, n_haplotypes)``. # :returns: # Tuple ``(H12, H2_H1, H1, H123)`` as floats. # :rtype: tuple[float, float, float, float] # """ # L, n = h.shape # # rolling uint64 hash to count distinct columns # counts = Dict.empty(key_type=uint64, value_type=int64) # for j in range(n): # hsh = np.uint64(146527) # for i in range(L): # hsh = (hsh * np.uint64(1000003)) ^ np.uint64(np.int64(h[i, j])) # counts[hsh] = counts.get(hsh, 0) + 1 # # collect counts into an array # m = len(counts) # cnts = np.empty(m, np.int64) # idx = 0 # for k in counts: # cnts[idx] = counts[k] # idx += 1 # # 3) to frequencies & sort descending # freqs = cnts.astype(np.float64) / n # freqs = np.sort(freqs)[::-1] # # pad top‐3 # p1 = freqs[0] if freqs.size > 0 else 0.0 # p2 = freqs[1] if freqs.size > 1 else 0.0 # p3 = freqs[2] if freqs.size > 2 else 0.0 # # compute H1, H12, H123, H2/H1 # H1 = 0.0 # for i in range(freqs.size): # H1 += freqs[i] * freqs[i] # H12 = (p1 + p2) ** 2 # for i in range(2, freqs.size): # H12 += freqs[i] * freqs[i] # H123 = (p1 + p2 + p3) ** 2 # for i in range(3, freqs.size): # H123 += freqs[i] * freqs[i] # H2 = H1 - p1**2 # H2_H1 = H2 / H1 # return H12, H2_H1, H1, H123, m @njit(cache=True) def garud_h_numba(h): L, n = h.shape if n == 0: return 0.0, 0.0, 0.0, 0.0, 0.0 # 1. Compute rolling hashes into a flat array # This avoids the Dict.empty() allocation bottleneck hashes = np.empty(n, dtype=np.uint64) for j in range(n): hsh = np.uint64(146527) for i in range(L): # Using bit-shifting for hash is often faster than multiplication on EPYC hsh = (hsh ^ np.uint64(h[i, j])) * np.uint64(1000003) hashes[j] = hsh # 2. Sort hashes to group identical haplotypes hashes.sort() # 3. Count frequencies in a single pass over sorted hashes # This replaces the need for a Dictionary freq_list = [] current_count = 1 for i in range(1, n): if hashes[i] == hashes[i - 1]: current_count += 1 else: freq_list.append(current_count / float(n)) current_count = 1 freq_list.append(current_count / float(n)) # Convert to array and sort descending freqs = np.sort(np.array(freq_list))[::-1] m = float(len(freqs)) p1 = freqs[0] if freqs.size > 0 else 0.0 p2 = freqs[1] if freqs.size > 1 else 0.0 p3 = freqs[2] if freqs.size > 2 else 0.0 # H1, H12, H123 in one pass H1 = 0.0 sq_sum_from_3 = 0.0 sq_sum_from_4 = 0.0 for i in range(freqs.size): f_sq = freqs[i] * freqs[i] H1 += f_sq if i >= 2: sq_sum_from_3 += f_sq if i >= 3: sq_sum_from_4 += f_sq H12 = (p1 + p2) ** 2 + sq_sum_from_3 H123 = (p1 + p2 + p3) ** 2 + sq_sum_from_4 H2 = H1 - p1**2 H2_H1 = H2 / H1 if H1 > 0 else 0.0 return H12, H2_H1, H1, H123, m
[docs] def garud_h(h): """Compute the H1, H12, H123 and H2/H1 statistics for detecting signatures of soft sweeps, as defined in Garud et al. (2015). Parameters ---------- h : array_like, int, shape (n_variants, n_haplotypes) Haplotype array. Returns ------- h1 : float H1 statistic (sum of squares of haplotype frequencies). h12 : float H12 statistic (sum of squares of haplotype frequencies, combining the two most common haplotypes into a single frequency). h123 : float H123 statistic (sum of squares of haplotype frequencies, combining the three most common haplotypes into a single frequency). h2_h1 : float H2/H1 statistic, indicating the "softness" of a sweep. """ from allel import HaplotypeArray # check inputs h = HaplotypeArray(h, copy=False) # compute haplotype frequencies f = h.distinct_frequencies() # compute H1 h1 = np.sum(f**2) # compute H12 h12 = np.sum(f[:2]) ** 2 + np.sum(f[2:] ** 2) # compute H123 h123 = np.sum(f[:3]) ** 2 + np.sum(f[3:] ** 2) # compute H2/H1 h2 = h1 - f[0] ** 2 h2_h1 = h2 / h1 return h12, h2_h1, h1, h123, f.size
@njit(cache=True) def comparen_haplos_optimized(haplo1, haplo2): identical = 0 different = 0 for i in range(len(haplo1)): h1 = haplo1[i] h2 = haplo2[i] if (h1 == 1) and (h2 == 1): identical += 1 elif h1 != h2: different += 1 total = identical + different return identical, different, total @njit(cache=True) def _compare_hap_cols(H, col_a, col_b): """ identical = #(1,1) different = #(mismatch: 1,0 or 0,1) total = identical + different """ L = H.shape[0] identical = 0 different = 0 for i in range(L): a = H[i, col_a] b = H[i, col_b] if (a == 1) and (b == 1): identical += 1 elif a != b: different += 1 total = identical + different return identical, different, total @njit(cache=True) def _legacy_row_order_indices(hap, positions, focal_coord, window_size, min_freq): """ Returns row indices exactly in the order legacy visits sites: - 100bp bins: left first (-1,-2,...) then right (+1,+2,...) - exclude focal bin (step==0) - freq filter in [min_freq, 1.0] - right cap: sup_i < 1_200_000 Implementation is O(#rows_in_window): two passes + bucketize by step. """ L_total, n = hap.shape int_coord = (focal_coord // 100) * 100 half = window_size // 2 low = int_coord - half high = int_coord + half max_steps = half // 100 # pass 1: collect "ok" rows and their step # store into temporary arrays ok_idx = np.empty(L_total, dtype=int64) ok_step = np.empty(L_total, dtype=int64) k = 0 for i in range(L_total): pos = positions[i] if pos == focal_coord: continue if (pos < low) or (pos > high): continue # derived freq s = 0 for j in range(n): s += hap[i, j] f = s / n if (f < min_freq) or (f > 1.0): continue # step in 100bp units bin_i = (pos // 100) * 100 step = (bin_i - int_coord) // 100 if step == 0: continue # legacy skips focal bin ok_idx[k] = i ok_step[k] = step k += 1 if k == 0: return np.empty(0, dtype=int64) # pass 2: bucketize by step with stability # LEFT left_limit = (int_coord - 1) // 100 # require inf_i > 0 left_counts = np.zeros(max_steps, dtype=int64) # index = abs(step)-1 # RIGHT (with cap sup_i < 1_200_000) right_cap = (1_200_000 - int_coord - 1) // 100 right_counts = np.zeros(max_steps, dtype=int64) # index = step-1 # count per step for t in range(k): step = ok_step[t] if step < 0: st = -step if (st <= max_steps) and (st <= left_limit): left_counts[st - 1] += 1 else: # step > 0 st = step if (st <= max_steps) and (st <= right_cap): right_counts[st - 1] += 1 left_total = int(left_counts.sum()) right_total = int(right_counts.sum()) out = np.empty(left_total + right_total, dtype=int64) # prefix sums to place indices in bin order (preserve ok order within step) left_off = np.zeros(max_steps, dtype=int64) right_off = np.zeros(max_steps, dtype=int64) # compute running offsets acc = 0 for s in range(max_steps): left_off[s] = acc acc += left_counts[s] left_end = acc acc = 0 for s in range(max_steps): right_off[s] = acc acc += right_counts[s] # pass 3: fill left, then right # LEFT fill in step order 1..max_steps for t in range(k): step = ok_step[t] if step < 0: st = -step if (st <= max_steps) and (st <= left_limit): pos = left_off[st - 1] out[pos] = ok_idx[t] left_off[st - 1] = pos + 1 # RIGHT fill appended after left block, in step order 1..max_steps base = left_end for t in range(k): step = ok_step[t] if step > 0: st = step if (st <= max_steps) and (st <= right_cap): pos = base + right_off[st - 1] out[pos] = ok_idx[t] right_off[st - 1] = right_off[st - 1] + 1 return out @njit(cache=True) def _unique_hash_counts_reprs_and_assign(H): """ H: (L, n) uint8 Returns: cnts: (m,) int64 counts per unique haplotype reprj: (m,) int64 representative column index for each unique hap assign: (n,) int64 sample -> uid """ L, n = H.shape counts = Dict.empty(key_type=uint64, value_type=int64) reprs = Dict.empty(key_type=uint64, value_type=int64) hashes = np.empty(n, dtype=uint64) for j in range(n): hsh = np.uint64(146527) for i in range(L): hsh = (hsh * np.uint64(1000003)) ^ np.uint64(H[i, j]) hashes[j] = hsh counts[hsh] = counts.get(hsh, 0) + 1 if hsh not in reprs: reprs[hsh] = j m = len(counts) cnts = np.empty(m, dtype=int64) reprj = np.empty(m, dtype=int64) key2id = Dict.empty(key_type=uint64, value_type=int64) k = 0 for hsh in counts: cnts[k] = counts[hsh] reprj[k] = reprs[hsh] key2id[hsh] = k k += 1 assign = np.empty(n, dtype=int64) for j in range(n): assign[j] = key2id[hashes[j]] return cnts, reprj, assign @njit(cache=True, inline="always") def _lex_less_cols(H, reprj, uid_a, uid_b): """ True if col(uid_a) < col(uid_b) lexicographically Equivalent to comparing legacy "0 1 1 ..." strings (spaces don't affect order). """ L = H.shape[0] ca = reprj[uid_a] cb = reprj[uid_b] for i in range(L): va = H[i, ca] vb = H[i, cb] if va < vb: return True elif va > vb: return False return False # equal @njit(cache=True) def _argsort_by_count_then_lex(cnts, H, reprj): """ Return indices 0..m-1 sorted by (count desc, lex asc on H[:, reprj[uid]]). Implemented via two-stage: count sort (argsort desc) then in-place segment lex insertion-sort per count tier. """ m = cnts.size order = np.argsort(cnts) # ascending # reverse to descending for i in range(m // 2): tmp = order[i] order[i] = order[m - 1 - i] order[m - 1 - i] = tmp # walk segments of equal count and lex-sort within each i = 0 while i < m: c = cnts[order[i]] j = i + 1 while j < m and cnts[order[j]] == c: j += 1 # insertion sort order[i:j] by lex k = i + 1 while k < j: x = order[k] p = k - 1 while (p >= i) and _lex_less_cols(H, reprj, x, order[p]): order[p + 1] = order[p] p -= 1 order[p + 1] = x k += 1 i = j return order
[docs] @njit(cache=True) def h12_enard( hap, positions, focal_coord=600000, n_snps=None, window_size=int(5e5), min_derived_freq=0.05, similarity_threshold=0.8, top_k=10, ): """ Estimate Garud's ``H12, H2/H1, H1`` around a focal coordinate, grouping haplotypes that are at least a given identity threshold (default 80%). The method builds a count-based, symmetric SNP window centered at ``focal_coord``, constructs a haplotype matrix ``H`` for the selected SNPs, collapses identical haplotypes (columns), orders the unique haplotypes by frequency (descending) and lexicographic order (ascending), selects a set of representative haplotypes, and then **merges representatives into similarity groups** whenever the **column-wise identity** meets or exceeds ``similarity_threshold`` (``0.8`` by default). Haplotype group frequencies are then used to compute the H12 family of statistics. Identity between two haplotype columns is defined as: .. math:: \\text{identity} = \\frac{\\#(1,1)}{\\#(1,1) + \\#(1,0) + \\#(0,1)} (i.e., matches on the derived allele over all non-equal-or-derived comparisons). :param numpy.ndarray hap: Haplotype matrix of shape ``(L_total, n)`` with 0/1 values (ancestral/derived). Rows are SNPs; columns are haplotypes. :param numpy.ndarray positions: 1D array (length ``L_total``) of genomic coordinates (``int64``) aligned to ``hap`` rows. :param int focal_coord: Genomic coordinate used to center the SNP window. Default ``600000``. :param int n_snps: Target number of SNPs for the window (the focal SNP, if present, is excluded from the returned set). Default ``1001``. :param float min_derived_freq: Minimum derived-allele frequency required for a SNP to enter the window (inclusive; upper bound is ``1.0``). Default ``0.05``. :param float similarity_threshold: Column similarity threshold for grouping haplotypes. Two representative columns are merged if their identity fraction (formula above) is **≥ this value**. Default ``0.8`` (80% identity). :param int top_k: Limit controlling how many unique-haplotype representatives are considered before grouping. Default ``10``. :returns: Tuple ``(H12, H2_H1, H2)`` as floats. If no usable SNPs or groups are found, returns ``(0.0, 0.0, 0.0)``. :rtype: tuple[float, float, float] .. math:: H1 = \\sum_g p_g^2,\\qquad H12 = (p_1 + p_2)^2,\\qquad H2 = H1 - p_1^2,\\qquad \\frac{H2}{H1} = \\begin{cases} (H1 - p_1^2)/H1, & H1 \\ne 0,\\\\ 0, & H1 = 0~. \\end{cases} Notes ----- - The SNP window is built by balancing sites to the left and right of ``focal_coord`` by proximity, after applying the derived-frequency filter ``[min_derived_freq, 1.0]``, and excluding the focal position itself. - Unique haplotypes are detected via hashing of ``H`` columns; sample-to-group frequencies :math:`p_g` are computed after the identity-based grouping step. - The default behavior corresponds to **H12 with 80% identity grouping** in the haplotype matrix, which can increase robustness by merging highly similar haplotypes. """ L_total, n = hap.shape # Change n_snps dinamically maximizing power based on Zhao et al. 2024 # if (n >= 50 and n < 100): # n_snps = 401 # elif n < 50: # n_snps = 201 # 1) legacy row order rows = _legacy_row_order_indices( hap, positions, np.int64(focal_coord), # np.int64(n_snps), np.int64(window_size), np.float64(min_derived_freq), ) if rows.size == 0: return 0.0, 0.0, 0.0, 0.0 # 2) window matrix in exact order L = rows.size H = np.empty((L, n), dtype=np.uint8) for r in range(L): i = rows[r] for j in range(n): H[r, j] = hap[i, j] # 3) unique haplotypes + per-sample assignment cnts, reprj, assign = _unique_hash_counts_reprs_and_assign(H) m = cnts.size if m == 0: return 0.0, 0.0, 0.0, 0.0 # 4) global order by (count desc, lex asc) and apply legacy accumulator order = _argsort_by_count_then_lex(cnts, H, reprj) # length m # accumulator quirk done_rev = 0 counter_rev = 0 # we don't know K a priori due to accumulator → collect into temporary array best_mask = np.zeros(m, dtype=np.uint8) # iterate by count tiers (segments of same count) i = 0 while i < m: c = cnts[order[i]] j = i + 1 while (j < m) and (cnts[order[j]] == c): j += 1 # add whole tier (i..j-1) in lex order for t in range(i, j): best_mask[order[t]] = 1 done_rev += 1 counter_rev += done_rev if counter_rev >= top_k: break i = j # collect selected uids in the *selection order* induced by `order` keep_m = 0 for t in range(order.size): u = order[t] if best_mask[u] == 1: keep_m += 1 sel = np.empty(keep_m, dtype=int64) w = 0 for t in range(order.size): u = order[t] if best_mask[u] == 1: sel[w] = u w += 1 # 5) forward-only similarity on selected; exclude-as-you-go groups # group_of[a] = group id; insertion order defines p1, p2,... group_of = np.full(keep_m, -1, dtype=int64) groups = 0 thr = similarity_threshold for ai in range(keep_m): if group_of[ai] != -1: continue g = groups groups += 1 group_of[ai] = g ua_col = reprj[sel[ai]] for bi in range(ai + 1, keep_m): if group_of[bi] != -1: continue ub_col = reprj[sel[bi]] ident, diff, tot = _compare_hap_cols(H, ua_col, ub_col) if tot == 0: continue if (ident / tot) >= thr: group_of[bi] = g # 6) uid -> group id map (only for selected) uid_to_group = np.full(m, -1, dtype=int64) for ai in range(keep_m): uid_to_group[sel[ai]] = group_of[ai] # 7) count only selected uids; denominator is n if groups == 0: return 0.0, 0.0, 0.0, 0.0 freq_counts = np.zeros(groups, dtype=int64) for j in range(n): uid = assign[j] g = uid_to_group[uid] if g != -1: freq_counts[g] += 1 # 8) legacy stats (NO resort) toto = float(n) # insertion order is 0..groups-1 by construction H1 = 0.0 p1 = freq_counts[0] / toto if groups > 0 else 0.0 p2 = freq_counts[1] / toto if groups > 1 else 0.0 # accumulate H1 for g in range(groups): f = freq_counts[g] / toto H1 += f * f H12 = (p1 + p2) * (p1 + p2) H2 = H1 - p1 * p1 H2_H1 = (H2 / H1) if H1 != 0.0 else 0.0 return H12, H2_H1, H1, H2
def pairwise_diffs(hap, missing=False): """ Pairwise mismatch counts between samples (columns) for a site x sample matrix. Parameters ---------- hap : (S, n) array-like, integer/boolean Haplotype/allele matrix with samples in columns. - If missing=False: entries must be exactly {0,1}. - If missing=True: entries may be {-1,0,1}, where -1 is treated as missing. missing : bool, default False If False, assumes data are strictly 0/1 and uses a single dot product. If True, treats -1 as missing and uses a mask for two dot product. Returns ------- diff_ls : (n*(n-1)//2,) float64 Pairwise mismatch counts in i<j order (same as the C implementation). """ S, n = hap.shape Y = hap.astype(np.float64, copy=False) s = Y.sum(axis=0) # (n, n) dot products G = Y.T @ Y # (n, n) Hamming counts D = s[None, :] + s[:, None] - 2.0 * G # Missing data # Let M = 1 on valid (0/1), Y = 1 on allele==1. # D = (Y^T M) + (Y^T M)^T - 2*(Y^T Y) # counts 10 + 01 only on jointly valid sites. # M = ((hap == 0) | (hap == 1)).astype(ftype) # Y = (hap == 1).astype(ftype) # A = Y.T @ M # C = Y.T @ Y # D = A + A.T - 2.0 * C # Extract upper triangle (i<j) iu, ju = np.triu_indices(n, k=1) return D[iu, ju].astype(np.float64, copy=False) def pairwise_diffs_precomp(hap, iu, ju, use_float32=False): """Pairwise mismatch counts reusing pre-computed triu indices. Call np.triu_indices once before the window loop, pass here each iteration. For best performance, pre-cast hap to float64 (or float32 if use_float32=True) before the window loop — the astype(copy=False) call becomes a no-op on slices.""" dtype = np.float32 if use_float32 else np.float64 Y = hap.astype(dtype, copy=False) s = Y.sum(axis=0) G = Y.T @ Y D = s[None, :] + s[:, None] - 2.0 * G return D[iu, ju].astype(np.float64, copy=False) @njit(parallel=False, cache=True) def _hscan_single(hap, pos, x, max_gap, dist_mode): S, N = hap.shape h_sum = float64(0.0) for i in prange(N - 1): for j in range(i + 1, N): # Extend RIGHT x_r = x while ( x_r + 1 < S and (pos[x_r + 1] - pos[x_r]) < max_gap and hap[x_r + 1, i] == hap[x_r + 1, j] ): x_r += 1 # Extend LEFT x_l = x while ( x_l - 1 >= 0 and hap[x_l - 1, i] == hap[x_l - 1, j] and (pos[x_l] - pos[x_l - 1]) < max_gap ): x_l -= 1 if x_r != x_l: if dist_mode == 0: h_sum += float64(pos[x_r] - pos[x_l] - 1) elif dist_mode == 1: h_sum += float64(x_r - x_l - 1) else: h_sum += float64(pos[x_r] - pos[x_l] - 1) * float64(x_r - x_l - 1) return 2.0 * h_sum / (float64(N) * float64(N - 1)) @njit(cache=True) def _hscan_all(hap, pos, indices, max_gap, dist_mode): n_focal = len(indices) h_means = np.empty(n_focal, dtype=float64) for k in range(n_focal): h_means[k] = _hscan_single(hap, pos, indices[k], max_gap, dist_mode) return h_means
[docs] def hscan( hap, pos, focal_pos=None, max_gap=int(1e9), dist_mode=0, step=1, left_bound=0, right_bound=int(1e9), return_pairs=False, ): # F-order: haplotype columns contiguous hap_f = np.asfortranarray(hap) S = hap.shape[0] # Single focal SNP if focal_pos is not None: focal_idx = int(np.argmin(np.abs(pos - int(focal_pos)))) h = _hscan_single(hap_f, pos, int64(focal_idx), int64(max_gap), dist_mode) return float(h) # Full scan all_indices = np.arange(0, S, step, dtype=np.int64) pos_at_idx = pos[all_indices] mask = (pos_at_idx >= left_bound) & (pos_at_idx <= right_bound) indices = all_indices[mask] h_means = _hscan_all(hap, pos, indices, int64(max_gap), dist_mode) return pos[indices], h_means
################## FS stats
[docs] @njit(cache=True, parallel=False) def fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq=0.25, max_focal_freq=0.95, window_size=50000, ): n_snps, n_samples = hap.shape rec_pos = rec_map[:, -2] half_window = window_size * 0.5 # ── single pass: compute freqs + collect focal indices ────────────── # Eliminates total_count array and separate focal-detection loop. freqs = np.empty(n_snps, dtype=np.float64) focal_indices = np.empty(n_snps, dtype=np.int64) n_focal = 0 for i in range(n_snps): f = ac[i, 1] / (ac[i, 0] + ac[i, 1]) freqs[i] = f if min_focal_freq <= f <= max_focal_freq: focal_indices[n_focal] = i n_focal += 1 focal_indices = focal_indices[:n_focal] # ── output containers ──────────────────────────────────────────────── # Sentinel placeholders [np.empty((1,3))...] removed — always overwritten. sq_out_list = [np.empty((0, 3), dtype=np.float64) for _ in range(n_focal)] snp_indices_list = [np.empty(0, dtype=np.int64) for _ in range(n_focal)] info = np.empty((n_focal, 4), dtype=np.float64) # ── main loop ──────────────────────────────────────────────────────── for j in prange(n_focal): focal_idx = focal_indices[j] center = rec_pos[focal_idx] # Window bounds computed inline — eliminates window_bounds (n_focal,2) array. x_l = np.searchsorted(rec_pos, center - half_window, side="left") y_r = np.searchsorted(rec_pos, center + half_window, side="right") - 1 y_l = focal_idx - 1 x_r = focal_idx + 1 # info written unconditionally — eliminates duplicated assignment in empty branch. focal_d_count = ac[focal_idx, 1] focal_a_count = ac[focal_idx, 0] info[j, 0] = center info[j, 1] = freqs[focal_idx] info[j, 2] = np.float64(focal_d_count) info[j, 3] = np.float64(focal_a_count) len_l = max(0, y_l - x_l + 1) len_r = max(0, y_r - x_r + 1) total_len = len_l + len_r if total_len == 0: continue # sentinels already set above out = np.empty((total_len, 3), dtype=np.float64) indices_out = np.empty(total_len, dtype=np.int64) inv_d = 1.0 / focal_d_count if focal_d_count > 0 else 0.0 inv_a = 1.0 / focal_a_count if focal_a_count > 0 else 0.0 hap_f = hap[focal_idx] out_idx = 0 # LEFT window (reverse) — plain 0 inits, Numba infers int64 for k in range(y_l, x_l - 1, -1): hap_k = hap[k] overlap_d = 0 sum_k = 0 for m in range(n_samples): hk = hap_k[m] overlap_d += hap_f[m] * hk sum_k += hk out[out_idx, 0] = overlap_d * inv_d out[out_idx, 1] = (sum_k - overlap_d) * inv_a out[out_idx, 2] = freqs[k] indices_out[out_idx] = k out_idx += 1 # RIGHT window (forward) for k in range(x_r, y_r + 1): hap_k = hap[k] overlap_d = 0 sum_k = 0 for m in range(n_samples): hk = hap_k[m] overlap_d += hap_f[m] * hk sum_k += hk out[out_idx, 0] = overlap_d * inv_d out[out_idx, 1] = (sum_k - overlap_d) * inv_a out[out_idx, 2] = freqs[k] indices_out[out_idx] = k out_idx += 1 sq_out_list[j] = out snp_indices_list[j] = indices_out return sq_out_list, info, snp_indices_list
[docs] def s_ratio( hap, ac, rec_map, max_ancest_freq=1, min_tot_freq=0, min_focal_freq=0.25, max_focal_freq=0.95, window_size=50000, genetic_distance=False, ): """ Compute the S-ratio statistic for each focal SNP. For each focal SNP (derived frequency in ``[min_focal_freq, max_focal_freq]``), neighbors within ``window_size`` are summarized by indicators of intermediate frequency on the derived and ancestral partitions, and the ratio of their counts is reported. :param numpy.ndarray hap: Haplotype matrix ``(n_snps, n_samples)`` with 0/1 values. :param numpy.ndarray ac: Allele counts ``(n_snps, 2)`` as ``[ancestral, derived]``. :param numpy.ndarray rec_map: Map array; penultimate column is the window coordinate. :param float max_ancest_freq: Maximum ancestral-partition frequency threshold. Default ``1``. :param float min_tot_freq: Minimum neighbor total derived frequency. Default ``0``. :param float min_focal_freq: Minimum focal derived frequency. Default ``0.25``. :param float max_focal_freq: Maximum focal derived frequency. Default ``0.95``. :param int window_size: Window size in coordinate units of ``rec_map[:, -2]``. Default ``50000``. :param bool genetic_distance: Unused here (kept for API symmetry). :returns: DataFrame with columns ``positions``, ``daf``, ``s_ratio``. :rtype: polars.DataFrame """ sq_freqs, info, snps_indices = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq, max_focal_freq, window_size ) n_rows = len(sq_freqs) results = np.empty((n_rows, 1), dtype=np.float64) for i, v in enumerate(sq_freqs): f_d = v[:, 0] f_a = v[:, 1] f_d2 = np.zeros(f_d.shape) f_a2 = np.zeros(f_a.shape) f_d2[(f_d > 0.0000001) & (f_d < 1)] = 1 f_a2[(f_a > 0.0000001) & (f_a < 1)] = 1 num = (f_d2 - f_d2 + f_a2 + 1).sum() den = (f_a2 - f_a2 + f_d2 + 1).sum() # redefine to add one to get rid of blowup issue introduced by adding 0.001 to denominator s_ratio_v = num / den # s_ratio_v_flip = den / num # results.append((s_ratio_v, s_ratio_v_flip)) results[i] = s_ratio_v tmp_schema = { "positions": pl.Int64, "daf": pl.Float64, "s_ratio": pl.Float64, # "s_ratio_flip": pl.Float64, } try: out = np.hstack([info[:, :2], results]) # out = np.hstack([info[:,:2], np.array(results)]) df_out = pl.DataFrame(out, schema=tmp_schema) # df_out = pl.DataFrame([out[:, 0], out[:, 1], out[:, 4], out[:, 5]], schema=tmp_schema) except Exception: df_out = pl.DataFrame([[], [], []], schema=tmp_schema) return df_out
[docs] def hapdaf_o( hap, ac, rec_map, max_ancest_freq=0.25, min_tot_freq=0.25, min_focal_freq=0.25, max_focal_freq=0.95, window_size=50000, genetic_distance=False, ): """ Compute hapDAF-o for each focal SNP. hapDAF-o averages ``f_d^2 - f_a^2`` over neighbors that satisfy ``(f_d > f_a) & (f_a <= max_ancest_freq) & (f_tot >= min_tot_freq)``. :param numpy.ndarray hap: Haplotype matrix ``(n_snps, n_samples)`` with 0/1 values. :param numpy.ndarray ac: Allele counts ``(n_snps, 2)`` as ``[ancestral, derived]``. :param numpy.ndarray rec_map: Map array; penultimate column is the window coordinate. :param float max_ancest_freq: Ancestral partition frequency threshold. Default ``0.25``. :param float min_tot_freq: Minimum neighbor total derived frequency. Default ``0.25``. :param float min_focal_freq: Minimum focal derived frequency. Default ``0.25``. :param float max_focal_freq: Maximum focal derived frequency. Default ``0.95``. :param int window_size: Window size in coordinate units of ``rec_map[:, -2]``. Default ``50000``. :param bool genetic_distance: Unused here (kept for API symmetry). :returns: DataFrame with columns ``positions``, ``daf``, ``hapdaf_o``. :rtype: polars.DataFrame """ sq_freqs, info, snps_indices = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq, max_focal_freq, window_size ) n_rows = len(sq_freqs) results = np.empty((n_rows, 1), dtype=np.float64) # nan_index = [] for i, v in enumerate(sq_freqs): f_d = v[:, 0] f_a = v[:, 1] f_tot = v[:, 2] f_d2 = ( f_d[(f_d > f_a) & (f_a <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 ) f_a2 = ( f_a[(f_d > f_a) & (f_a <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 ) if len(f_d2) != 0 and len(f_a2) != 0: hapdaf = (f_d2 - f_a2).sum() / f_d2.shape[0] else: hapdaf = np.nan # # Flipping derived to ancestral, ancestral to derived # f_d2f = ( # f_a[(f_a > f_d) & (f_d <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 # ) # f_a2f = ( # f_d[(f_a > f_d) & (f_d <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 # ) # if len(f_d2f) != 0 and len(f_a2f) != 0: # hapdaf_flip = (f_d2f - f_a2f).sum() / f_d2f.shape[0] # else: # hapdaf_flip = np.nan # results.append((hapdaf, hapdaf_flip)) results[i] = hapdaf tmp_schema = { "positions": pl.Int64, "daf": pl.Float64, "hapdaf_o": pl.Float64, # "hapdaf_o_flip": pl.Float64, } try: out = np.hstack( [ info[:, :2], # np.array(results), results, ] ) df_out = pl.DataFrame(out, schema=tmp_schema) # df_out = pl.DataFrame([out[:, 0], out[:, 1], out[:, 4], out[:, 5]], schema=tmp_schema) except Exception: df_out = pl.DataFrame([[], [], []], schema=tmp_schema) return df_out
[docs] def hapdaf_s( hap, ac, rec_map, max_ancest_freq=0.1, min_tot_freq=0.1, min_focal_freq=0.25, max_focal_freq=0.95, window_size=50000, genetic_distance=False, ): """ Compute hapDAF-s for each focal SNP. hapDAF-s is the same construction as hapDAF-o but uses more stringent thresholds (e.g., smaller ``max_ancest_freq`` and ``min_tot_freq``). :param numpy.ndarray hap: Haplotype matrix ``(n_snps, n_samples)`` with 0/1 values. :param numpy.ndarray ac: Allele counts ``(n_snps, 2)`` as ``[ancestral, derived]``. :param numpy.ndarray rec_map: Map array; penultimate column is the window coordinate. :param float max_ancest_freq: Ancestral partition frequency threshold. Default ``0.1``. :param float min_tot_freq: Minimum neighbor total derived frequency. Default ``0.1``. :param float min_focal_freq: Minimum focal derived frequency. Default ``0.25``. :param float max_focal_freq: Maximum focal derived frequency. Default ``0.95``. :param int window_size: Window size in coordinate units of ``rec_map[:, -2]``. Default ``50000``. :param bool genetic_distance: Unused here (kept for API symmetry). :returns: DataFrame with columns ``positions``, ``daf``, ``hapdaf_s``. :rtype: polars.DataFrame """ sq_freqs, info, snps_indices = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq, max_focal_freq, window_size ) n_rows = len(sq_freqs) results = np.empty((n_rows, 1), dtype=np.float64) # nan_index = [] for i, v in enumerate(sq_freqs): f_d = v[:, 0] f_a = v[:, 1] f_tot = v[:, 2] f_d2 = ( f_d[(f_d > f_a) & (f_a <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 ) f_a2 = ( f_a[(f_d > f_a) & (f_a <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 ) if len(f_d2) != 0 and len(f_a2) != 0: hapdaf = (f_d2 - f_a2).sum() / f_d2.shape[0] else: hapdaf = np.nan # # Flipping derived to ancestral, ancestral to derived # f_d2f = ( # f_a[(f_a > f_d) & (f_d <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 # ) # f_a2f = ( # f_d[(f_a > f_d) & (f_d <= max_ancest_freq) & (f_tot >= min_tot_freq)] ** 2 # ) # if len(f_d2f) != 0 and len(f_a2f) != 0: # hapdaf_flip = (f_d2f - f_a2f).sum() / f_d2f.shape[0] # else: # hapdaf_flip = np.nan # results.append((hapdaf, hapdaf_flip)) results[i] = hapdaf tmp_schema = { "positions": pl.Int64, "daf": pl.Float64, "hapdaf_s": pl.Float64, # "hapdaf_s_flip": pl.Float64, } try: out = np.hstack( [ info[:, :2], # np.array(results), results, ] ) df_out = pl.DataFrame(out, schema=tmp_schema) # df_out = pl.DataFrame([out[:, 0], out[:, 1], out[:, 4], out[:, 5]], schema=tmp_schema) except Exception: df_out = pl.DataFrame([[], [], [], []], schema=tmp_schema) return df_out
[docs] def dind_high_low( hap, ac, rec_map, max_ancest_freq=0.25, min_tot_freq=0, min_focal_freq=0.25, max_focal_freq=0.95, window_size=50000, genetic_distance=False, ): """ Compute DIND, highfreq, and lowfreq statistics per focal SNP. :param numpy.ndarray hap: Haplotype matrix ``(n_snps, n_samples)`` with 0/1 values. :param numpy.ndarray ac: Allele counts ``(n_snps, 2)`` as ``[ancestral, derived]``. :param numpy.ndarray rec_map: Map array; penultimate column is the window coordinate. :param float max_ancest_freq: Threshold used in high/low frequency components. Default ``0.25``. :param float min_tot_freq: Unused here (kept for API symmetry). Default ``0``. :param float min_focal_freq: Minimum focal derived frequency. Default ``0.25``. :param float max_focal_freq: Maximum focal derived frequency. Default ``0.95``. :param int window_size: Window size in coordinate physical units from ``rec_map[:, -2]``. Default ``50000``. :param bool genetic_distance: Unused here (kept for API symmetry). :returns: DataFrame with columns ``positions``, ``daf``, ``dind``, ``high_freq``, ``low_freq``. :rtype: polars.DataFrame """ sq_freqs, info, snps_indices = fast_sq_freq_pairs( hap, ac, rec_map, min_focal_freq, max_focal_freq, window_size ) focal_counts = info[:, 2:] # Pre-allocate arrays for results to avoid growing lists n_rows = len(sq_freqs) results_dind = np.empty((n_rows, 1), dtype=np.float64) results_high = np.empty((n_rows, 1), dtype=np.float64) results_low = np.empty((n_rows, 1), dtype=np.float64) # Main computation loop for i, v in enumerate(sq_freqs): f_d = v[:, 0] f_a = v[:, 1] focal_derived_count = focal_counts[i][0] focal_ancestral_count = focal_counts[i][1] # Calculate derived and ancestral components with in-place operations f_d2 = f_d * (1 - f_d) * focal_derived_count / (focal_derived_count - 1) f_a2 = ( f_a * (1 - f_a) * focal_ancestral_count / (focal_ancestral_count - 1 + 0.001) ) # Calculate dind values num = (f_d2 - f_d2 + f_a2).sum() den = (f_a2 - f_a2 + f_d2).sum() + 0.001 dind_v = num / den if not np.isinf(num / den) else np.nan # dind_v_flip = den / num if not np.isinf(den / num) else np.nan # results_dind[i] = [dind_v, dind_v_flip] results_dind[i] = dind_v # Calculate high and low frequency values hf_v = (f_d[f_d > max_ancest_freq] ** 2).sum() / max( len(f_d[f_d > max_ancest_freq]), 1 ) # hf_v_flip = (f_a[f_a > max_ancest_freq] ** 2).sum() / max( # len(f_a[f_a > max_ancest_freq]), 1 # ) # results_high[i] = [hf_v, hf_v_flip] results_high[i] = hf_v lf_v = ((1 - f_d[f_d < max_ancest_freq]) ** 2).sum() / max( len(f_d[f_d < max_ancest_freq]), 1 ) # lf_v_flip = ((1 - f_a[f_a < max_ancest_freq]) ** 2).sum() / max( # len(f_a[f_a < max_ancest_freq]), 1 # ) # results_low[i] = [lf_v, lf_v_flip] results_low[i] = lf_v # Free memory explicitly for large arrays del f_d, f_a, f_d2, f_a2 tmp_schema = { "positions": pl.Int64, "daf": pl.Float64, "dind": pl.Float64, # "dind_flip": pl.Float64, "high_freq": pl.Float64, # "high_freq_flip": pl.Float64, "low_freq": pl.Float64, # "low_freq_flip": pl.Float64, } # Final DataFrame creation try: out = np.hstack([info[:, :2], results_dind, results_high, results_low]) df_out = pl.DataFrame(out, schema=tmp_schema) # df_out = pl.DataFrame([out[:, 0],out[:, 1],out[:, 4],out[:, 5],out[:, 6],out[:, 7],out[:, 8],out[:, 9],],schema=tmp_schema,) except Exception: df_out = pl.DataFrame([[], [], [], [], [], [], [], []], schema=tmp_schema) return df_out
[docs] @njit(parallel=False, cache=True) def s_ratio_from_pairs(sq_freqs, max_ancest_freq=1, min_tot_freq=0): n_rows = len(sq_freqs) results = np.zeros((n_rows, 1)) for i in prange(len(sq_freqs)): f_d = sq_freqs[i][:, 0] f_a = sq_freqs[i][:, 1] f_d2 = np.zeros(f_d.shape) f_a2 = np.zeros(f_a.shape) f_d2[(f_d > 0.0000001) & (f_d < 1)] = 1 f_a2[(f_a > 0.0000001) & (f_a < 1)] = 1 num = (f_d2 - f_d2 + f_a2 + 1).sum() den = (f_a2 - f_a2 + f_d2 + 1).sum() # redefine to add one to get rid of blowup issue introduced by adding 0.001 to denominator # Add error checking before division if den == 0: s_ratio_v = np.nan else: s_ratio_v = num / den # if num == 0: # s_ratio_v_flip = np.nan # else: # s_ratio_v_flip = den / num # s_ratio_v = num / den # s_ratio_v_flip = den / num # results[i] = s_ratio_v, s_ratio_v_flip results[i] = s_ratio_v return results
[docs] @njit(cache=True, parallel=False) def hapdaf_from_pairs(sq_freqs, max_ancest_freq, min_tot_freq): """ Unified hapdaf_o and hapdaf_s — bodies were identical, only defaults differ. Call as:: hapdaf_o = hapdaf_from_pairs(sq_freqs, max_ancest_freq=0.25, min_tot_freq=0.25) hapdaf_s = hapdaf_from_pairs(sq_freqs, max_ancest_freq=0.10, min_tot_freq=0.10) Removed dead args: hap, snps_indices (were passed to hapdaf_o but never used). """ n_rows = len(sq_freqs) results = np.zeros((n_rows, 1), dtype=np.float64) for i in prange(n_rows): f_d = sq_freqs[i][:, 0] f_a = sq_freqs[i][:, 1] f_tot = sq_freqs[i][:, 2] mask = (f_d > f_a) & (f_a <= max_ancest_freq) & (f_tot >= min_tot_freq) f_d2 = f_d[mask] ** 2 f_a2 = f_a[mask] ** 2 if len(f_d2) > 0: results[i, 0] = (f_d2 - f_a2).sum() / float64(len(f_d2)) else: results[i, 0] = np.nan return results
[docs] @njit(parallel=False, cache=True) def dind_high_low_from_pairs(sq_freqs, info, max_ancest_freq=0.25, min_tot_freq=0): # Pre-allocate arrays for results to avoid growing lists n_rows = len(sq_freqs) results_dind = np.zeros((n_rows, 1)) results_high = np.zeros((n_rows, 1)) results_low = np.zeros((n_rows, 1)) # results_dind = np.zeros((n_rows, 2)) # results_high = np.zeros((n_rows, 2)) # results_low = np.zeros((n_rows, 2)) # Main computation loop for i in prange(len(sq_freqs)): f_d = sq_freqs[i][:, 0] f_a = sq_freqs[i][:, 1] f_tot = sq_freqs[i][:, 2] focal_derived_count = info[i][-2] focal_ancestral_count = info[i][-1] # Calculate derived and ancestral components f_d2 = f_d * (1 - f_d) * focal_derived_count / (focal_derived_count - 1) f_a2 = ( f_a * (1 - f_a) * focal_ancestral_count / (focal_ancestral_count - 1 + 0.001) ) # Calculate dind values num = (f_d2 - f_d2 + f_a2).sum() den = (f_a2 - f_a2 + f_d2).sum() + 0.001 if den != 0.0: dind_v = num / den else: dind_v = np.nan # if num != 0.0: # dind_v_flip = den / num # else: # dind_v_flip = np.nan # results_dind[i] = [dind_v, dind_v_flip] results_dind[i] = dind_v # Calculate high and low frequency values fd_h_mask = (f_d > max_ancest_freq) & (f_tot >= min_tot_freq) # fa_h_mask = (f_a > max_ancest_freq) & (f_tot >= min_tot_freq) fd_l_mask = (f_d < max_ancest_freq) & (f_tot >= min_tot_freq) # fa_l_mask = (f_a < max_ancest_freq) & (f_tot >= min_tot_freq) fd_l_mask = ((f_d > max_ancest_freq) & (f_d < 0.8)) & (f_tot >= min_tot_freq) fd_l_mask = (f_d > max_ancest_freq) & (f_tot >= min_tot_freq) # fa_l_mask = ((f_a > 0.25) & (f_a < 0.8)) & (f_tot >= min_tot_freq) hf_v = (f_d[fd_h_mask] ** 2).sum() / max(len(f_d[fd_h_mask]), 1) # hf_v_flip = (f_a[fa_h_mask] ** 2).sum() / max(len(f_a[fa_h_mask]), 1) # results_high[i] = [hf_v, hf_v_flip] results_high[i] = hf_v lf_v = ((1 - f_d[fd_l_mask]) ** 2).sum() / max(len(f_d[fd_l_mask]), 1) # lf_v_flip = ((1 - f_a[fa_l_mask]) ** 2).sum() / max(len(f_a[fa_l_mask]), 1) # results_low[i] = [lf_v, lf_v_flip] results_low[i] = lf_v return results_dind, results_high, results_low
def fs_stats_dataframe( info, results_dind, results_high, results_low, results_s_ratio, results_hapdaf_o, results_hapdaf_s, _iter=None, ): try: out_dind_high_low = np.hstack( [info[:, :2], results_dind, results_high, results_low] ) df_dind_high_low = pl.DataFrame( out_dind_high_low, schema={ "positions": pl.Int64, "daf": pl.Float64, "dind": pl.Float64, # "dind_flip": pl.Float64, "high_freq": pl.Float64, # "high_freq_flip": pl.Float64, "low_freq": pl.Float64, # "low_freq_flip": pl.Float64, }, ) except Exception: df_dind_high_low = pl.DataFrame( # [[], [], [], [], [], [], [], []], [[], [], [], [], []], schema={ "positions": pl.Int64, "daf": pl.Float64, "dind": pl.Float64, # "dind_flip": pl.Float64, "high_freq": pl.Float64, # "high_freq_flip": pl.Float64, "low_freq": pl.Float64, # "low_freq_flip": pl.Float64, }, ) try: out_s_ratio = np.hstack([info[:, :2], results_s_ratio]) df_s_ratio = pl.DataFrame( out_s_ratio, schema={ "positions": pl.Int64, "daf": pl.Float64, "s_ratio": pl.Float64, # "s_ratio_flip": pl.Float64, }, ) except Exception: df_s_ratio = pl.DataFrame( # [[], [], [], []], [[], [], []], schema={ "positions": pl.Int64, "daf": pl.Float64, "s_ratio": pl.Float64, # "s_ratio_flip": pl.Float64, }, ) try: out_hapdaf_s = np.hstack([info[:, :2], np.array(results_hapdaf_s)]) df_hapdaf_s = pl.DataFrame( out_hapdaf_s, schema={ "positions": pl.Int64, "daf": pl.Float64, "hapdaf_s": pl.Float64, # "hapdaf_s_flip": pl.Float64, }, ) except Exception: df_hapdaf_s = pl.DataFrame( [[], [], []], # [[], [], [], []], schema={ "positions": pl.Int64, "daf": pl.Float64, "hapdaf_s": pl.Float64, # "hapdaf_s_flip": pl.Float64, }, ) try: out_hapdaf_o = np.hstack([info[:, :2], np.array(results_hapdaf_o)]) df_hapdaf_o = pl.DataFrame( out_hapdaf_o, schema={ "positions": pl.Int64, "daf": pl.Float64, "hapdaf_o": pl.Float64, # "hapdaf_o_flip": pl.Float64, # "omega_diff": pl.Float64, }, ) except Exception: df_hapdaf_o = pl.DataFrame( [[], [], []], # [[], [], [], []], schema={ "positions": pl.Int64, "daf": pl.Float64, "hapdaf_o": pl.Float64, # "hapdaf_o_flip": pl.Float64, # "omega_diff": pl.Float64, }, ) if _iter is not None: df_dind_high_low.with_columns(pl.lit(_iter).alias("iter")) df_s_ratio.with_columns(pl.lit(_iter).alias("iter")) df_hapdaf_o.with_columns(pl.lit(_iter).alias("iter")) df_hapdaf_s.with_columns(pl.lit(_iter).alias("iter")) return ( df_dind_high_low.fill_nan(None), df_s_ratio.fill_nan(None), df_hapdaf_o.fill_nan(None), df_hapdaf_s.fill_nan(None), ) ################## iSAFE @njit(cache=True) def rank_with_duplicates(x): # sorted_arr = sorted(x, reverse=True) sorted_arr = np.sort(x)[::-1] rank_dict = {} rank = 1 prev_value = -1 for value in sorted_arr: if value != prev_value: rank_dict[value] = rank rank += 1 prev_value = value return np.array([rank_dict[value] for value in x]) # @njit("float64[:,:](float64[:,:])", cache=True) @njit(parallel=False, cache=True) def dot_nb(hap): return np.dot(hap.T, hap) @njit(cache=True) def neutrality_divergence_proxy(kappa, phi, freq, method=3): sigma1 = (kappa) * (1 - kappa) sigma1[sigma1 == 0] = 1.0 sigma1 = sigma1**0.5 p1 = (phi - kappa) / sigma1 sigma2 = (freq) * (1 - freq) sigma2[sigma2 == 0] = 1.0 sigma2 = sigma2**0.5 p2 = (phi - kappa) / sigma2 nu = freq[np.argmax(p1)] p = p1 * (1 - nu) + p2 * nu if method == 1: return p1 elif method == 2: return p2 elif method == 3: return p @njit(cache=True) def calc_H_K(hap, haf): """ :param snp_matrix: Binary SNP Matrix :return: H: Sum of HAF-score of carriers of each mutation. :return: N: Number of distinct carrier haplotypes of each mutation. """ num_snps, num_haplotypes = hap.shape haf_matrix = haf * hap K = np.zeros((num_snps)) for j in range(num_snps): ar = haf_matrix[j, :] K[j] = len(np.unique(ar[ar > 0])) H = np.sum(haf_matrix, 1) return (H, K) def safe(hap): num_snps, num_haplotypes = hap.shape haf = dot_nb(hap.astype(np.float64)).sum(1) # haf = np.dot(hap.T, hap).sum(1) H, K = calc_H_K(hap, haf) phi = 1.0 * H / haf.sum() kappa = 1.0 * K / (np.unique(haf).shape[0]) freq = hap.sum(1) / num_haplotypes safe_values = neutrality_divergence_proxy(kappa, phi, freq) # rank = np.zeros(safe_values.size) # rank = rank_with_duplicates(safe_values) # rank = ( # pd.DataFrame(safe_values).rank(method="min", ascending=False).values.flatten() # ) rank = ( pl.DataFrame({"safe": safe_values}) .select(pl.col("safe").rank(method="min", descending=True)) .to_numpy() .flatten() ) return haf, safe_values, rank, phi, kappa, freq def creat_windows_summary_stats_nb(hap, pos, w_size=300, w_step=150): num_snps, num_haplotypes = hap.shape rolling_indices = create_rolling_indices_nb(num_snps, w_size, w_step) windows_stats = {} windows_haf = [] snp_summary = [] for i, I_rolling in enumerate(rolling_indices): window_i_stats = {} haf, safe_values, rank, phi, kappa, freq = safe( hap[I_rolling[0] : I_rolling[1], :] ) tmp = pl.DataFrame( { "safe": safe_values, "rank": rank, "phi": phi, "kappa": kappa, "freq": freq, "pos": pos[I_rolling[0] : I_rolling[1]], "ordinal_pos": np.arange(I_rolling[0], I_rolling[1]), "window": np.repeat(i, I_rolling[1] - I_rolling[0]), } ) window_i_stats["safe"] = tmp windows_haf.append(haf) windows_stats[i] = window_i_stats snp_summary.append(tmp) combined_df = pl.concat(snp_summary).with_columns( pl.col("ordinal_pos").cast(pl.Float64) ) # combined_df = combined_df.with_row_count(name="index") # snps_summary.select(snps_summary.columns[1:]) return windows_stats, windows_haf, combined_df @njit(cache=True) def create_rolling_indices_nb(total_variant_count, w_size, w_step): assert total_variant_count < w_size or w_size > 0 rolling_indices = [] w_start = 0 while True: w_end = min(w_start + w_size, total_variant_count) if w_end >= total_variant_count: break rolling_indices.append([w_start, w_end]) # rolling_indices += [range(int(w_start), int(w_end))] w_start += w_step return rolling_indices
[docs] def run_isafe( hap, positions, max_freq=1, min_region_size_bp=49000, min_region_size_ps=300, ignore_gaps=True, window=300, step=150, top_k=1, max_rank=15, ): """ Estimate iSAFE or SAFE on a genomic region following Flex-sweep default values. The function removes monomorphic SNPs, then checks region size. If ``num_snps <= min_region_size_ps`` or ``positions.max() - positions.min() < min_region_size_bp``, it computes **SAFE**; otherwise it computes **iSAFE** using the provided sliding-window settings. Results are returned as a Polars DataFrame with columns ``positions`` (bp), ``daf`` (derived allele frequency), and ``isafe`` (score). Variants with ``daf >= max_freq`` are filtered out. :param numpy.ndarray hap: Haplotype matrix of shape ``(n_snps, n_haplotypes)`` with 0/1 values (ancestral/derived). :param numpy.ndarray positions: 1D array of physical coordinates (length ``n_snps``) aligned to ``hap`` rows. :param float max_freq: Maximum allowed derived allele frequency in the output (``daf < max_freq``). Default ``1`` (no filter). :param int min_region_size_bp: Minimum region span in base pairs required to run iSAFE. Default ``49000``. :param int min_region_size_ps: Minimum number of polymorphic SNPs required to run iSAFE. Default ``300``. :param bool ignore_gaps: Reserved for gap handling; currently not used. Default ``True``. :param int window: iSAFE sliding window size (number of SNPs or bp, depending on the downstream implementation). Default ``300``. :param int step: iSAFE step between windows. Default ``150``. :param int top_k: iSAFE parameter controlling the number of top candidates per window. Default ``1``. :param int max_rank: iSAFE parameter controlling the maximum rank to track. Default ``15``. :returns: Polars DataFrame with columns ``positions`` (int), ``daf`` (float), and ``isafe`` (float), sorted by position and filtered to ``daf < max_freq``. If the region is small, the ``isafe`` column contains SAFE scores. :rtype: polars.DataFrame .. note:: Monomorphic sites are removed using ``(1 - f) * f > 0``, where ``f`` is the derived allele frequency per SNP. When computing iSAFE, the function passes ``window``, ``step``, ``top_k``, and ``max_rank`` to the underlying implementation. """ total_window_size = positions.max() - positions.min() # dp = np.diff(positions) # num_gaps = sum(dp > 6000000) f = hap.mean(1) freq_filter = ((1 - f) * f) > 0 hap_filtered = hap[freq_filter, :] positions_filtered = positions[freq_filter] num_snps = hap_filtered.shape[0] if (num_snps <= min_region_size_ps) | (total_window_size < min_region_size_bp): haf, safe_values, rank, phi, kappa, freq = safe(hap_filtered) df_safe = pl.DataFrame( { "isafe": safe_values, "rank": rank, "phi": phi, "kappa": kappa, "daf": freq, "positions": positions_filtered, } ) return df_safe.select(["positions", "daf", "isafe"]).sort("positions") else: df_isafe = isafe( hap_filtered, positions_filtered, window, step, top_k, max_rank ) df_isafe = ( df_isafe.filter(pl.col("freq") < max_freq) .sort("ordinal_pos") .rename({"id": "positions", "isafe": "isafe", "freq": "daf"}) .filter(pl.col("daf") < max_freq) .select(["positions", "daf", "isafe"]) ) return df_isafe
[docs] def isafe(hap, pos, w_size=300, w_step=150, top_k=1, max_rank=15): windows_summaries, windows_haf, snps_summary = creat_windows_summary_stats_nb( hap, pos, w_size, w_step ) df_top_k1 = get_top_k_snps_in_each_window(snps_summary, k=top_k) ordinal_pos_snps_k1 = np.sort(df_top_k1["ordinal_pos"].unique()).astype(np.int64) psi_k1 = step_function(creat_matrix_Psi_k_nb(hap, windows_haf, ordinal_pos_snps_k1)) df_top_k2 = get_top_k_snps_in_each_window(snps_summary, k=max_rank) temp = np.sort(df_top_k2["ordinal_pos"].unique()) ordinal_pos_snps_k2 = np.sort(np.setdiff1d(temp, ordinal_pos_snps_k1)).astype( np.int64 ) psi_k2 = step_function(creat_matrix_Psi_k_nb(hap, windows_haf, ordinal_pos_snps_k2)) alpha = psi_k1.sum(0) / psi_k1.sum() iSAFE1 = pl.DataFrame( data={ "ordinal_pos": ordinal_pos_snps_k1, "isafe": np.dot(psi_k1, alpha), "tier": np.repeat(1, ordinal_pos_snps_k1.size), } ) iSAFE2 = pl.DataFrame( { "ordinal_pos": ordinal_pos_snps_k2, "isafe": np.dot(psi_k2, alpha), "tier": np.repeat(2, ordinal_pos_snps_k2.size), } ) # Concatenate the DataFrames and reset the index iSAFE = pl.concat([iSAFE1, iSAFE2]) # Add the "id" column using values from `pos` iSAFE = iSAFE.with_columns( pl.col("ordinal_pos") .map_elements(lambda x: pos[x], return_dtype=pl.Int64) .alias("id") ) # Add the "freq" column using values from `freq` freq = hap.mean(1) iSAFE = iSAFE.with_columns( pl.col("ordinal_pos") .map_elements(lambda x: freq[x], return_dtype=pl.Float64) .alias("freq") ) # Select the required columns df_isafe = iSAFE.select(["ordinal_pos", "id", "isafe", "freq", "tier"]) return df_isafe
# @njit # def creat_matrix_Psi_k_nb(hap, hafs, Ifp): # P = np.zeros((len(Ifp), len(hafs))) # for i in range(len(Ifp)): # for j in range(len(hafs)): # P[i, j] = isafe_kernel_nb(hafs[j], hap[Ifp[i], :]) # return P @njit(cache=True) def isafe_kernel_nb(haf, snp): phi = haf[snp == 1].sum() * 1.0 / haf.sum() kappa = len(np.unique(haf[snp == 1])) / (1.0 * len(np.unique(haf))) f = np.mean(snp) sigma2 = (f) * (1 - f) if sigma2 == 0: sigma2 = 1.0 sigma = sigma2**0.5 p = (phi - kappa) / sigma return p @njit(cache=True) def creat_matrix_Psi_k_nb(hap, hafs, Ifp): """Further optimized version with pre-computed unique values""" P = np.zeros((len(Ifp), len(hafs))) # Pre-compute for each haf: sum and unique count haf_sums = np.zeros(len(hafs)) haf_unique_counts = np.zeros(len(hafs)) for j in range(len(hafs)): haf_sums[j] = hafs[j].sum() haf_unique_counts[j] = len(np.unique(hafs[j])) for i in range(len(Ifp)): snp = hap[Ifp[i], :] # Pre-compute common values for this row f = np.mean(snp) sigma2 = f * (1 - f) if sigma2 == 0: sigma2 = 1.0 sigma = sigma2**0.5 snp_ones_idx = np.where(snp == 1)[0] for j in range(len(hafs)): haf = hafs[j] # Use pre-computed values phi = haf[snp_ones_idx].sum() / haf_sums[j] kappa = len(np.unique(haf[snp_ones_idx])) / haf_unique_counts[j] p = (phi - kappa) / sigma P[i, j] = p return P def step_function(P0): P = P0.copy() P[P < 0] = 0 return P def get_top_k_snps_in_each_window(df_snps, k=1): """ :param df_snps: this datafram must have following columns: ["safe","ordinal_pos","window"]. :param k: :return: return top k snps in each window. """ return ( df_snps.group_by("window") .agg(pl.all().sort_by("safe", descending=True).head(k)) .explode(pl.all().exclude("window")) .sort("window") .select(pl.all().exclude("window"), pl.col("window")) ) ################## LD stats
[docs] @njit(parallel=False, cache=True) def r2(locus_A: np.ndarray, locus_B: np.ndarray) -> float: """ Compute the squared correlation coefficient :math:`r^2` between two biallelic loci. Given two 0/1 vectors of equal length (haplotypes across samples), this function computes: .. math:: D = P_{11} - p_A p_B,\\quad r^2 = \\frac{D^2}{p_A (1-p_A)\\, p_B (1-p_B)}, where :math:`p_A` and :math:`p_B` are the allele-1 frequencies at loci A and B, and :math:`P_{11}` is the empirical joint frequency that both loci equal 1. :param numpy.ndarray locus_A: 1D array of 0/1 alleles for locus A (dtype ``int8`` expected by Numba signature). :param numpy.ndarray locus_B: 1D array of 0/1 alleles for locus B (same length and dtype as ``locus_A``). :returns: The :math:`r^2` value as a float. :rtype: float .. note:: If either locus is monomorphic (denominator zero), the result may be ``inf`` or ``nan`` depending on arithmetic; callers typically filter such sites beforehand. """ n = locus_A.size # Frequency of allele 1 in locus A and locus B a1 = 0 b1 = 0 count_a1b1 = 0 for i in range(n): a1 += locus_A[i] b1 += locus_B[i] count_a1b1 += locus_A[i] * locus_B[i] a1 /= n b1 /= n a1b1 = count_a1b1 / n D = a1b1 - a1 * b1 r_squared = (D**2) / (a1 * (1 - a1) * b1 * (1 - b1)) return r_squared
[docs] def compute_r2_matrix_upper(hap, as_float32=False): """ r² via pre-scaled BLAS matmul. Avoids the outer-product subtraction step of the original compute_r2_matrix_upper, saving one O(S²) allocation. """ if not as_float32: if hap.dtype != np.float64: raise TypeError( f"compute_r2_matrix_upper: hap must be float64, got {hap.dtype}. " "Pass as_float32=True to use float32 internally." ) else: hap = hap.astype(np.float32, copy=False) S, N = hap.shape p = hap.mean(axis=1) v = p * (1.0 - p) v[v == 0] = np.inf std = np.sqrt(v) hap_scaled = (hap - p[:, None]) / std[:, None] r2 = (hap_scaled @ hap_scaled.T) / N np.square(r2, out=r2) return np.triu(r2, k=1)
[docs] @njit(parallel=False, cache=True) def omega_linear_correct(r2_matrix): """ Compute :math:`\\omega_\\text{max}` (Kim & Nielsen, 2004) from an :math:`r^2` matrix. The statistic compares the average LD within two partitions (left/right of a split) to the average LD between the partitions. For a split index :math:`\\ell` on a sequence of length :math:`S`, define: .. math:: \\begin{aligned} &\\text{within-left} &&= \\sum_{0 \\le i < j < \\ell} r^2_{ij},\\\\ &\\text{within-right} &&= \\sum_{\\ell \\le i < j < S} r^2_{ij},\\\\ &\\text{between} &&= \\sum_{0 \\le i < \\ell} \\sum_{\\ell \\le j < S} r^2_{ij}, \\end{aligned} and the means are obtained by dividing by the corresponding pair counts :math:`\\binom{\\ell}{2}`, :math:`\\binom{S-\\ell}{2}`, and :math:`\\ell(S-\\ell)`. The omega score at :math:`\\ell` is: .. math:: \\omega(\\ell) = \\frac{\\dfrac{\\text{within-left}}{\\binom{\\ell}{2}} + \\dfrac{\\text{within-right}}{\\binom{S-\\ell}{2}}} {\\dfrac{\\text{between}}{\\ell(S-\\ell)}}. This function scans admissible :math:`\\ell` and returns the maximum value. :param numpy.ndarray _r2: Square matrix (``S`` × ``S``) of pairwise :math:`r^2` values. Only the upper triangle (``i < j``) is required to hold valid values. :param numpy.ndarray mask: Optional boolean vector selecting a subset of SNP indices to consider. Default ``None`` (use all SNPs). :returns: The maximum omega value over all candidate split points. :rtype: float :notes: Modification from https://github.com/kr-colab/diploSHIC/blob/master/diploshic/utils.c taking advantages of numpy vectorized operations. Very small windows (``S < 3``) return ``0.0``. """ S = r2_matrix.shape[0] if S < 3: # return np.array([0.0,0.0]) return 0.0, 0.0 # Build row_sum[i] = sum_{j>i} r2[i,j] # and col_sum[j] = sum_{i<j} r2[i,j] # Also accumulate total of all upper‐triangle entries. row_sum = np.zeros(S, np.float64) col_sum = np.zeros(S, np.float64) total = 0.0 for i in range(S): s = 0.0 for j in range(i + 1, S): v = r2_matrix[i, j] s += v col_sum[j] += v row_sum[i] = s total += s # Kelly's ZnS divisor = (S * (S - 1)) / 2.0 zns = total / divisor if divisor > 0 else 0.0 # Build prefix_L[_l] = sum_{i<j<_l} r2[i,j] (prefix_L[0] = 0 sentinel) prefix_L = np.zeros(S, np.float64) for _l in range(1, S): prefix_L[_l] = prefix_L[_l - 1] + col_sum[_l - 1] # Build suffix_R[_l] = sum_{_l≤i<j} r2[i,j] (suffix_R[S] = 0 sentinel) suffix_R = np.zeros(S + 1, np.float64) # suffix_R[S] = 0.0 for _l in range(S - 1, -1, -1): suffix_R[_l] = suffix_R[_l + 1] + row_sum[_l] # Sweep _l = 3..S-3 in O(S), compute _omega and track maximum omega_max = 0.0 # omega_argmax = -1.0 for _l in range(3, S - 2): sum_L = prefix_L[_l] sum_R = suffix_R[_l] sum_LR = total - sum_L - sum_R if sum_LR > 0.0: denom_L = (_l * (_l - 1) / 2.0) + ((S - _l) * (S - _l - 1) / 2.0) denom_R = _l * (S - _l) _omega = ((sum_L + sum_R) / denom_L) / (sum_LR / denom_R) if _omega > omega_max: omega_max = _omega # omega_argmax = _l + 2 # return np.array([omega_max,omega_argmax]) return zns, omega_max
[docs] def Ld(hap, as_float32=True) -> tuple: """ Compute **Kelly's ZnS** (mean pairwise :math:`r^2`) and **omega\\_max** from an LD matrix. The input ``r_2`` is a square matrix of pairwise linkage disequilibrium values :math:`r^2` among SNPs within a window. If ``mask`` is provided, the computation is restricted to the subset of indices where ``mask`` is ``True``. ZnS is defined as: .. math:: \\mathrm{ZnS} = \\frac{\\sum_{i<j} r^2_{ij}}{\\binom{S}{2}}, where :math:`S` is the number of SNPs after masking. The function also returns ``omega_max`` (Kim & Nielsen, 2004), computed via :func:`omega_linear_correct_mask`, which scans split points and compares the average LD within versus between the two partitions. :param numpy.ndarray r_2: Square matrix (``S`` × ``S``) of pairwise :math:`r^2` values. The routine treats it as symmetric; values on and below the diagonal are ignored for ZnS. :param numpy.ndarray mask: Optional boolean vector of length ``S`` to select a subset of SNPs. Default ``None``. :returns: Tuple ``(zns, omega_max)`` as floats. :rtype: tuple[float, float] """ r_2 = compute_r2_matrix_upper(hap, as_float32=as_float32) # S = _r_2.shape[0] # zns = _r_2.sum() / comb(S, 2) zns, omega_max = omega_linear_correct(r_2) # return zns, 0 return zns, omega_max
################## Site Frequency Spectrum stats @njit(cache=True) def _harmonic_sums(n): """ Return harmonic sums up to ``n-1``. Computes: - ``a1 = sum_{i=1}^{n-1} 1/i`` - ``a2 = sum_{i=1}^{n-1} 1/i^2`` :param int n: Sample size (number of chromosomes). :returns: A length-2 array ``[a1, a2]`` as ``float64``. :rtype: numpy.ndarray """ a1 = 0.0 a2 = 0.0 for i in range(1, int(n)): inv = 1.0 / i a1 += inv a2 += inv * inv return np.array((a1, a2), dtype=np.float64)
[docs] @njit(cache=True) def theta_watterson(ac, positions): # count segregating variants S = ac.shape[0] n = ac[0].sum() a1 = _harmonic_sums(n)[0] # calculate absolute value theta_hat_w_abs = S / a1 # calculate value per base if positions.size < 2: # not enough positions to estimate per-base value meaningfully return theta_hat_w_abs, np.nan n_bases = (positions[-1] - positions[0]) + 1 theta_hat_w = theta_hat_w_abs / n_bases return theta_hat_w_abs, theta_hat_w
[docs] @njit(cache=True) def sfs_nb(dac, n): """ Site-frequency spectrum (SFS) from derived-allele counts. :param numpy.ndarray dac: 1D array of derived allele counts per site, values in ``[0..n]``. :param int n: Total number of chromosomes. If ``n <= 0``, it is inferred as ``max(dac)``. :returns: Integer array of length ``n+1``; ``sfs[k]`` is the number of sites with ``k`` derived copies. :rtype: numpy.ndarray """ # infer n if not provided or invalid if n <= 0: maxv = 0 for i in range(dac.shape[0]): if dac[i] > maxv: maxv = dac[i] n = maxv # initialize spectrum s = np.zeros(n + 1, dtype=np.int64) # counts for i in range(dac.shape[0]): k = dac[i] if 0 <= k <= n: s[k] += 1 return s
[docs] @njit(cache=True) def theta_pi(ac): """ Per-site nucleotide diversity (π) from allele counts. For each site ``j``, computes ``pi_j = 2 * a_j * (n - a_j) / [n * (n - 1)]``, where ``a_j`` is the derived allele count and ``n`` is the total number of chromosomes. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), with constant ``n`` across sites. :returns: Array of per-site π values of length ``S``. :rtype: numpy.ndarray """ S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() denom_pairs = n * (n - 1.0) # pi = np.zeros(S) pi = np.zeros(S) for j in range(S): aj = ac[j, 1] pi[j] = 2.0 * aj * (n - aj) / denom_pairs return pi
[docs] @njit(cache=True) def tajima_d(ac, min_sites=3): """ Tajima’s D from allele counts. Compares the mean pairwise difference (sum of per-site π) to the Watterson estimator based on the number of segregating sites. Returns ``nan`` if the number of segregating sites is below ``min_sites``. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), with constant ``n`` across sites. :param int min_sites: Minimum required number of segregating sites. Default ``3``. :returns: Tajima’s D as a float (``nan`` if insufficient sites). :rtype: float """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() if S < min_sites: return np.nan # (n-1)th harmonic number an, bn = _harmonic_sums(n) # calculate Watterson's theta (absolute value) theta_hat_w_abs = S / an # calculate mean pairwise difference mpd = theta_pi(ac) # calculate theta_hat pi (sum differences over variants) theta_hat_pi_abs = mpd.sum() # N.B., both theta estimates are usually divided by the number of # (accessible) bases but here we want the absolute difference d = theta_hat_pi_abs - theta_hat_w_abs # calculate the denominator (standard deviation) a2 = np.sum(1 / (np.arange(1, n) ** 2)) b1 = (n + 1) / (3 * (n - 1)) b2 = 2 * (n**2 + n + 3) / (9 * n * (n - 1)) c1 = b1 - (1 / an) c2 = b2 - ((n + 2) / (an * n)) + (a2 / (an**2)) e1 = c1 / an e2 = c2 / (an**2 + a2) d_stdev = np.sqrt((e1 * S) + (e2 * S * (S - 1))) # finally calculate Tajima's D D = d / d_stdev return D
[docs] @njit(cache=True) def achaz_y(ac): """ Achaz’s Y neutrality test (standardized). Unfolded/polarized form — requires ancestral-state information (outgroup or polarization). Excludes ξ₁ (derived singletons) from both estimators. Reference: Achaz 2008, Appendix B, Equations B28–B30. f = (n-2) / (n·(a_n-1)); Var[Y] = α_n·θ + β_n·θ² :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. Derived allele counts in ``ac[:,1]`` define the unfolded SFS. :returns: Standardized Achaz’s Y as a float; returns ``nan`` if ``n < 3`` or if there are no segregating sites excluding singletons. :rtype: float """ n = int(ac[0, 0] + ac[0, 1]) if n < 3: return np.nan fs = sfs_nb(ac[:, 1], n) a1, a2 = _harmonic_sums(n) a1m1 = a1 - 1.0 ff = (n - 2.0) / (n * a1m1) inv_n = 1.0 / n inv_n1 = 1.0 / (n - 1.0) inv_n2 = 1.0 / (n - 2.0) n2 = n * n alpha = ( ff * ff * a1m1 + ff * ( a1 * (4.0 * (n + 1.0) * inv_n1 * inv_n1) - 2.0 * (n + 1.0) * (n + 2.0) * inv_n * inv_n1 ) - a1 * 8.0 * (n + 1.0) * inv_n * inv_n1 * inv_n1 + (n * n2 + n2 + 60.0 * n + 12.0) * (inv_n * inv_n) * (1.0 / 3.0) * inv_n1 ) beta = ( ff * ff * (a2 + a1 * (4.0 * inv_n1 * inv_n2) - 4.0 * inv_n2) + ff * ( -a1 * (4.0 * (n + 2.0) * inv_n * inv_n1 * inv_n2) - ((n * n2 - 3.0 * n2 - 16.0 * n + 20.0) * inv_n * inv_n1 * inv_n2) ) + a1 * (8.0 * inv_n * inv_n1 * inv_n2) + (2.0 * (2.0 * n2 * n2 - n2 * n - 17.0 * n2 - 42.0 * n + 72.0)) * (inv_n * inv_n) * (inv_n1 * inv_n2) * (1.0 / 9.0) ) y = fs.copy() y[0] = y[1] = y[n] = 0.0 S = 0.0 pi_sum = 0.0 for i in range(2, n + 1): yi = y[i] if i > 1 and i < n: S += yi if i < n: pi_sum += yi * i * (n - i) # At least 1 seg site if S < 1: return np.nan pi_hat = pi_sum / (n * (n - 1.0) * 0.5) that = S / a1m1 that_sq = S * (S - 1.0) / (a1m1 * a1m1) return (pi_hat - ff * S) / np.sqrt(alpha * that + beta * that_sq)
[docs] @njit(cache=True) def achaz_y_star(ac): """ Achaz's Y* neutrality test (standardized). Folded form — does not require ancestral-state polarization. Excludes η₁ = ξ₁ + ξ_{n-1} (minor-allele singletons at frequency 1/n or (n-1)/n) from both the π and S estimators. Reference: Achaz 2008, Appendix B, Equations B19–B21. f* = (n-3) / (a_n·(n-1) - n); Var[Y*] = α*_n·θ + β*_n·θ² :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant n. :returns: Standardized Achaz Y* as a float; nan if n < 4 or no valid sites. :rtype: float """ n = int(ac[0, 0] + ac[0, 1]) if n < 4: return np.nan a1, a2 = _harmonic_sums(n) inv_n1 = 1.0 / (n - 1.0) # γ_n = E[S_{-η₁}] / θ = a_n - n/(n-1) gamma_n = a1 - n * inv_n1 if gamma_n <= 0.0: return np.nan # f* = (n-3) / (a_n*(n-1) - n) = (n-3) / ((n-1)*γ_n) [Achaz 2008 Eq 21] fstar = (n - 3.0) / ((n - 1.0) * gamma_n) inv_n = 1.0 / n n2 = n * n # α*_n (Achaz 2008 Eq B20) alpha_star = ( fstar * fstar * (a1 - n * inv_n1) + fstar * (a1 * (4.0 * (n + 1.0) * inv_n1 * inv_n1) - 2.0 * (n + 3.0) * inv_n1) - a1 * 8.0 * (n + 1.0) * inv_n * inv_n1 * inv_n1 + (n2 + n + 60.0) * inv_n * inv_n1 / 3.0 ) # β*_n (Achaz 2008 Eq B21) beta_star = ( fstar * fstar * (a2 - (2.0 * n - 1.0) * inv_n1 * inv_n1) + fstar * ( a2 * 8.0 * inv_n1 - a1 * 4.0 * inv_n * inv_n1 - (n * n2 + 12.0 * n2 - 35.0 * n + 18.0) * inv_n * inv_n1 * inv_n1 ) - a2 * 16.0 * inv_n * inv_n1 + a1 * 8.0 * inv_n * inv_n * inv_n1 + (2.0 * (2.0 * n2 * n2 + 110.0 * n2 - 255.0 * n + 126.0)) * (inv_n * inv_n) * (inv_n1 * inv_n1) / 9.0 ) # Compute S_{-η₁} and π_{-η₁}: sum over sites with 2 ≤ derived count ≤ n-2 S_total = ac.shape[0] S_excl = 0.0 pi_excl = 0.0 for j in range(S_total): k = int(ac[j, 1]) if k >= 2 and k <= n - 2: S_excl += 1.0 pi_excl += k * (n - k) if S_excl < 1.0: return np.nan pi_excl /= n * (n - 1.0) * 0.5 # θ̂ and θ̂² from S_{-η₁} (consistent with achaz_y approximation) that = S_excl / gamma_n that_sq = S_excl * (S_excl - 1.0) / (gamma_n * gamma_n) variance = alpha_star * that + beta_star * that_sq if variance <= 0.0: return np.nan return (pi_excl - fstar * S_excl) / np.sqrt(variance)
@lru_cache(maxsize=16) def _achaz_t_coeffs(n, decay=0.9): """Precompute αₙ and βₙ for T_Ω (Achaz 2009 Eq. 9). Cached per (n, decay) — computed once per sample size, O(1) on all subsequent calls. Uses the @njit ``sigma`` function for fast batch σᵢⱼ computation, then a BLAS quadratic form for βₙ. αₙ = Σᵢ i·Ωᵢ² βₙ = vᵀ·Σ·v where vᵢ = i·Ωᵢ and Σᵢⱼ = σᵢⱼ (Fu 1995) """ k = np.arange(1, n, dtype=np.float64) # i = 1..n-1 w1 = np.exp(-decay * k) # ω₁ᵢ = e^{-decay·i} w2 = np.ones(n - 1) # ω₂ᵢ = 1 (uniform/Watterson) omega = w1 / w1.sum() - w2 / w2.sum() # Ωᵢ (sums to 0) alpha_n = float(np.sum(k * omega**2)) # αₙ = Σᵢ i·Ωᵢ² (O(n)) # Full σᵢⱼ matrix via @njit sigma (batch call, fast) # Note: sigma is defined later in this module; forward ref is fine in Python. ki = k.astype(np.int64) ii, jj = np.meshgrid(ki, ki, indexing="ij") sig_mat = sigma(n, np.column_stack([ii.ravel(), jj.ravel()])).reshape(n - 1, n - 1) v = k * omega # vᵢ = i·Ωᵢ beta_n = float(v @ sig_mat @ v) # βₙ = vᵀΣv (BLAS DGEMV) return alpha_n, beta_n
[docs] def achaz_t(ac, decay=0.9): """Achaz's T_Ω neutrality test (Achaz 2009 Eq. 9). Unfolded/polarized SFS. Uses exponential weight ω₁ᵢ = e^{−0.9·i} vs uniform ω₂ᵢ = 1. Sensitive to excess low-frequency polymorphisms such as those produced by severe bottlenecks. Variance coefficients αₙ/βₙ are precomputed once per sample size via ``_achaz_t_coeffs`` (lru_cache), so per-window cost is O(n). α = 0.9 is empirical (Achaz 2009 p.254): gives positive Ωᵢ only for i/n ≤ 0.13. θ² estimated as S(S-1)/(a₁²+a₂) following Fu (1995). :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant n. Derived allele counts in ``ac[:,1]`` define the unfolded SFS. :param float decay: Exponential decay α (default 0.9 per Achaz 2009). :returns: Standardized T_Ω as a float; nan if n < 3 or no segregating sites. :rtype: float """ n = int(ac[0, 0] + ac[0, 1]) if n < 3: return np.nan # O(1) after first call alpha_n, beta_n = _achaz_t_coeffs(n, decay) a1, a2 = _harmonic_sums(n) # Build unfolded SFS ξᵢ (i = 1..n-1) from ac xi = np.zeros(n - 1, dtype=np.float64) for j in range(ac.shape[0]): k = int(ac[j, 1]) if 1 <= k <= n - 1: xi[k - 1] += 1.0 S = xi.sum() if S < 1.0: return np.nan # θ̂ and θ̂² (exact Fu 1995 form; matches pg-gpu) that = S / a1 that_sq = S * (S - 1.0) / (a1 * a1 + a2) variance = alpha_n * that + beta_n * that_sq if variance <= 0.0: return np.nan # Numerator: Σᵢ Ωᵢ·i·ξᵢ (recompute Ω; cost is O(n), negligible) k_arr = np.arange(1, n, dtype=np.float64) w1 = np.exp(-decay * k_arr) w2 = np.ones(n - 1) omega = w1 / w1.sum() - w2 / w2.sum() numerator = float(np.sum(omega * k_arr * xi)) return numerator / np.sqrt(variance)
[docs] @njit(cache=True) def fay_wu_h_norm(ac, positions=None): """ Fay & Wu’s H and its normalized form (single-population, infinite sites). Computes: - ``theta_h``: estimator that upweights high-frequency derived alleles. - ``h = pi - theta_h`` (Fay & Wu’s H). - ``h_norm``: normalized H using variance terms. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. :param numpy.ndarray positions: Optional positions (length ``S``). If provided, ``theta_h`` is divided by the accessible span ``positions[-1] - (positions[0] - 1)``. :returns: Tuple ``(theta_h, h, h_norm)`` as floats. :rtype: tuple[float, float, float] """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() fs = sfs_nb(ac[:, 1], n)[1:-1] i_arr = np.arange(1, int(n)) a1 = np.sum(1.0 / i_arr) bn = np.sum(1.0 / (i_arr * i_arr)) + 1.0 / (n * n) theta_w = S / a1 pi = 0.0 theta_h = 0.0 for k in range(1, int(n)): si = fs[k - 1] pi += (2 * si * k * (n - k)) / (n * (n - 1.0)) theta_h += (2 * si * k * k) / (n * (n - 1.0)) tl = 0.0 for k in range(1, int(n)): tl += k * fs[k - 1] tl /= n - 1.0 var1 = (n - 2.0) / (6.0 * (n - 1.0)) * theta_w theta_sq = S * (S - 1.0) / (a1 * a1 + bn) var2 = ( ((18 * n * n * (3 * n + 2) * bn) - (88 * n * n * n + 9 * n * n - 13 * n - 6)) / (9.0 * n * (n - 1.0) * (n - 1.0)) ) * theta_sq h = pi - theta_h if positions is not None: theta_h = theta_h / (positions[-1] - (positions[0] - 1)) return theta_h, h, h / np.sqrt(var1 + var2)
[docs] @njit(cache=True) def zeng_e(ac): """ Zeng’s E statistic (single-population, infinite sites), standardized. Contrasts Watterson’s estimator with a linear SFS component related to high-frequency derived signal. Useful alongside Tajima’s D and Fay & Wu’s H. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. :returns: Standardized Zeng’s E as a float. :rtype: float """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() fs = sfs_nb(ac[:, 1], n)[1:-1] # i_arr = np.arange(1, int(n)) a1, bn = _harmonic_sums(n) # bn = np.sum(1.0 / (i_arr * i_arr)) theta_w = S / a1 tl = 0.0 for k in range(1, int(n)): tl += k * fs[k - 1] tl /= n - 1.0 theta_sq = S * (S - 1.0) / (a1 * a1 + bn) var1 = (n / (2.0 * (n - 1.0)) - 1.0 / a1) * theta_w var2 = ( bn / a1 / a1 + 2 * (n / (n - 1.0)) * (n / (n - 1.0)) * bn - 2 * (n * bn - n + 1) / ((n - 1.0) * a1) - (3 * n + 1) / (n - 1.0) ) * theta_sq return (tl - theta_w) / np.sqrt(var1 + var2)
[docs] @njit(cache=True) def fuli_f_star(ac): """ Fu and Li’s F* (starred) statistic (no outgroup required). Focuses on deviations in the **singleton** class of the (folded) SFS, contrasting singleton abundance with diversity (π). The starred form does not require ancestral state polarization. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. :returns: Fu & Li’s F* as a float. :rtype: float """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() an, bn = _harmonic_sums(n) an1 = an + np.true_divide(1, n) denom_pairs = n * (n - 1.0) pi = 0.0 for j in range(S): aj = ac[j, 1] pi += 2.0 * aj * (n - aj) / denom_pairs ss = ((ac[:, 1] == 1) | (ac[:, 1] == n - 1)).sum() vfs = ( ( (2 * (n**3.0) + 110.0 * (n**2.0) - 255.0 * n + 153) / (9 * (n**2.0) * (n - 1.0)) ) + ((2 * (n - 1.0) * an) / (n**2.0)) - ((8.0 * bn) / n) ) / ((an**2.0) + bn) ufs = ( ( n / (n + 1.0) + (n + 1.0) / (3 * (n - 1.0)) - 4.0 / (n * (n - 1.0)) + ((2 * (n + 1.0)) / ((n - 1.0) ** 2)) * (an1 - ((2.0 * n) / (n + 1.0))) ) / an ) - vfs num = pi - ((n - 1.0) / n) * ss den = np.sqrt(ufs * S + vfs * (S * S)) return num / den
[docs] @njit(cache=True) def fuli_f(ac): """ Fu and Li’s F statistic (polarized). Uses singleton counts and diversity (π); typically assumes **derived** states are known (e.g., via outgroup). :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. :returns: Fu & Li’s F as a float. :rtype: float """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() an, bn = _harmonic_sums(n) an1 = an + 1.0 / n ss = (ac[:, 1] == 1).sum() denom_pairs = n * (n - 1.0) pi = 0.0 for j in range(S): aj = ac[j, 1] pi += 2.0 * aj * (n - aj) / denom_pairs if n == 2: cn = 1 else: cn = 2.0 * (n * an - 2.0 * (n - 1.0)) / ((n - 1.0) * (n - 2.0)) v = ( cn + 2.0 * (np.power(n, 2) + n + 3.0) / (9.0 * n * (n - 1.0)) - 2.0 / (n - 1.0) ) / (np.power(an, 2) + bn) u = ( 1.0 + (n + 1.0) / (3.0 * (n - 1.0)) - 4.0 * (n + 1.0) / np.power(n - 1, 2) * (an1 - 2.0 * n / (n + 1.0)) ) / an - v F = (pi - ss) / np.sqrt(u * S + v * np.power(S, 2)) return F
[docs] @njit(cache=True) def fuli_d_star(ac): """ Fu and Li’s D* (starred) statistic (no outgroup required). Compares the number of segregating sites against singleton counts in the folded spectrum. The starred form does not require ancestral state polarization. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. :returns: Fu & Li’s D* as a float. :rtype: float """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() an, bn = _harmonic_sums(n) an1 = an + np.true_divide(1, n) cn = 2 * ((n * an) - 2 * (n - 1)) / ((n - 1) * (n - 2)) dn = ( cn + np.true_divide((n - 2), ((n - 1) ** 2)) + np.true_divide(2, (n - 1)) * (3.0 / 2 - (2 * an1 - 3) / (n - 2) - 1.0 / n) ) vds = ( ((n / (n - 1.0)) ** 2) * bn + (an**2) * dn - 2 * (n * an * (an + 1)) / ((n - 1.0) ** 2) ) / (an**2 + bn) uds = ((n / (n - 1.0)) * (an - n / (n - 1.0))) - vds ss = ((ac[:, 1] == 1) | (ac[:, 1] == n - 1)).sum() Dstar1 = ((n / (n - 1.0)) * S - (an * ss)) / (uds * S + vds * (S**2)) ** 0.5 return Dstar1
[docs] @njit(cache=True) def fuli_d(ac): """ Fu and Li’s D statistic (polarized form). Uses the total number of segregating sites and singletons; typically assumes **derived** states are known (e.g., via outgroup) to define singletons. :param numpy.ndarray ac: Allele counts array of shape ``(S, 2)`` (ancestral, derived), constant ``n``. :returns: Fu & Li’s D as a float. :rtype: float """ # count segregating variants S = ac.shape[0] # assume number of chromosomes sampled is constant for all variants n = ac[0].sum() an, bn = _harmonic_sums(n) ss = (ac[:, 1] == 1).sum() if n == 2: cn = 1 else: cn = 2.0 * (n * an - 2.0 * (n - 1.0)) / ((n - 1.0) * (n - 2.0)) v = 1.0 + (np.power(an, 2) / (bn + np.power(an, 2))) * (cn - (n + 1.0) / (n - 1.0)) u = an - 1.0 - v D = (S - ss * an) / np.sqrt(u * S + v * np.power(S, 2)) return D
[docs] @njit(cache=True) def neutrality_stats(ac, positions): """ # Numba unified call to compute SFS neutrality statistics. Trying to avoid numba overhead as much as possible Returns a length-12 array: [0] tajima_d [1] theta_h (scaled by sequence length if positions provided) [2] h_raw (Fay & Wu's H = pi - theta_h, unscaled) [3] h_norm (normalized Fay & Wu's H) [8] pi (mean pairwise diversity, absolute) [9] theta_w (Watterson's theta, absolute) [10] theta_w_per_base (nan if < 2 positions) [11] pi_per_base (nan if < 2 positions) """ out = np.empty(8, dtype=np.float64) S = ac.shape[0] if S < 3: for i in range(8): out[i] = np.nan return out # Constant counts dac = ac[:, 1] n = int(ac[0, 0] + ac[0, 1]) n_f = float(n) sfs = np.zeros(n - 1, dtype=np.int64) ss_derived = 0 for i in range(S): k = int(dac[i]) if 0 < k < n: sfs[k - 1] += 1 if k == 1: ss_derived += 1 # harmonic sums an = 0.0 bn = 0.0 for i in range(1, n): inv = 1.0 / i an += inv bn += inv * inv an1 = an + 1.0 / n_f # pi, theta_h, theta_l theta_w = S / an pi = 0.0 theta_h = 0.0 theta_l = 0.0 denom = n_f * (n_f - 1.0) for k_idx in range(n - 1): k = k_idx + 1 count = sfs[k_idx] pi += (2.0 * count * k * (n_f - k)) / denom theta_h += (2.0 * count * k * k) / denom theta_l += k * count theta_l /= n_f - 1.0 # Per-base values if positions.size >= 2: n_bases = float(positions[-1] - positions[0] + 1) theta_w_per_base = theta_w / n_bases pi_per_base = pi / n_bases theta_h_final = theta_h / float(positions[-1] - (positions[0] - 1)) else: theta_w_per_base = np.nan pi_per_base = np.nan theta_h_final = np.nan # Tajima's D e1_d = ((n_f + 1.0) / (3.0 * (n_f - 1.0)) - (1.0 / an)) / an e2_d = ( (2.0 * (n_f**2 + n_f + 3.0) / (9.0 * n_f * (n_f - 1.0))) - ((n_f + 2.0) / (an * n_f)) + (bn / (an**2)) ) / (an**2 + bn) d_stdev = np.sqrt(e1_d * S + e2_d * S * (S - 1.0)) tajima_d_val = (pi - theta_w) / d_stdev if d_stdev > 0.0 else np.nan # Fay & Wu's H + normalized h_raw = pi - theta_h bn_h = bn + 1.0 / (n_f * n_f) var1_h = (n_f - 2.0) / (6.0 * (n_f - 1.0)) * theta_w th_sq = S * (S - 1.0) / (an**2 + bn) var2_h = ( ( (18.0 * n_f**2 * (3.0 * n_f + 2.0) * bn_h) - (88.0 * n_f**3 + 9.0 * n_f**2 - 13.0 * n_f - 6.0) ) / (9.0 * n_f * (n_f - 1.0) ** 2) * th_sq ) denom_h = var1_h + var2_h h_norm = h_raw / np.sqrt(denom_h) if denom_h > 0.0 else np.nan # Output out[0] = tajima_d_val out[1] = theta_h out[2] = h_raw out[3] = h_norm out[4] = pi out[5] = theta_w out[6] = theta_w_per_base out[7] = pi_per_base return out
################## LASSI def get_empir_freqs_np_fast(hap): """ Optimized version to calculate the empirical frequencies of haplotypes. Parameters: - hap (numpy.ndarray): Shape (S, n), where S = SNPs, n = individuals. Returns: - k_counts (numpy.ndarray): Counts of each unique haplotype. - h_f (numpy.ndarray): Frequencies of each unique haplotype. """ # Transpose so each haplotype is a row hap_t = hap.T # shape (n, S) # Hash each haplotype row into a unique identifier hashes = np.ascontiguousarray(hap_t).view( np.dtype((np.void, hap_t.dtype.itemsize * hap_t.shape[1])) ) # Use np.unique on 1D hashes _, unique_counts = np.unique(hashes, return_counts=True) # Sort counts in descending order k_counts = np.sort(unique_counts)[::-1] h_f = k_counts / hap_t.shape[0] return k_counts, h_f def process_spectra(k: np.ndarray, h_f: np.ndarray, K_truncation: int, n_ind: int): """ Process haplotype count and frequency spectra. Parameters: - k (numpy.ndarray): Counts of each unique haplotype. - h_f (numpy.ndarray): Empirical frequencies of each unique haplotype. - K_truncation (int): Number of haplotypes to consider. - n_ind (int): Number of individuals. Returns: - Kcount (numpy.ndarray): Processed haplotype count spectrum. - Kspect (numpy.ndarray): Processed haplotype frequency spectrum. """ # Truncate count and frequency spectrum Kcount = k[:K_truncation] Kspect = h_f[:K_truncation] # Normalize count and frequency spectra Kcount = Kcount / Kcount.sum() * n_ind Kspect = Kspect / Kspect.sum() # Pad with zeros if necessary if Kcount.size < K_truncation: Kcount = np.concatenate([Kcount, np.zeros(K_truncation - Kcount.size)]) Kspect = np.concatenate([Kspect, np.zeros(K_truncation - Kspect.size)]) return Kcount, Kspect
[docs] def LASSI_spectrum_and_Kspectrum(input_data, K_truncation=10, window=110, step=5): """ Compute haplotype count and frequency spectra within sliding windows. Parameters: - hap (numpy.ndarray): Array of haplotypes where each column represents an individual and each row represents a SNP. - pos (numpy.ndarray): Array of SNP positions. - K_truncation (int): Number of haplotypes to consider. - window (int): Size of the sliding window. - step (int): Step size for sliding the window. Returns: - K_count (numpy.ndarray): Haplotype count spectra for each window. - K_spectrum (numpy.ndarray): Haplotype frequency spectra for each window. - windows_centers (numpy.ndarray): Centers of the sliding windows. """ filterwarnings( "ignore", category=RuntimeWarning, message="invalid value encountered in scalar divide", ) np.seterr(divide="ignore", invalid="ignore") if isinstance(input_data, list) or isinstance(input_data, tuple): hap_int, position_masked = input_data elif isinstance(input_data, str): try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = genome_reader(input_data) # freqs = ac[:, 1] / ac.sum(axis=1) except Exception: try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = parse_ms_numpy(input_data) # freqs = ac[:, 1] / ac.sum(axis=1) except Exception: return None else: return None K_count = [] K_spectrum = [] windows_centers = [] S, n = hap_int.shape for i in range(0, S, step): hap_subset = hap_int[i : i + window, :] # Calculate window center based on median SNP position windows_centers.append(np.median(position_masked[i : i + window])) # Compute empirical frequencies and process spectra for the window k, h_f = get_empir_freqs_np_fast(hap_subset) K_count_subset, K_spectrum_subset = process_spectra(k, h_f, K_truncation, n) K_count.append(K_count_subset) K_spectrum.append(K_spectrum_subset) if hap_subset.shape[0] < window: break return np.array(K_count), np.array(K_spectrum), np.array(windows_centers)
def neut_average(K_spectrum: np.ndarray) -> np.ndarray: """ Compute the neutral average of haplotype frequency spectra. Parameters: - K_spectrum (numpy.ndarray): Haplotype frequency spectra. Returns: - out (numpy.ndarray): Neutral average haplotype frequency spectrum. """ weights = [] S, n = K_spectrum.shape # Compute mean spectrum gwide_K = np.mean(K_spectrum, axis=0) # Calculate weights for averaging if S % 5e4 == 0: weights.append(5e4) else: small_weight = S % 5e4 weights.append(small_weight) # Compute weighted average out = np.average([gwide_K], axis=0, weights=weights) return out @njit(cache=True) def easy_likelihood(K_neutral, K_count, K_truncation): """ Basic computation of the likelihood function; runs as-is for neutrality, but called as part of a larger process for sweep model """ likelihood_list = [] for i in range(K_truncation): likelihood_list.append(K_count[i] * np.log(K_neutral[i])) likelihood = sum(likelihood_list) return likelihood @njit(cache=True) def sweep_likelihood( K_neutral, K_count, K_truncation, m_val, epsilon, epsilon_max, sweep_mode=4 ): """ Computes the likelihood of a sweep under optimized parameters. sweep_mode controls how frequency is redistributed among the m sweeping haplotype classes: 1 — Zipf (1/j), normalized 2 — Zipf squared (1/j²), normalized 3 — exponential (exp(-j)), normalized 4 — exponential squared (exp(-j²)), normalized [default] 5 — uniform (1/m) """ if m_val != K_truncation: altspect = np.zeros(K_truncation) tailclasses = np.zeros(K_truncation - m_val) neutdiff = np.zeros(K_truncation - m_val) tailinds = np.arange(m_val + 1, K_truncation + 1) for i in range(len(tailinds)): ti = tailinds[i] denom = K_truncation - m_val - 1 if denom != 0: the_ns = epsilon_max - ((ti - m_val - 1) / denom) * ( epsilon_max - epsilon ) else: the_ns = epsilon tailclasses[i] = the_ns neutdiff[i] = K_neutral[ti - 1] - the_ns headinds = np.arange(1, m_val + 1) for hd in headinds: altspect[hd - 1] = K_neutral[hd - 1] neutdiff_all = np.sum(neutdiff) # Precompute denominator for normalized modes 1-4 denom_sum = 0.0 for x in headinds: if sweep_mode == 1: denom_sum += 1.0 / float(x) elif sweep_mode == 2: denom_sum += 1.0 / float(x * x) elif sweep_mode == 3: denom_sum += np.exp(-float(x)) elif sweep_mode == 4: denom_sum += np.exp(-float(x * x)) # sweep_mode 5: uniform, no normalization needed for ival in headinds: if sweep_mode == 1: theadd = (1.0 / float(ival) / denom_sum) * neutdiff_all elif sweep_mode == 2: theadd = (1.0 / float(ival * ival) / denom_sum) * neutdiff_all elif sweep_mode == 3: theadd = (np.exp(-float(ival)) / denom_sum) * neutdiff_all elif sweep_mode == 4: theadd = (np.exp(-float(ival * ival)) / denom_sum) * neutdiff_all else: # sweep_mode == 5 theadd = (1.0 / float(m_val)) * neutdiff_all altspect[ival - 1] += theadd altspect[m_val:] = tailclasses output = easy_likelihood(altspect, K_count, K_truncation) else: output = easy_likelihood(K_neutral, K_count, K_truncation) return output @njit(cache=True) def compute_epsilon_values(K_truncation, K_neutral_last): epsilon_min = 1 / (K_truncation * 100) values = [] for i in range(1, 101): val = i * epsilon_min if val <= K_neutral_last: values.append(val) return np.array(values) @njit(cache=True) def T_m_statistic_core(K_counts, K_neutral, windows, K_truncation, sweep_mode=4): num_windows = len(windows) m_vals = K_truncation + 1 epsilon_values = compute_epsilon_values(K_truncation, K_neutral[-1]) # Estimate max rows possible: 1 row per window output = np.zeros( (num_windows, 6 + len(K_counts[0])) ) # 6 meta values + K_iter size for j in range(num_windows): w = windows[j] K_iter = K_counts[j] null_likelihood = easy_likelihood(K_neutral, K_iter, K_truncation) best_likelihood = -np.inf best_m = 0 best_e = 0.0 for e in epsilon_values: for m in range(1, m_vals): alt_like = sweep_likelihood( K_neutral, K_iter, K_truncation, m, e, K_neutral[-1], sweep_mode ) likelihood_diff = 2 * (alt_like - null_likelihood) if likelihood_diff > best_likelihood: best_likelihood = likelihood_diff best_m = m best_e = e # Build the output row output[j, 0] = w output[j, 1] = best_likelihood output[j, 2] = best_m output[j, 3] = best_e output[j, 4] = K_neutral[-1] output[j, 5] = sweep_mode output[j, 6:] = K_iter return output
[docs] def T_m_statistic_fast( K_counts, K_neutral, windows, K_truncation, sweep_mode=4, _iter=0 ): t_m = T_m_statistic_core(K_counts, K_neutral, windows, K_truncation, sweep_mode) stats_schema = { "window_lassi": pl.Int64, "T": pl.Float64, "m": pl.Float64, "frequency": pl.Float64, "e": pl.Float64, "model": pl.Float64, } k_schema = {"Kcounts_" + str(i): pl.Float64 for i in range(1, K_truncation + 1)} output = pl.DataFrame( t_m, schema=pl.Schema({**stats_schema, **k_schema}) ).with_columns(pl.lit(_iter).cast(pl.Int64).alias("iter")) return output
[docs] def compute_t_m( sim_list, K_truncation=10, w_size=201, w_step=10, K_neutral=None, sweep_mode=4, center=[5e4, 1.2e6 - 5e4], windows=[100000], step=int(1e5), nthreads=1, params=None, parallel_manager=None, ): """ Compute LASSI-style T and m-hat over a set of simulations. The function builds truncated haplotype-frequency spectra per window, estimates a neutral spectrum if not provided, scores each window with T and m, and then reduces the scan to fixed physical windows around the specified centers. If ``params`` are provided, they are attached and the result may be pivoted to feature vectors format. :param sim_list: Iterable of simulation items consumable by LASSI_spectrum_and_Kspectrum. :type sim_list: sequence :param K_truncation: Number of top haplotype counts retained in the truncated spectrum. Default 5. :type K_truncation: int :param w_size: Sliding window size in SNPs used to build K-spectra. Default 110. :type w_size: int :param step: Step in SNPs between consecutive windows. Default 5. :type step: int :param K_neutral: Precomputed neutral truncated spectrum; if None, estimated via neut_average. Optional. :type K_neutral: array-like or None :param windows: Physical window widths (bp) for cut_t_m_argmax. Default [50000, 100000, 200000, 500000, 1000000]. :type windows: list[int] :param center: Inclusive physical range (bp) defining centers. Default [500000, 700000]. :type center: list[int] :param nthreads: Number of joblib workers. Default 1. :type nthreads: int :param params: Optional parameter matrix aligned to sim_list with columns [s, t, f_i, f_t]. :type params: array-like or None :param parallel_manager: Existing joblib.Parallel to reuse; if None, a new one is created. :type parallel_manager: joblib.Parallel or None :returns: (t_m_cut, K_neutral) :rtype: tuple :notes: T is a log-likelihood ratio comparing sweep-distorted vs neutral truncated spectra. m is the estimated number of sweeping haplotypes (1 = hard; >1 = soft), upper-bounded by ``K_truncation``. """ from . import Parallel, delayed if parallel_manager is None: parallel_manager = Parallel(n_jobs=nthreads, verbose=1) hfs_stats = parallel_manager( delayed(LASSI_spectrum_and_Kspectrum)(hap_data, K_truncation, w_size, w_step) for _index, (hap_data) in enumerate(sim_list[:], 1) ) K_counts, K_spectrum, windows_lassi = zip(*hfs_stats) if K_neutral is None: K_neutral = neut_average(np.vstack(K_spectrum)) t_m = parallel_manager( delayed(T_m_statistic_fast)( kc, K_neutral, windows_lassi[_iter - 1], K_truncation, sweep_mode=sweep_mode, _iter=_iter, ) for _iter, (kc) in enumerate(K_counts, 1) ) t_m_cut = parallel_manager( delayed(cut_t_m_argmax)( t, windows=windows, center=center, step=step, _iter=_iter ) for _iter, t in enumerate(t_m, 1) ) t_m_cut = pl.concat(t_m_cut) t_m_cut = t_m_cut.select( [ "iter", "window", "center", *[ col for col in t_m_cut.columns if col not in ("iter", "window", "center") ], ] ) if params is not None: t_m_cut = pivot_feature_vectors( pl.concat( [ pl.DataFrame( np.repeat( params, t_m_cut.select(["center", "window"]).unique().shape[0], axis=0, ), schema=["s", "t", "f_i", "f_t"], ), t_m_cut, ], how="horizontal", ) ) return t_m_cut, K_neutral
def cut_t_m_argmax( df_t_m, center=[5e4, 1.2e6 - 5e4], windows=[100000], step=1e5, _iter=1, ): K_names_c = df_t_m.select("^Kcounts_.*$").schema t_schema = OrderedDict( { "T": pl.Float64, "m": pl.Float64, **K_names_c, "iter": pl.Int64, "window": pl.Int64, "center": pl.Int64, } ) out = [] centers = np.arange(center[0], center[1] + step, step).astype(int) iter_c_w = list(product(centers, windows)) for c, w in iter_c_w: # for w in [1000000]: lower = c - w / 2 upper = c + w / 2 df_t_m_subset = df_t_m.filter( (pl.col("window_lassi") > lower) & (pl.col("window_lassi") < upper) ) try: max_t = df_t_m_subset["T"].arg_max() # df_t_m_subset = df_t_m_subset[df_t_m_subset.m > 0] # max_t = df_t_m_subset[df_t_m_subset.m > 0].m.argmin() df_t_m_subset = df_t_m_subset[max_t : max_t + 1, :] df_t_m_subset = df_t_m_subset.select( pl.exclude(["frequency", "e", "model", "window_lassi"]) ).with_columns( pl.lit(w).cast(pl.Int64).alias("window"), pl.lit(c).cast(pl.Int64).alias("center"), ) out.append(df_t_m_subset) except Exception: tmp = pl.DataFrame( { col: [ None if col not in ["iter", "center", "window"] else _iter if col == "iter" else c if col == "center" else w ] for col in t_schema.keys() }, schema=t_schema, ) out.append(tmp) out = pl.concat(out) return out def run_lassi( hap_data, K_truncation=10, w_size=201, step=10, K_neutral=None, sweep_mode=4 ): try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = genome_reader(hap_data) # freqs = ac[:, 1] / ac.sum(axis=1) except Exception: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = parse_ms_numpy(hap_data) # freqs = ac[:, 1] / ac.sum(axis=1) K_counts, K_spectrum, windows_lassi = LASSI_spectrum_and_Kspectrum( [hap_int, position_masked], K_truncation, w_size, int(step) ) if K_neutral is None: K_neutral = neut_average(np.vstack(K_spectrum)) t_m = T_m_statistic_fast( K_counts, K_neutral, windows_lassi, K_truncation, sweep_mode=sweep_mode )[:, :-1] return t_m ################## saltiLASSI @njit(parallel=False, cache=True) def _lassip_precompute(K_counts, K_neutral, K_truncation, sweep_mode=4): """ saltiLASSI precomputation: build 3D dF table using geometric mixture. Matches the C++ lassip reference implementation (lassip-winstats.cpp L265-267): L1 = Σ_i [α_i · sweep_lik_i(m,ε) + (1-α_i) · null_lik_i] L1 - L0 = Σ_i α_i · (sweep_lik_i - null_lik_i) = dot(alpha, dF[mi, ei, :]) dF[mi, ei, i] = sweep_likelihood(i, m, ε) − easy_likelihood(i) Memory: K × n_e × n_windows × 8 bytes ≈ 384 KB for typical inputs (fits in L2 cache). Returns (dF, null_composite, epsilon_values). """ n_windows = K_counts.shape[0] epsilon_values = compute_epsilon_values(K_truncation, K_neutral[-1]) epsilon_max = K_neutral[-1] n_e = len(epsilon_values) null_per_window = np.zeros(n_windows) for i in range(n_windows): null_per_window[i] = easy_likelihood(K_neutral, K_counts[i], K_truncation) null_composite = np.sum(null_per_window) dF = np.zeros((K_truncation, n_e, n_windows)) for mi in range(K_truncation): m_val = mi + 1 for ei in range(n_e): epsilon = epsilon_values[ei] for i in range(n_windows): s_lik = sweep_likelihood( K_neutral, K_counts[i], K_truncation, m_val, epsilon, epsilon_max, sweep_mode, ) dF[mi, ei, i] = s_lik - null_per_window[i] return dF, null_composite, epsilon_values @njit(parallel=False, cache=True) def _lassip_chunk( j_start, j_end, positions, K_truncation, A_grid, dF, epsilon_values, max_extend=1e5 ): """ saltiLASSI phase 3: optimize (m, A, ε) for target windows j_start..j_end-1. Hot path: exp() + multiply-add only — no log() calls. L1 - L0 = Σ_i α_i · dF[mi, ei, i] (dot product with precomputed 3D dF table). max_extend gates which source windows contribute: only windows within max_extend bp of the target z* are included (alpha=0 beyond). Matches C++ lassip MAX_EXTEND_BP (default 100000 bp). Distances are precomputed once per target window and reused across all A/m/ε iterations. The inner loop `delta += alpha[i] * dF[mi, ei, i]` is a pure multiply-add reduction over a contiguous float64 slice — numba auto-vectorizes to AVX2 (4 doubles/cycle). Output column 3 stores 1/best_A (C++ convention: `maxA = 1.0/exp(A_loop)`). Returns output array of shape (j_end - j_start, 5). """ n_windows = positions.shape[0] n_chunk = j_end - j_start output = np.zeros((n_chunk, 5)) for jj in range(n_chunk): j = j_start + jj z_star = positions[j] best_lambda = -np.inf best_m = 1 best_A = A_grid[0] best_e = epsilon_values[0] # precompute distances once — reused across all A/m/ε iterations distances = np.empty(n_windows) for i in range(n_windows): distances[i] = np.abs(positions[i] - z_star) for a_idx in range(len(A_grid)): A_val = A_grid[a_idx] alpha = np.zeros(n_windows) for i in range(n_windows): if distances[i] <= max_extend: alpha[i] = np.exp(-A_val * distances[i]) # else alpha[i] stays 0 (beyond MAX_EXTEND) for mi in range(K_truncation): for ei in range(len(epsilon_values)): delta = 0.0 for i in range(n_windows): delta += alpha[i] * dF[mi, ei, i] lambda_val = 2.0 * delta if lambda_val > best_lambda: best_lambda = lambda_val best_m = mi + 1 best_A = A_val best_e = epsilon_values[ei] output[jj, 0] = positions[j] output[jj, 1] = best_lambda output[jj, 2] = best_m # LASSIP C++: maxA = 1.0/exp(A_loop) output[jj, 3] = 1.0 / best_A output[jj, 4] = best_e return output
[docs] def Lambda_statistic_fast( K_counts, K_neutral, positions, K_truncation, n_A=101, sweep_mode=4, nthreads=1, max_extend=1e5, _iter=0, ): """ Compute saltiLASSI Λ statistics for all windows, returning a Polars DataFrame. Uses geometric mixture formula matching C++ lassip (lassip-winstats.cpp L265-267): ``L1 - L0 = Σ_i α_i · dF[mi, ei, i]`` where ``dF[mi, ei, i] = sweep_likelihood(i, m, ε) − easy_likelihood(i)`` is precomputed once. Precomputes dF single-threaded (384 KB, fits in L2 cache), then distributes target windows across joblib loky workers. dF serializes cheaply (384 KB); each worker JITs _lassip_chunk once then processes all its assigned windows. Threading backend does not parallelize because numba @njit(parallel=False) does not release the GIL. precompute: O(n_windows × K × n_ε) — all log() calls happen here. hot path: O(n_windows² × n_A × K × n_ε) — multiply-adds only, within max_extend. Parameters ---------- K_counts : np.ndarray, shape (n_windows, K_truncation) K_neutral : np.ndarray, shape (K_truncation,) positions : np.ndarray, shape (n_windows,) K_truncation : int n_A : int Number of log-spaced A values. Default 101 (matches C++ lassip 101-point grid). sweep_mode : int Sweep haplotype redistribution model (1-5). Default 4 (exponential squared). nthreads : int Number of joblib threads. Default 1 (single-threaded). max_extend : float Maximum bp distance from target window to include in composite. Default 1e5 (100 kb), matching C++ lassip DEFAULT_MAX_EXTEND_BP. Use np.inf to sum all windows. _iter : int Replicate index attached as 'iter' column. Returns ------- pl.DataFrame with columns: window_lassip, Lambda, m, A, frequency, iter. A column stores 1/actual_A (C++ lassip: maxA = 1.0/exp(A_loop)). """ from . import Parallel, delayed n_windows = len(positions) d_min = float(np.min(np.diff(np.sort(positions)))) A_min = -np.log(0.99999) / d_min A_max = -np.log(0.00001) / d_min A_grid = np.geomspace(A_min, A_max, n_A) # Precompute 3D dF table — all log() calls happen here dF, null_composite, epsilon_values = _lassip_precompute( K_counts, K_neutral, K_truncation, sweep_mode ) # Distribute target windows across loky workers. # Each worker JITs _lassip_chunk once then processes all its windows. # dF (384 KB) serializes cheaply; loky provides real parallelism unlike threading chunk_size = max(1, ceil(n_windows / nthreads)) chunks = [ (i, min(i + chunk_size, n_windows)) for i in range(0, n_windows, chunk_size) ] results = Parallel(n_jobs=nthreads, backend="loky")( delayed(_lassip_chunk)( j_start, j_end, positions, K_truncation, A_grid, dF, epsilon_values, max_extend, ) for j_start, j_end in chunks ) result = np.vstack(results) return pl.DataFrame( result, schema={ "window_lassip": pl.Int64, "Lambda": pl.Float64, "m": pl.Float64, "A": pl.Float64, "frequency": pl.Float64, }, ).with_columns(pl.lit(_iter).cast(pl.Int64).alias("iter"))
[docs] def run_lassip( hap_data, K_truncation=10, w_size=201, step=10, K_neutral=None, n_A=100, sweep_mode=4, nthreads=1, ): """ Run saltiLASSI on a single haplotype dataset (VCF or ms format). Extends run_lassi by computing the spatially-aware Λ statistic (DeGiorgio & Szpiech 2022) instead of the per-window T statistic. Reuses LASSI_spectrum_and_Kspectrum and neut_average unchanged. **Algorithm and performance vs C++ lassip** saltiLASSI composites LASSI per-window log-likelihoods with a spatial decay kernel: L1(m,ε,A,z*) = Σ_i α_i · sweep_lik(i,m,ε) + (1−α_i) · null_lik(i) L1 − L0 = Σ_i α_i · dF[mi,ei,i] ← dot product, no log() where α_i = exp(−A · abs(z_i − z*)) and dF[mi,ei,i] = sweep_lik(i,m,ε) − null_lik(i) is precomputed once in O(n_windows × K × n_ε) before the hot path. +---------------------------+----------------------------------------+----------------------------------+ | Aspect | C++ lassip | This implementation | +---------------------------+----------------------------------------+----------------------------------+ | Mixture formula | Geometric (paper uses arithmetic) | Geometric — identical ✓ | | Hot-path inner work | K log() + K mul-add per source window | 1 mul-add per source window | | log() calls total | ~23 B for 480 windows | ~48 K (precompute only) | | q / dF working set | n_win × n_ε × m × K × 8B ≈ 38 MB | K × n_ε × n_win × 8B ≈ 384 KB | | Cache behaviour | L3 misses (38 MB) | L2 resident (384 KB) | | sweep_mode | Fixed at compile time (oopt flag) | Runtime parameter 1–5 | | Spatial cutoff | MAX_EXTEND hard distance limit | Sums all windows (no cutoff) | | Parallelism | OpenMP (process forks) | joblib loky (separate processes) | | Observed 1-thread speed | ~120 s for 480 windows (39K loci) | ~2–3 s for 480 windows | +---------------------------+----------------------------------------+----------------------------------+ The ~40–60× single-thread speedup comes from three compounding effects: 1. dF precomputation collapses K from the hot path → 10× fewer operations per inner step. 2. 384 KB dF fits in L2 cache; C++ 38 MB q array thrashes L3. 3. Pure multiply-add hot path is SIMD-vectorizable; log() calls are not. **Complexity**: O(n_windows² × n_A × K × n_ε). For n_windows≈480 on a 1.2 Mb locus, this is ~2.3 B multiply-adds ≈ 2–3 s single-thread; ~0.3 s with 8 threads. Use a larger step (e.g. step=50) to reduce n_windows and quadratically cut runtime. Parameters ---------- hap_data : str or tuple Path to VCF/ms file, or (hap_int, position_masked) tuple. K_truncation : int Number of top haplotypes in truncated spectrum. Default 10. w_size : int Sliding window size in SNPs for building K-spectra. Default 201. step : int Step in SNPs between consecutive K-spectrum windows. Default 10. K_neutral : np.ndarray or None Pre-computed neutral spectrum. Estimated via neut_average if None. n_A : int Number of log-spaced A grid points (spatial decay rates). Default 100. sweep_mode : int Sweep haplotype redistribution model passed to sweep_likelihood (1–5). Default 4. nthreads : int Number of joblib threads for the per-target phase. Default 1. Returns ------- pl.DataFrame with columns: window_lassip, Lambda, m, A, frequency, iter """ try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = genome_reader(hap_data) except Exception: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = parse_ms_numpy(hap_data) K_counts, K_spectrum, windows_centers = LASSI_spectrum_and_Kspectrum( [hap_int, position_masked], K_truncation, w_size, int(step) ) if K_neutral is None: K_neutral = neut_average(np.vstack(K_spectrum)) return Lambda_statistic_fast( K_counts, K_neutral, windows_centers, K_truncation, n_A=n_A, sweep_mode=sweep_mode, nthreads=nthreads, )
################## RAISD @njit(cache=True) def compute_mu_var(start_idx, end_idx, snp_positions, D_ln, W_sz): """Variation component of the RAiSD mu statistic (paper Eq. 1). Measures the reduction of nucleotide diversity in a SNP window. A hard selective sweep purges linked variation, so the physical span covered by ``W_sz`` consecutive SNPs shrinks relative to the genome-wide rate. The ratio (window_span / expected_span) rises near the selected site. Args: start_idx: Index of the first SNP in the window (into snp_positions). end_idx: One-past index of the last SNP in the window. snp_positions: Full array of SNP physical positions (bp), shape (S,). D_ln: Physical span of the entire region being scanned (bp). Computed once as ``(positions[-1] + 1) - positions[0]``. W_sz: Number of SNPs in the window (== end_idx - start_idx). Returns: float: mu_var >= 0. Larger values indicate less variation (stronger sweep signal). Under neutrality the expectation is ~1. """ l_start = snp_positions[start_idx] l_end = snp_positions[end_idx - 1] # (window physical span) / (expected span if SNPs were uniformly distributed) # multiplied by total SNP count to normalise for region-wide density return ((l_end - l_start) / (D_ln * W_sz)) * snp_positions.shape[0] @njit(cache=True) def compute_mu_sfs(window, n, theta_W): """Site-frequency-spectrum component of the RAiSD mu statistic (paper Eq. 2). A selective sweep elevates high-frequency derived alleles (Fay-Wu signal). This function approximates that signal without requiring ancestral-state polarization by counting "edge" SNPs — those whose derived allele count equals 1 (singletons) or n-1 (near-fixation) — and normalising by Watterson's theta harmonic sum to account for sample size. Args: window: Haplotype sub-matrix for the current SNP window, shape (W_sz, n), with 0/1 entries (rows = SNPs, cols = haplotypes). n: Number of haplotypes (= window.shape[1]). theta_W: Watterson's harmonic sum: sum(1/k, k=1..n-1). Precomputed by ``_harmonic_sums(n)[0]`` in mu_stat to avoid recomputation on every window. Returns: float: mu_sfs >= 0. Larger values indicate more edge-frequency SNPs (stronger sweep signal). Returns np.nan for empty windows. """ if window.shape[0] == 0: return np.nan # Count allele-1 copies per SNP (row sums over haplotypes) derived_counts = np.sum(window, axis=1) # Edge SNPs: singletons (count==1) or near-fixed alleles (count==n-1) # These are the tails of the SFS most enriched by a sweep edge_mask = (derived_counts == 1) | (derived_counts == n - 1) n_edges = np.sum(edge_mask) W_sz = window.shape[0] # Normalise edge fraction by Watterson's correction for sample size return (n_edges / W_sz) * theta_W @njit(cache=True) def pack_snp_row(row, n_samples): """Pack one SNP's 0/1 haplotype vector into an array of uint64 bit-words. Converts a row of haplotype alleles (0 or 1, length n_samples) into a compact bit representation stored in ceil(n_samples/64) uint64 words. Bits are packed MSB-first within each word (sample 0 → most-significant bit). Trailing bits in the last word are zeroed so that ``equal_or_complement`` and ``hash_pattern`` are not polluted by uninitialized bits. This representation enables O(ceil(n/64)) word comparisons instead of O(n) element comparisons when testing pattern equality or complement equality in ``equal_or_complement``. Args: row: 1-D array of length n_samples with 0/1 alleles for one SNP. n_samples: Number of haplotypes (must equal len(row)). Returns: np.ndarray: uint64 array of length ceil(n_samples/64) holding the packed bit representation. """ words = (n_samples + 63) // 64 packed = np.zeros(words, dtype=np.uint64) word = np.uint64(0) lcnt = 0 # bits accumulated in the current word w = 0 # index of the current word for j in range(n_samples): b = np.uint64(row[j] & 1) word = (word << np.uint64(1)) | b # shift left, insert new bit at LSB lcnt += 1 if lcnt == 64: packed[w] = word word = np.uint64(0) lcnt = 0 w += 1 if lcnt != 0: # Align the partial word to the LSB and zero trailing bits shift = np.uint64(64 - lcnt) tmp = (word << shift) >> shift packed[w] = tmp return packed @njit(cache=True) def last_word_mask(n_bits): """Bitmask for the valid bits in the final packed uint64 word. When n_bits is not a multiple of 64, the last uint64 word produced by ``pack_snp_row`` has ``r = n_bits % 64`` meaningful bits (at the LSB end) and 64-r zero-padding bits. This function returns the mask with exactly those r bits set so that complement operations and hash computations can ignore the padding. Args: n_bits: Total number of bits (== n_samples). Returns: np.uint64: Mask with the low ``n_bits % 64`` bits set. If n_bits is a multiple of 64, all 64 bits are set (full word). """ r = n_bits % 64 if r == 0: # All 64 bits are valid — full all-ones mask return np.uint64(~np.uint64(0)) # Set only the low r bits: (1 << r) - 1 return (np.uint64(1) << np.uint64(r)) - np.uint64(1) @njit(cache=True) def equal_or_complement(a, b, n_samples): """Return True if packed patterns a and b are equal or bitwise complements. RAiSD treats a SNP pattern as equivalent to its complement because the assignment of 0/1 to REF/ALT is arbitrary (polarization ambiguity). Two SNPs that differ only in which allele is labelled ancestral carry the same information for the LD component of mu. The test is done in two passes over the packed uint64 words: Pass 1 — exact equality: a[i] == b[i] for all words. Pass 2 — complement equality: a[i] == ~b[i] for full words, and (a[lw] ^ ~b[lw]) & last_word_mask == 0 for the partial last word (so padding bits don't cause a false mismatch). Args: a, b: Packed uint64 arrays from ``pack_snp_row``, same length. n_samples: Number of haplotypes (needed to mask the last word). Returns: bool: True if the two SNP patterns are identical or complementary. """ words = a.shape[0] # Pass 1: exact equality (fast path — most pairs differ immediately) eq = True for i in range(words): if a[i] != b[i]: eq = False break if eq: return True # Pass 2: complement equality — check ~b[i] == a[i] word by word full_words = n_samples // 64 last_mask = last_word_mask(n_samples) for i in range(full_words): if a[i] != (~b[i]): return False if (n_samples % 64) != 0: # For the partial last word, only compare the valid bits via the mask lw = full_words if (a[lw] ^ (~b[lw])) & last_mask != 0: return False return True @njit(cache=True) def hash_pattern(packed, n_samples): """Compute a complement-symmetric hash key for a packed SNP bit-pattern. Returns the same hash value for a pattern and its bitwise complement, so that ``get_pattern_ids`` maps polarization-equivalent patterns to the same dictionary bucket without needing to normalise patterns upfront. Two hashes are computed using the Fibonacci multiplicative hash: h1 = hash of the original packed words h2 = hash of the bitwise complement (with the last word masked to zero padding bits before flipping, so padding never affects h2) The canonical key is ``min(h1, h2)``. Regardless of which form of a pattern is encountered first (original or complement), ``min`` always selects the same representative → same bucket → correctly identified as the same pattern when ``equal_or_complement`` confirms. Fibonacci hashing constant: 0x9E3779B97F4A7C15 = floor(2^64 / phi), where phi = (1+sqrt(5))/2 (golden ratio). Multiplying an integer by this constant and discarding overflow distributes hash values nearly uniformly across the uint64 range (Knuth, TAOCP Vol. 3 §6.4). Args: packed: uint64 array from ``pack_snp_row``. n_samples: Number of haplotypes (needed to mask padding bits). Returns: np.uint64: Canonical symmetric hash key. """ # h1: hash of the original pattern h1 = np.uint64(0) for w in packed: # Fibonacci multiplicative hash: XOR-fold each word h1 ^= w * np.uint64(11400714819323198485) # h2: hash of the bitwise complement (polarization-flipped pattern) mask = last_word_mask(n_samples) comp = packed.copy() for i in range(packed.shape[0] - 1): comp[i] = ~packed[i] # flip all 64 bits of full words comp[-1] = (~packed[-1]) & mask # flip only valid bits of the last word h2 = np.uint64(0) for w in comp: h2 ^= w * np.uint64(11400714819323198485) # Symmetric key: min(h1, h2) is the same whichever form is seen first return h1 if h1 < h2 else h2 @njit(cache=True) def get_pattern_ids(hap): """Assign a canonical integer ID to each SNP's haplotype pattern. Two SNPs receive the same ID if their haplotype patterns are identical or bitwise complements (polarization-equivalent), as tested by ``equal_or_complement``. IDs are dense integers starting at 0. The deduplication table uses two parallel arrays: ``uniq_packed[u]`` — packed bit-pattern of the u-th unique pattern ``hashes[u]`` — its symmetric hash (from ``hash_pattern``) For each incoming SNP: 1. Compute the symmetric hash (cheap, O(words)). 2. Linear-scan ``hashes`` for a hash collision (fast rejection for non-matching patterns without an expensive bit comparison). 3. On a hash match, call ``equal_or_complement`` to confirm. 4. If confirmed, reuse the existing ID; otherwise register a new pattern. Note: Python dicts are not available in Numba nopython mode, so deduplication is implemented as a manual linear scan. In practice the number of distinct patterns per window is small (bounded by window_size), so this is efficient despite the O(uniq_count) scan. Args: hap: Haplotype matrix (S, n) with 0/1 entries (rows=SNPs, cols=haps). Returns: np.ndarray: int32 array of length S. ``ids[i]`` is the canonical pattern ID for SNP i; equal IDs mean the patterns are the same or complementary. """ snps, n_samples = hap.shape words_per_snp = (n_samples + 63) // 64 ids = np.empty(snps, dtype=np.int32) # Deduplication table: stores one packed pattern per unique ID uniq_packed = np.zeros((snps, words_per_snp), dtype=np.uint64) hashes = np.zeros(snps, dtype=np.uint64) uniq_count = 0 for i in range(snps): packed = pack_snp_row(hap[i], n_samples) h = hash_pattern(packed, n_samples) found = -1 # Linear scan: check hash first (fast), then full bit comparison for u in range(uniq_count): if hashes[u] == h: if equal_or_complement(uniq_packed[u], packed, n_samples): found = u break if found == -1: # New unique pattern: register it uniq_packed[uniq_count, :] = packed hashes[uniq_count] = h ids[i] = uniq_count uniq_count += 1 else: ids[i] = found return ids @njit(cache=True) def compute_mu_ld(start_idx: int, end_idx: int, pattern_ids: np.ndarray) -> float32: """LD contrast component of the RAiSD mu statistic (paper Eq. 3). A selective sweep creates long haplotype blocks on both sides of the selected site. The two flanking regions each have their own coherent LD structure (high internal LD, few distinct patterns) but different patterns from each other (high inter-flank exclusivity). This function splits the SNP window [start_idx, end_idx) in half and counts: pcntl — unique patterns in the left half [p0, p1] pcntr — unique patterns in the right half [p2, p3] excl_left — unique patterns found ONLY in the left half excl_right — unique patterns found ONLY in the right half excntsnpsl — individual left-half SNPs whose pattern is exclusive to left excntsnpsr — individual right-half SNPs whose pattern is exclusive to right The statistic is: mu_ld = (excl_left * excntsnpsl + excl_right * excntsnpsr) / (pcntl * pcntr) Near a sweep center both numerator products are large (each flank is internally coherent and mutually exclusive), while the denominator stays small (few distinct patterns per half), so mu_ld peaks at the sweep. Window pointers (using ``mid = length // 2``): p0 = start_idx (first SNP of left half) p1 = start_idx + mid - 1 (last SNP of left half) p2 = start_idx + mid (first SNP of right half) p3 = end_idx - 1 (last SNP of right half) Args: start_idx: First SNP index (into the chromosome-level ``pattern_ids``). end_idx: One-past index of the last SNP in the window. pattern_ids: int32 array of pattern IDs for all SNPs on the chromosome, produced by ``get_pattern_ids``. Returns: float32: mu_ld >= 0. Returns 1e-10 (not 0) when numerator or denominator is zero, so that mu_total = mu_var * mu_sfs * mu_ld is never exactly zero (avoids log(0) in downstream scoring). """ if end_idx <= start_idx: return float32(0.0) length = end_idx - start_idx if length == 0: return float32(0.0) # Split window into left [p0..p1] and right [p2..p3] halves mid = length // 2 p0 = int32(start_idx) p1 = int32(start_idx + mid - 1) p2 = int32(start_idx + mid) p3 = int32(end_idx - 1) # Build unique-pattern list for the left half # list_left holds distinct pattern IDs; list_left_cnt tracks their counts list_left = np.empty(length, dtype=int32) list_left_cnt = np.zeros(length, dtype=int32) list_left_size = int32(0) list_left[0] = int32(pattern_ids[p0]) list_left_cnt[0] = 1 list_left_size += 1 for i in range(p0 + 1, p1 + 1): pid = int32(pattern_ids[i]) match = 0 for j in range(list_left_size): if list_left[j] == pid: list_left_cnt[j] += 1 match = 1 break if match == 0: list_left[list_left_size] = pid list_left_cnt[list_left_size] = 1 list_left_size += 1 pcntl = int32(list_left_size) # total unique patterns in left half # Build unique-pattern list for the right half list_right = np.empty(length, dtype=int32) list_right_cnt = np.zeros(length, dtype=int32) list_right_size = int32(0) list_right[0] = int32(pattern_ids[p2]) list_right_cnt[0] = 1 list_right_size += 1 for i in range(p2 + 1, p3 + 1): pid = int32(pattern_ids[i]) match = 0 for j in range(list_right_size): if list_right[j] == pid: list_right_cnt[j] += 1 match = 1 break if match == 0: list_right[list_right_size] = pid list_right_cnt[list_right_size] = 1 list_right_size += 1 pcntr = int32(list_right_size) # total unique patterns in right half # Exclusive unique patterns (pattern-level exclusivity) # excl_left: unique patterns in left that do NOT appear in right excl_left = int32(list_left_size) for i in range(list_left_size): for j in range(list_right_size): if list_left[i] == list_right[j]: excl_left -= 1 break # excl_right: unique patterns in right that do NOT appear in left excl_right = int32(list_right_size) for i in range(list_right_size): for j in range(list_left_size): if list_right[i] == list_left[j]: excl_right -= 1 break # SNP-level exclusivity # excntsnpsl: number of left-half SNPs whose pattern is not in the right half excntsnpsl = int32(0) for i in range(p0, p1 + 1): pid_i = int32(pattern_ids[i]) match = 0 for j in range(p2, p3 + 1): if pid_i == int32(pattern_ids[j]): match = 1 break if match == 0: excntsnpsl += 1 # excntsnpsr: number of right-half SNPs whose pattern is not in the left half excntsnpsr = int32(0) for i in range(p2, p3 + 1): pid_i = int32(pattern_ids[i]) match = 0 for j in range(p0, p1 + 1): if pid_i == int32(pattern_ids[j]): match = 1 break if match == 0: excntsnpsr += 1 # Cross-products: exclusive unique patterns × exclusive SNP count per half pcntexll = int32(excl_left * excntsnpsl) pcntexlr = int32(excl_right * excntsnpsr) denom = int32(pcntl * pcntr) if (pcntexll + pcntexlr) == 0 or denom == 0: # Return a small non-zero floor so mu_total is never exactly zero return float32(1e-10) return float32((pcntexll + pcntexlr) / float32(denom))
[docs] def mu_stat(hap, snp_positions, window_size=50): """ Compute RAiSD composite sweep score :math:`\\mu` over overlapping SNP windows. For each sliding window of ``window_size`` consecutive SNPs (step = 1 SNP), this routine evaluates three components and their product: * **mu_var** – reduction-of-variation component (computed by :func:`compute_mu_var`), scaled by the region length. * **mu_sfs** – site-frequency-spectrum skew component (from :func:`compute_mu_sfs`), standardized using Watterson’s harmonic correction (``_harmonic_sums(n)[0]``). * **mu_ld** – linkage-disequilibrium contrast component (from :func:`compute_mu_ld`) using the supplied :math:`r^2` matrix. * **mu_total** – composite statistic ``mu_var * mu_sfs * mu_ld``. The window center coordinate is recorded as the midpoint between the first and last SNP positions in the window. Results are returned as a Polars DataFrame with one row per window. :param numpy.ndarray hap: Haplotype matrix of shape ``(S, n)`` with 0/1 alleles (rows = SNPs, columns = haplotypes or chromosomes). :param numpy.ndarray snp_positions: Monotonically increasing physical positions of length ``S`` (aligned to rows of ``hap``). :param int window_size: Number of consecutive SNPs per sliding window; defaults to ``50`` (RAiSD’s ``-w`` default). :returns: A Polars DataFrame with columns * ``positions`` (int): window center (bp, midpoint of first/last SNP in window) * ``mu_var`` (float): variation component * ``mu_sfs`` (float): SFS component * ``mu_ld`` (float): LD component * ``mu_total`` (float): composite score ``mu_var * mu_sfs * mu_ld`` :rtype: polars.DataFrame :notes: * ``D_ln = (snp_positions[-1] + 1) - snp_positions[0]`` is the physical span of the full input region (bp), used to scale mu_var. * ``theta_w_correction = _harmonic_sums(n)[0]`` = Watterson’s harmonic sum (1 + 1/2 + ... + 1/(n-1)), precomputed once for all windows. * ``get_pattern_ids`` is called once on the full haplotype matrix before the window loop; pattern IDs are indexed per-window by start/end. * Windows advance by one SNP (maximally overlapping). For S SNPs and window size W, the output has S - W + 1 rows. :see also: :func:`compute_mu_var`, :func:`compute_mu_sfs`, :func:`compute_mu_ld`, :func:`get_pattern_ids` """ # Physical span of the full region; used to normalise mu_var across regions D_ln = (snp_positions[-1] + 1) - snp_positions[0] S, n = hap.shape # Watterson’s harmonic correction: sum(1/k, k=1..n-1); normalises mu_sfs for sample size theta_w_correction = _harmonic_sums(n)[0] # Match RAiSD -w option (default: 50) _window_size = window_size _iter_windows = list(range(S - _window_size + 1)) mu_var_np = np.zeros(len(_iter_windows)) mu_sfs_np = np.zeros(len(_iter_windows)) mu_ld_np = np.zeros(len(_iter_windows)) mu_total_np = np.zeros(len(_iter_windows)) center_np = np.zeros(len(_iter_windows)) # Compute pattern IDs once for the full chromosome; windows index into this array pattern_ids = get_pattern_ids(hap) for i in _iter_windows: start_idx = i end_idx = i + _window_size # Center = midpoint between first and last SNP positions in this window center_pos = (snp_positions[start_idx] + snp_positions[end_idx - 1]) / 2 window = hap[start_idx:end_idx, :] if end_idx <= start_idx or end_idx > hap.shape[0]: mu_var_np[i] = np.nan mu_sfs_np[i] = np.nan mu_ld_np[i] = np.nan mu_total_np[i] = np.nan continue window = hap[start_idx:end_idx] mu_var = compute_mu_var( start_idx, end_idx, snp_positions, D_ln, end_idx - start_idx ) mu_sfs = compute_mu_sfs(window, n, theta_w_correction) mu_ld = compute_mu_ld(start_idx, end_idx, pattern_ids) mu_total = mu_var * mu_sfs * mu_ld mu_var_np[i] = mu_var mu_sfs_np[i] = mu_sfs mu_ld_np[i] = mu_ld mu_total_np[i] = mu_total center_np[i] = center_pos df_mu = pl.DataFrame( { "positions": center_np.astype(int), "mu_var": mu_var_np, "mu_sfs": mu_sfs_np, "mu_ld": mu_ld_np, "mu_total": mu_total_np, } ) # return mu_var_np,mu_sfs_np,mu_ld_np,mu_total_np return df_mu
[docs] def run_raisd(hap_data, window_size=50): """Public entry point for the RAiSD mu statistic. Accepts either a VCF/VCF.gz file path or an ms/discoal simulation text file and returns the per-SNP-window mu scores computed by ``mu_stat``. Input dispatch: 1. Tries ``genome_reader`` (allel-based VCF reader). This is the primary path for real data. 2. On any exception, falls back to ``parse_ms_numpy`` (ms/discoal text format). This allows the same function to be used for both empirical and simulated data without separate entry points. Args: hap_data: Path to a VCF/VCF.gz file or an ms/discoal output file. window_size: Number of consecutive SNPs per sliding window (default 50, matching RAiSD's ``-w`` default). Returns: polars.DataFrame: Same schema as ``mu_stat`` — columns ``[positions, mu_var, mu_sfs, mu_ld, mu_total]``, one row per window. """ try: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = genome_reader(hap_data) # freqs = ac[:, 1] / ac.sum(axis=1) except Exception: ( hap_int, rec_map_01, ac, biallelic_mask, position_masked, genetic_position_masked, ) = parse_ms_numpy(hap_data) # freqs = ac[:, 1] / ac.sum(axis=1) df_mu = mu_stat(hap_int, position_masked, window_size=window_size) return df_mu
################## Balancing stats @njit(cache=True) def calc_d(freq, core_freq, p): """Calculates the value of d, the similarity measure Parameters: freq: freq of SNP under consideration, ranges from 0 to 1 core_freq: freq of coresite, ranges from 0 to 1 p: the p parameter specifying sharpness of peak """ xf = min(core_freq, 1.0 - core_freq) f = np.minimum(freq, 1.0 - freq) maxdiff = np.maximum(xf, 0.5 - xf) corr = ((maxdiff - np.abs(xf - f)) / maxdiff) ** p return corr @njit(cache=True) def omegai_nb(freqs, core_freq, n, p): """Calculates 9a Parameters: i:freq of SNP under consideration, ranges between 0 and 1 snp_n: number of chromosomes used to calculate frequency of core SNP x: freq of coresite, ranges from 0 to 1 p: the p parameter specifying sharpness of peak """ n1num = calc_d(freqs, core_freq, p) n1denom = np.sum(calc_d(np.arange(1.0, n) / n, core_freq, p)) n1 = n1num / n1denom n2 = (1.0 / (freqs * n)) / (np.sum(1.0 / np.arange(1.0, n))) return n1 - n2 @njit(cache=True) def an_nb(n, core_freq, p): """ Calculates alpha_n from Achaz 2009, eq 9b n: Sample size x: frequency, ranges from 0 to 1 p: value of p parameter """ i = np.arange(1, n) return np.sum(i * omegai_nb(i / n, core_freq, n, p) ** 2.0) @njit(cache=True) def fu_an_vec(n): """Calculates a_n from Fu 1995, eq 4 for a single integer value""" if n <= 1: return 0.0 return np.sum(1.0 / np.arange(1.0, n)) @njit(cache=True) def fu_Bn(n, i): """Calculates Beta_n(i) from Fu 1995, eq 5""" r = 2.0 * n / ((n - i + 1.0) * (n - i)) * (fu_an_vec(n + 1) - fu_an_vec(i)) - ( 2.0 / (n - i) ) return r @njit(cache=True) def sigma(n, ij): res = np.zeros(ij.shape[0]) for k in range(ij.shape[0]): i = max(ij[k, 0], ij[k, 1]) j = min(ij[k, 0], ij[k, 1]) if i == j and 2 * i == n: res[k] = 2.0 * ((fu_an_vec(n) - fu_an_vec(i)) / (n - i)) - (1.0 / (i * i)) elif i == j and i < n / 2.0: # FIXED: removed 2* res[k] = fu_Bn(n, i + 1) elif i == j and i > n / 2.0: res[k] = fu_Bn(n, i) - (1.0 / (i * i)) elif i > j and (i + j == n): an_n = fu_an_vec(n) an_i = fu_an_vec(i) an_j = fu_an_vec(j) term1 = (an_n - an_i) / (n - i) term2 = (an_n - an_j) / (n - j) term3 = (fu_Bn(n, i) + fu_Bn(n, j + 1)) / 2.0 term4 = 1.0 / (i * j) res[k] = (term1 + term2) - (term3 + term4) elif i > j and (i + j < n): res[k] = (fu_Bn(n, i + 1) - fu_Bn(n, i)) / 2.0 elif i > j and (i + j > n): res[k] = (fu_Bn(n, j) - fu_Bn(n, j + 1)) / 2.0 - (1.0 / (i * j)) return res @njit(cache=True) def Bn_nb(n, core_freq, p): """ Returns Beta_N from Achaz 2009, eq 9c Parameters: n: Sample size x: frequency, ranges from 0 to 1 p: value of p parameter """ i = np.arange(1, n) n1 = np.sum( i**2.0 * omegai_nb(i / n, core_freq, n, p) ** 2.0 * sigma(n, np.column_stack((i, i))) ) # coords = np.asarray([(j, i) for i in range(1, n) for j in range(1, i)]) m = (n - 1) * (n - 2) // 2 coords = np.empty((m, 2), dtype=np.int64) idx = 0 for i in range(1, n): for j in range(1, i): coords[idx, 0] = j coords[idx, 1] = i idx += 1 s2 = np.sum( coords[:, 0] * coords[:, 1] * omegai_nb(coords[:, 0] / n, core_freq, n, p) * omegai_nb(coords[:, 1] / n, core_freq, n, p) * sigma(n, coords) ) n2 = 2.0 * s2 return n1 + n2 def calc_thetaw_unfolded(snp_freq_list, num_ind): """Calculates watterson's theta Parameters: snp_freq_list: a list of frequencies, one for each SNP in the window, first column ranges from 1 to number of individuals, second columns is # individuals num_ind: number of individuals used to calculate the core site frequency """ if snp_freq_list.size == 0: return 0 a1 = np.sum(1.0 / np.arange(1, num_ind)) thetaW = len(snp_freq_list[:, 0]) / a1 return thetaW def calc_t_unfolded(freqs, core_freq, n, p, theta, var_dic): """ Using equation 8 from Achaz 2009 Parameters: core_freq: freq of SNP under consideration, ranges from 1 to sample size snp_n: sample size of core SNP p: the p parameter specifying sharpness of peak theta: genome-wide estimate of the mutation rate """ # x = float(core_freq)/snp_n num = np.sum(freqs * n * omegai_nb(freqs, core_freq, n, p)) # if not (n, core_freq, theta) in var_dic: if (n, core_freq, theta) not in var_dic: denom = np.sqrt( an_nb(n, core_freq, p) * theta + Bn_nb(n, core_freq, p) * theta**2.0 ) var_dic[(n, core_freq, theta)] = denom else: denom = var_dic[(n, core_freq, theta)] return num / denom @njit(cache=True) def calc_t_unfolded_cached(freqs, denom, core_freq, n, p, theta): num = np.sum(freqs * n * omegai_nb(freqs, core_freq, n, p)) return num, num / denom @njit(cache=True) def precompute_denoms(n, p, theta, omega_func): denom_array = np.zeros(n + 1) # Precompute shared structures i_vals = np.arange(1, n) diag_sigma = sigma(n, np.column_stack((i_vals, i_vals))) m = (n - 1) * (n - 2) // 2 coords = np.empty((m, 2), dtype=np.int64) idx = 0 for i in range(1, n): for j in range(1, i): coords[idx, 0] = j coords[idx, 1] = i idx += 1 coords_i = coords[:, 0] coords_j = coords[:, 1] off_diag_sigma = sigma(n, coords) for cf in range(1, n + 1): x = cf / n omega = omega_func(i_vals / n, x, n, p) an = np.sum(i_vals * omega**2) omega_i = omega_func(coords_i / n, x, n, p) omega_j = omega_func(coords_j / n, x, n, p) s2 = np.sum(coords_i * coords_j * omega_i * omega_j * off_diag_sigma) b_n = np.sum(i_vals**2 * omega**2 * diag_sigma) + 2.0 * s2 denom_array[cf] = np.sqrt(an * theta + b_n * theta**2) return denom_array, diag_sigma, off_diag_sigma @njit(cache=True) def find_win_indx(prev_start_i, prev_end_i, pos, snp_info, win_size): """Takes in the previous indices of the start_ing and end of the window, then returns the appropriate start_ing and ending index for the next SNP Parameters: prev_start_i: start_ing index in the array of SNP for the previous core SNP's window, inclusive prev_end_i: ending index in the array for the previous SNP's window, inclusive snp_i, the index in the array for the current SNP under consideration snp_info: the numpy array of all SNP locations & frequencies """ win_start = pos - win_size / 2 # array index of start of window, inclusive firstI = prev_start_i + np.searchsorted( snp_info[prev_start_i:, 0], win_start, side="left" ) winEnd = pos + win_size / 2 # array index of end of window, exclusive endI = ( prev_end_i - 1 + np.searchsorted(snp_info[prev_end_i:, 0], winEnd, side="right") ) return (firstI, endI)
[docs] def run_beta_window(ac, position_masked, p=2, m=0.1, w=None, theta=None): _n = ac.sum(axis=1) snp_info = np.column_stack([position_masked, ac[:, 1] / _n, ac, _n]) # S = int(snp_info.shape[0]) n = int(snp_info[0, -1]) mask = (snp_info[:, 4] == n) & (snp_info[:, 1] < (1 - m)) & (snp_info[:, 1] > m) snp_info_masked = snp_info[mask] output = np.zeros((snp_info_masked.shape[0], 3)) if w is None: theta = theta_watterson(snp_info[:, 2:4], snp_info[:, 0])[0] denom_array, sigma_term1, sigma_term2 = precompute_denoms( n, p, theta, omegai_nb ) for j, snp_i in enumerate(snp_info_masked): snp_set = np.concatenate((snp_info[:j], snp_info[j + 1 :])) core_freq = snp_i[1] denom = denom_array[int(round(core_freq * n))] # distances = np.abs(snp_set[:, 0] - snp_i[0]) B, T = calc_t_unfolded_cached(snp_set[:, 1], denom, core_freq, n, p, theta) output[j] = np.array([snp_i[0], B, T]) else: if theta is None: theta = theta_watterson(snp_info[:, 2:4], snp_info[:, 0])[1] * w denom_array, sigma_term1, sigma_term2 = precompute_denoms( n, p, theta, omegai_nb ) prev_start_i = 0 prev_end_i = 0 _idx = 0 for j, snp_i in enumerate(snp_info): core_freq = snp_i[1] if not mask[j]: continue # print(prev_start_i, prev_end_i) sI, endI = find_win_indx(prev_start_i, prev_end_i, snp_i[0], snp_info, w) prev_start_i = sI prev_end_i = endI if endI == sI: # B, T, B_decay, T_decay = 0, 0, 0, 0 B, T, _, _ = 0, 0, 0, 0 elif endI > sI: snp_set = np.concatenate( (snp_info[sI:j], snp_info[(j + 1) : (endI + 1)]) ) denom = denom_array[int(core_freq * n)] B, T = calc_t_unfolded_cached( snp_set[:, 1], denom, core_freq, n, p, theta * w ) output[_idx] = np.array([snp_i[0], B, T]) _idx += 1 schema = {"positions": pl.Int64, "beta": pl.Float64, "t": pl.Float64} return pl.DataFrame(output, schema=schema)
[docs] @njit(parallel=False, cache=True) def ncd1(position_masked, freqs, tf=0.5, w=3000, minIS=2): n = len(position_masked) maf = np.minimum(freqs, 1 - freqs) w1 = w / 2.0 start_positions = np.arange(position_masked[0], position_masked[-1], w1) n_windows = len(start_positions) # Preallocate outputs results = np.empty(n_windows, dtype=np.float64) valid_mask = np.zeros(n_windows, dtype=np.bool_) j_start = 0 j_end = 0 for widx in range(n_windows): start = start_positions[widx] end = start + w # Advance start pointer while j_start < n and position_masked[j_start] < start: j_start += 1 # Advance end pointer while j_end < n and position_masked[j_end] <= end: j_end += 1 # Now [j_start:j_end) are the indices within window count = j_end - j_start if count < minIS: continue # Compute temp2 = sum((maf - tf)^2) tmp = 0.0 for k in range(j_start, j_end): diff = maf[k] - tf tmp += diff * diff results[widx] = np.sqrt(tmp / count) valid_mask[widx] = True return results[valid_mask]