Source code for flexsweep.enrichment

import os
import sys
import time

import polars_bio as pb
from numba import njit

from . import Parallel, delayed, np, pl


def compute_distances(vip_gene_ids: list, annotation_file: str) -> pl.DataFrame:
    """
    Compute distance from each gene to its nearest VIP (case) gene.

    Parameters
    ----------
    vip_gene_ids    : list of VIP gene ID strings
    annotation_file : path to BED file (chrom  start  end  gene_id  strand), tab-separated, no header

    Returns
    -------
    pl.DataFrame with columns: gene_id, distance
    """
    case_genes = set(vip_gene_ids)

    df_annotation = pl.read_csv(
        annotation_file,
        separator="\t",
        has_header=False,
        new_columns=["chrom", "start", "end", "gene_id", "strand"],
        schema_overrides={"chrom": pl.Utf8},
    )

    # Use gene centre as a 1-bp point feature
    df_annotation = df_annotation.with_columns(
        [
            ((pl.col("start") + pl.col("end")) // 2).cast(pl.Int64).alias("start"),
            ((pl.col("start") + pl.col("end")) // 2 + 1).cast(pl.Int64).alias("end"),
        ]
    )

    df_cases_annot = df_annotation.filter(pl.col("gene_id").is_in(case_genes))

    df_all_pts = df_annotation.select(["chrom", "start", "end", "gene_id"])
    df_case_pts = df_cases_annot.select(["chrom", "start", "end", "gene_id"])
    df_all_pts.config_meta.set(coordinate_system_zero_based=True)
    df_case_pts.config_meta.set(coordinate_system_zero_based=True)

    result = (
        pb.nearest(
            df_all_pts,
            df_case_pts,
            cols1=["chrom", "start", "end"],
            cols2=["chrom", "start", "end"],
            suffixes=("_1", "_2"),
            output_type="polars.LazyFrame",
        )
        .select(
            [
                pl.col("gene_id_1").alias("gene_id"),
                pl.col("distance"),
            ]
        )
        .collect()
    )

    # polars_bio.nearest() returns gap distance (b−a−1); closestBed -d returns
    # start-to-start distance (b−a). Add 1 for non-overlapping pairs to match Perl.
    result = result.with_columns(
        pl.when(pl.col("distance") > 0)
        .then(pl.col("distance") + 1)
        .otherwise(pl.col("distance"))
        .alias("distance")
    )

    return result.sort("gene_id")


@njit(cache=True)
def iterative_control_set_fast(
    vip_X: np.ndarray,
    nonvip_X: np.ndarray,
    vip_number: int,
    tolerance: float,
    max_rep: int,
    seed: int = None,
    skip_factor_idx: int = 8,
    n_batches: int = 10,
) -> np.ndarray:
    """
    JIT-compiled bootstrap builder for the nomatchomega_fast variant.

    Runs n_batches sequential internal batches with **continuous state**,
    matching Perl's within-process behaviour exactly:
      - used_counts, current_means, fake_seed, sc_gn persist across batches.
      - Target for batch r: 100 + 11 * vip_number * r  (grows each batch).
      - 10 sets sliced from the end of the list after each batch.

    Additional semantic differences vs iterative_control_set (sweep_count.py):
      1. Factor skip_factor_idx (default 8, dN/dS) is unconstrained (±1e18).
      2. The +2 offset on factor 27 has already been applied by the caller.

    Parameters
    ----------
    vip_X           : (vip_number, n_factors) float64 — offset pre-applied
    nonvip_X        : (n_nonvips,  n_factors) float64 — offset pre-applied
    vip_number      : number of VIP genes
    tolerance       : ± fraction for confounder matching (e.g. 0.05)
    max_rep         : dynamic rep-limit numerator
    seed            : RNG seed for this parallel call
    skip_factor_idx : 0-based factor index to leave unconstrained
    n_batches       : number of sequential internal batches (≡ Perl Iterations_number/10)

    Returns
    -------
    np.ndarray of shape (n_batches * 10, vip_number), dtype int64.
    Each row is one control set of nonvip_X row indices.
    """
    n_nonvips = nonvip_X.shape[0]
    n_factors = nonvip_X.shape[1]

    # Compute VIP mean vector.
    vip_vec = np.zeros(n_factors)
    for i in range(vip_X.shape[0]):
        for f in range(n_factors):
            vip_vec[f] += vip_X[i, f]
    for f in range(n_factors):
        vip_vec[f] /= vip_X.shape[0]

    # Tolerance bounds; skip_factor_idx is unconstrained.
    inf_vec = (1.0 - tolerance) * vip_vec
    sup_vec = (1.0 + tolerance) * vip_vec
    inf_vec[skip_factor_idx] = -1e18
    sup_vec[skip_factor_idx] = 1e18

    init_fake = 100

    # Preallocate for the full continuous list across all batches.
    capacity = (init_fake + 11 * vip_number * n_batches) * 2
    good_idx = np.empty(capacity, dtype=np.int64)

    # Continuous state — persists across all n_batches (matches Perl).
    used_counts = np.zeros(n_nonvips, dtype=np.int32)
    current_means = vip_vec.copy()
    fake_seed = 100
    sc_gn = 0

    np.random.seed(seed)

    out = np.empty((n_batches * 10, vip_number), dtype=np.int64)
    out_row = 0

    for r in range(1, n_batches + 1):
        target_len = init_fake + 11 * vip_number * r  # grows each batch

        while sc_gn < target_len:
            # if sc_gn % 1000 == 0:
            #     print("batch", r, "progress", sc_gn, "/", target_len)
            i1 = np.random.randint(0, n_nonvips)
            i2 = np.random.randint(0, n_nonvips)
            while i2 == i1:
                i2 = np.random.randint(0, n_nonvips)

            rep_limit = max_rep * (sc_gn + 1) / vip_number
            if used_counts[i1] >= rep_limit or used_counts[i2] >= rep_limit:
                continue

            denom = sc_gn + 2 + fake_seed
            ok = True
            for f in range(n_factors):
                sub = (nonvip_X[i1, f] + nonvip_X[i2, f]) * 0.5
                new_m = (
                    fake_seed * vip_vec[f] + current_means[f] * sc_gn + sub * 2.0
                ) / denom
                if new_m < inf_vec[f] or new_m > sup_vec[f]:
                    ok = False
                    break

            if ok:
                if sc_gn + 2 > capacity:
                    new_cap = capacity * 2
                    new_buf = np.empty(new_cap, dtype=np.int64)
                    new_buf[:sc_gn] = good_idx[:sc_gn]
                    good_idx = new_buf
                    capacity = new_cap

                good_idx[sc_gn] = i1
                good_idx[sc_gn + 1] = i2
                used_counts[i1] += 1
                used_counts[i2] += 1

                for f in range(n_factors):
                    sub = (nonvip_X[i1, f] + nonvip_X[i2, f]) * 0.5
                    current_means[f] = (
                        fake_seed * vip_vec[f] + current_means[f] * sc_gn + sub * 2.0
                    ) / denom

                sc_gn += 2
                if fake_seed > 0:
                    fake_seed -= 1

        # Slice 10 sets from the end of good_idx for batch r
        total = sc_gn
        n_sets = 0
        for p in range(10):
            if total - (p + 1) * vip_number < 0:
                break
            n_sets += 1

        for p in range(n_sets):
            sup = total - p * vip_number
            inf_i = sup - vip_number
            for j in range(vip_number):
                out[out_row + (n_sets - 1 - p), j] = good_idx[inf_i + j]
        out_row += n_sets

    return out[:out_row]


[docs] def run_bootstrap_nomatchomega( case_genes: list, factors_file: str, annotation_file: str, runs_number: int, tolerance: float, min_dist: float, flip: bool, max_rep: int, seed: int = None, nthreads: int = 1, control_genes: list = None, distance_file: str = None, skip_factor_idx: int = 8, offset_factor_idx: int = 27, offset_value: float = 2.0, n_batches: int = 10, ): """ Step 6: nomatchomega_fast variant of the matched bootstrap control set generator. Implements the behaviour of bootstrap_test_script_nomatchomega_fast.pl: - Factor skip_factor_idx (default 8, omega/dN/dS) is unconstrained. - Factor offset_factor_idx (default 27, Perl column 28) gets +offset_value. - n_batches sequential internal batches per parallel call, with continuous state (used_counts, current_means, fake_seed persist across batches), matching Perl's within-process continuous state. - Target for batch r: 100 + 11 * vip_number * r (grows each batch). - Total control sets = runs_number * n_batches * 10. Parameters ---------- case_genes : list of VIP gene IDs (HLA/histone excluded by caller) factors_file : path to confounding factors table (TSV, first col gene_id) annotation_file : path to BED gene coordinates file (chrom start end gene_id strand). Distances to nearest VIP are computed automatically via compute_distances. runs_number : number of independent parallel calls (≡ Perl Runs_number) tolerance : allowed ± fraction for confounder matching (e.g. 0.05) min_dist : minimum distance (bp) from VIPs for control eligibility flip : if True, swap VIP and control pools max_rep : max times a control gene may be sampled on average nthreads : number of parallel workers control_genes : optional explicit candidate control gene IDs. If None, all genes in factors_file that are not case genes are used. Pass the "no"-labelled genes from genes_set_file to match Perl behaviour exactly. distance_file : optional path to pre-computed distance TSV (gene_id \\t distance, no header). When provided, distances are read directly instead of calling compute_distances(). Use Perl's distance_genes_set_file.txt for exact control pool matching. skip_factor_idx : 0-based factor index to leave unconstrained (default 8 = Perl factor 9, dN/dS / omega). offset_factor_idx : 0-based factor index to apply +offset_value to (default 27 = Perl column 28). offset_value : value added to offset_factor_idx column (default 2.0). n_batches : sequential internal batches per parallel call (≡ Perl Iterations_number/10; default 10). Returns ------- df_set : pl.DataFrame — VIP genes (gene_id column) df_bootstrap_control : pl.DataFrame — rows = control sets, cols = gene positions (sample_id + column_0..column_N) """ case_ids = set(case_genes) df_factors = pl.read_csv(factors_file, separator=" ", has_header=False).rename( {"column_1": "gene_id"} ) # Perl checks index 19 of the raw line. In Polars, this is column_20. df_factors = ( pl.read_csv(factors_file, separator=" ", has_header=False) .rename({"column_1": "gene_id"}) .filter(pl.col("column_20") >= -1e21) ) # Background pool if control_genes is not None: background_ids = set(control_genes) - case_ids else: background_ids = set(df_factors["gene_id"].to_list()) - case_ids if distance_file is not None: distance_df = pl.read_csv( distance_file, separator="\t", has_header=False, new_columns=["gene_id", "distance"], ) else: distance_df = compute_distances(case_genes, annotation_file) set_by_dist = set( distance_df.filter(pl.col("distance") >= min_dist)["gene_id"].to_list() ) df_set = df_factors.filter(pl.col("gene_id").is_in(case_ids)) df_control_pool = df_factors.filter( pl.col("gene_id").is_in(background_ids) & pl.col("gene_id").is_in(set_by_dist) ) if flip: df_set, df_control_pool = df_control_pool, df_set factor_cols = [c for c in df_factors.columns if c != "gene_id"] control_factors = df_control_pool.select(["gene_id"] + factor_cols) vip_X = df_set.select(factor_cols).to_numpy().astype(np.float64) nonvip_X = control_factors.select(factor_cols).to_numpy().astype(np.float64) nonvip_genes = control_factors["gene_id"].to_numpy().astype(str) vip_number = df_set.height # Apply offset once in Python — shared read-only across all parallel workers. if 0 <= offset_factor_idx < vip_X.shape[1]: vip_X_off = vip_X.copy() nonvip_X_off = nonvip_X.copy() vip_X_off[:, offset_factor_idx] += offset_value nonvip_X_off[:, offset_factor_idx] += offset_value else: vip_X_off = vip_X nonvip_X_off = nonvip_X print( f"There are {df_set.height} genes of interest and {df_control_pool.height} potential control genes at distance of at least {min_dist} bases." ) if df_control_pool.height <= 1.5 * df_set.height: print( "The number of control genes is less than 1.5× the number of VIPs. FDR may be high." ) def _run_one_batch(run_idx: int) -> np.ndarray: _batch_seed = int( np.random.default_rng(None if seed is None else seed + run_idx).integers( 0, 2**31 ) ) idx = iterative_control_set_fast( vip_X_off, nonvip_X_off, vip_number, tolerance, max_rep, _batch_seed, skip_factor_idx, n_batches, ) return nonvip_genes[idx] batch_results = Parallel(n_jobs=nthreads, backend="loky", verbose=10)( delayed(_run_one_batch)(run_idx) for run_idx in range(runs_number) ) np_control = np.concatenate(batch_results, axis=0) df_bootstrap_control = pl.DataFrame(np_control) df_bootstrap_control = df_bootstrap_control.with_columns( ( pl.lit("sample_") + pl.arange(1, df_bootstrap_control.height + 1).cast(pl.Utf8) ).alias("sample_id") ).select("sample_id", *df_bootstrap_control.columns) return df_set.select("gene_id"), df_bootstrap_control
def build_gene_neighbors(coord_file: str, valid_genes: list, dist: int) -> pl.DataFrame: """ Find all gene neighbours within ±dist bp using polars-bio overlap. Uses the full gene body intervals (not centre approximation). Each gene's interval is expanded by dist on both sides, then overlapped against all gene centre points — equivalent to the original pybedtools bed.window(w=dist). Parameters ---------- coord_file : path to BED gene coordinates (chrom start end gene_id strand) valid_genes : list of gene IDs to include (QC-passing universe) dist : window radius in bp (e.g. 500_000) Returns ------- pl.DataFrame with columns: gene_id, neighbors (space-joined string) """ valid = set(valid_genes) df_coords = ( pl.read_csv( coord_file, separator="\t", has_header=False, new_columns=["chrom", "start", "end", "gene_id", "strand"], schema_overrides={"chrom": pl.Utf8}, ) .with_columns( [ pl.col("start").cast(pl.Int64), pl.col("end").cast(pl.Int64), ] ) .filter(pl.col("gene_id").is_in(valid)) .with_columns(((pl.col("start") + pl.col("end")) // 2).alias("center")) ) # Expand each gene body by ±dist (pybedtools window equivalent) df_windows = df_coords.select( [ pl.col("chrom"), (pl.col("start") - dist).alias("win_start"), (pl.col("end") + dist).alias("win_end"), pl.col("gene_id"), ] ) # Target: 1-bp gene centres for overlap detection df_centers = df_coords.select( [ pl.col("chrom"), pl.col("center").alias("start"), (pl.col("center") + 1).alias("end"), pl.col("gene_id"), ] ) df_windows.config_meta.set(coordinate_system_zero_based=True) df_centers.config_meta.set(coordinate_system_zero_based=True) pairs = ( pb.overlap( df_windows, df_centers, cols1=["chrom", "win_start", "win_end"], cols2=["chrom", "start", "end"], suffixes=("_1", "_2"), output_type="polars.LazyFrame", ) .select( [ pl.col("gene_id_1").alias("gene_id"), pl.col("gene_id_2").alias("neighbor"), ] ) .filter(pl.col("gene_id") != pl.col("neighbor")) .collect() ) df_neighbors = pairs.group_by("gene_id").agg( pl.col("neighbor").sort().str.join(" ").alias("neighbors") ) df_all_genes = df_coords.select("gene_id") return ( df_all_genes.join(df_neighbors, on="gene_id", how="left") .with_columns(pl.col("neighbors").fill_null("")) .sort("gene_id") ) def build_gene_neighbors_numpy( coord_file: str, valid_genes: list, dist: int ) -> pl.DataFrame: """ Find all gene neighbours within ±dist bp using pure NumPy (10 kb binning). Gene centres are binned to a 10 kb grid. For each gene, numpy.searchsorted finds all genes whose binned centre falls within [center - dist, center + dist]. Parameters ---------- coord_file : path to BED gene coordinates (chrom start end gene_id strand) valid_genes : list of gene IDs to include (QC-passing universe) dist : neighbourhood radius in bp (e.g. 500_000) Returns ------- pl.DataFrame with columns: gene_id, neighbors (space-joined string) """ valid = set(valid_genes) df = ( pl.read_csv( coord_file, separator="\t", has_header=False, new_columns=["chrom", "start", "end", "gene_id", "strand"], schema_overrides={"chrom": pl.Utf8}, ) .with_columns( [ pl.col("start").cast(pl.Int64), pl.col("end").cast(pl.Int64), ] ) .filter(pl.col("gene_id").is_in(valid)) .with_columns(((pl.col("start") + pl.col("end")) // 2).alias("center")) .with_columns(((pl.col("center") / 10000).floor() * 10000).alias("bin")) .select(["chrom", "bin", "gene_id"]) .sort(["chrom", "bin"]) ) results = [] for chrom, sub in df.group_by("chrom", maintain_order=True): bins = sub["bin"].to_numpy() genes = sub["gene_id"].to_list() n = len(genes) left_idx = np.searchsorted(bins, bins - dist, side="left") right_idx = np.searchsorted(bins, bins + dist, side="right") for i in range(n): neighbors = [g for g in genes[left_idx[i] : right_idx[i]] if g != genes[i]] results.append((genes[i], " ".join(neighbors))) return pl.DataFrame(results, schema=["gene_id", "neighbors"], orient="row").sort( "gene_id" ) def simplify_sweeps_gz( rank_file: str, thresholds: list, col_index: int = 1, df: "pl.DataFrame | None" = None, ) -> pl.DataFrame: """ Assign each gene the most stringent threshold it passes (smallest cutoff c such that rank <= c). Parameters ---------- rank_file : TSV with gene_id in column 0 and rank values in subsequent columns (no header). For single-population files the rank is in column 1. For multi-population files (e.g. all_ihsfreqafr with ESN/GWD/LWK/MSL/YRI in columns 1-5) pass col_index to select the population column (1-based, default 1). thresholds : list of int rank thresholds col_index : 1-based index of the rank column to use (default 1 = second column, i.e. first rank column after gene_id). df : optional pre-loaded pl.DataFrame (same column layout as the file). When provided, rank_file is ignored and df is used directly — enables in-memory shuffle workflows without disk I/O. Returns ------- pl.DataFrame with columns: gene_id, selected_cutoff (genes with no hit excluded) """ if df is not None: df_raw = df else: # Auto-detect delimiter: try tab first, fall back to space. # (AFR/pop rank files use space; some single-pop files may use tab.) import gzip as _gzip _open = _gzip.open if rank_file.endswith(".gz") else open with _open(rank_file, "rt") as _fh: _raw = _fh.readline() _sep = "\t" if "\t" in _raw else " " df_raw = pl.read_csv( rank_file, separator=_sep, has_header=False, infer_schema_length=0, # read all as Utf8 first to handle variable cols ) col_name = df_raw.columns[col_index] df = df_raw.select( [ pl.col(df_raw.columns[0]).alias("gene_id").cast(pl.Utf8), pl.col(col_name).alias("rank").cast(pl.Int64), ] ) cuts = np.sort(np.array(thresholds, dtype=int)) expr = None for c in cuts: if expr is None: expr = pl.when(pl.col("rank") <= c).then(pl.lit(c)) else: expr = expr.when(pl.col("rank") <= c).then(pl.lit(c)) return ( df.with_columns(expr.otherwise(None).alias("selected_cutoff")) .drop_nulls("selected_cutoff") .select(["gene_id", "selected_cutoff"]) ) def simplify_sweeps_gz_list(rank_file: str, thresholds: list) -> pl.DataFrame: """ For a single-population rank file, return all thresholds each gene passes as a comma-separated string. Parameters ---------- rank_file : TSV with columns gene_id, rank (no header) thresholds : list of int rank thresholds Returns ------- pl.DataFrame with columns: gene_id, cutoffs (comma-separated string) """ df = pl.read_csv( rank_file, separator="\t", has_header=False, new_columns=["gene_id", "rank"] ) cuts = np.sort(thresholds)[::-1] df_cut = df.with_columns( pl.when(pl.col("rank").is_not_null()) .then( pl.struct(pl.col("rank")).map_elements( lambda s: ",".join(str(c) for c in cuts if s["rank"] <= c), return_dtype=pl.Utf8, ) ) .alias("cutoffs") ) return df_cut.filter(pl.col("cutoffs") != "") def _count_sweep_events(query_genes: set, sweep_genes: set, neighbor_map: dict) -> int: """ Count distinct sweep events = connected components of the neighbour graph restricted to `sweep_genes`, that are touched by at least one gene from `query_genes`. Each connected component in the sub-graph of sweep_genes (using neighbor_map edges) is a single sweep event. We count components that share at least one gene with query_genes. """ visited: set = set() count = 0 for g in query_genes: if g not in sweep_genes or g in visited: continue # DFS to find the connected component of g within sweep_genes stack = [g] while stack: node = stack.pop() if node in visited or node not in sweep_genes: continue visited.add(node) stack.extend(neighbor_map.get(node, set()) - visited) count += 1 return count
[docs] def count_sweeps_singlepop( vip_genes: list, control_sets: list, simplified_sweeps: pl.DataFrame, neighbors: pl.DataFrame, thresholds: list, use_clust: bool = True, count_sweeps: bool = False, neighbor_col: str = None, ) -> pl.DataFrame: """ Count sweeps overlapping VIPs vs control sets for a single population. Parameters ---------- vip_genes : list of VIP gene ID strings control_sets : list of lists of control gene IDs (one list per control set) simplified_sweeps: pl.DataFrame from simplify_sweeps_gz (gene_id, selected_cutoff) neighbors : pl.DataFrame from build_gene_neighbors_numpy (gene_id, neighbors) thresholds : sorted list of int rank thresholds use_clust : expand gene sets by neighbours before counting (de-duplication) count_sweeps : if True, count distinct sweep events (connected components of the neighbour graph within sweep genes) rather than individual genes. neighbor_col : column name for neighbours (auto-detected if None) Returns ------- pl.DataFrame with columns: threshold, vip_count, ctrl_mean, ctrl_ci_lo, ctrl_ci_hi, ratio, ci_lo_ratio, ci_hi_ratio, p_value """ cuts = sorted(int(x) for x in thresholds) _sw = simplified_sweeps.select(["gene_id", "selected_cutoff"]) sel_map = dict(zip(_sw["gene_id"].to_list(), _sw["selected_cutoff"].to_list())) neighbor_map: dict = {} if use_clust or count_sweeps: if neighbor_col is None: if "neighbors" in neighbors.columns: neighbor_col = "neighbors" elif "gene_id_b" in neighbors.columns: neighbor_col = "gene_id_b" if neighbor_col and neighbor_col in neighbors.columns: for row in neighbors.iter_rows(named=True): raw = row.get(neighbor_col, None) neighbor_map[row["gene_id"]] = set(raw.split(" ")) if raw else set() def expand(ids: set) -> set: if not use_clust or not neighbor_map: return ids visited: set = set() expanded: set = set() for g in ids: if g in visited: continue cluster = neighbor_map.get(g, set()) | {g} expanded |= cluster visited |= cluster return expanded vip_set = expand(set(vip_genes)) # Keep raw lists for counting (duplicates matter — Perl bootstrap samples with # replacement and scores = sum of repetition counts for sweep genes in ctrl list). # Only the set-expanded version is needed for use_clust / count_sweeps paths. ctrl_lists = [list(ctrl) for ctrl in control_sets] # raw lists with dups ctrl_sets_expanded = [expand(set(ctrl)) for ctrl in control_sets] # for clust paths # Vectorized path for the common case (use_clust=False, count_sweeps=False): # build a numpy int32 matrix once, then use boolean indexing per threshold. # This replaces O(n_sets × n_genes × n_thresholds) Python loops with numpy ops. if not use_clust and not count_sweeps and ctrl_lists: _all_genes = sorted({g for ctrl in ctrl_lists for g in ctrl}) _g2i = {g: i for i, g in enumerate(_all_genes)} _n_vocab = len(_all_genes) _ctrl_matrix = np.array( [[_g2i[g] for g in ctrl] for ctrl in ctrl_lists], dtype=np.int32 ) # shape: (n_ctrl_sets, n_genes_per_set) else: _ctrl_matrix = None _g2i = None _n_vocab = 0 rows = [] if not use_clust and not count_sweeps and _ctrl_matrix is not None: # === FAST PATH: batch all thresholds in one tensor operation === # Build cutoff vector aligned to control-gene vocab (_g2i) _cutoff_arr = np.full(_n_vocab, np.iinfo(np.int32).max, dtype=np.int32) for g, c in sel_map.items(): if g in _g2i: _cutoff_arr[_g2i[g]] = int(c) # VIP cutoffs from sel_map directly (independent of _g2i — covers all VIPs) _vip_sorted = sorted(vip_set) _vip_cutoffs = np.array( [int(sel_map.get(g, np.iinfo(np.int32).max)) for g in _vip_sorted], dtype=np.int32, ) _thresh_arr = np.array(cuts, dtype=np.int32) # (T,) _sweep_mask = _cutoff_arr[None, :] <= _thresh_arr[:, None] # (T, V) bool _all_ctrl = _sweep_mask[:, _ctrl_matrix].sum(axis=2) # (T, S) _all_vip = (_vip_cutoffs[None, :] <= _thresh_arr[:, None]).sum(axis=1) # (T,) _all_means = _all_ctrl.mean(axis=1) # (T,) _all_ci = np.percentile(_all_ctrl, [2.5, 97.5], axis=1) # (2, T) _all_pvals = (_all_ctrl > _all_vip[:, None]).mean(axis=1) # (T,) for t_idx, t in enumerate(cuts): vip_count = int(_all_vip[t_idx]) ctrl_mean = float(_all_means[t_idx]) ci_low = float(_all_ci[0, t_idx]) ci_high = float(_all_ci[1, t_idx]) p_val = float(_all_pvals[t_idx]) denom = ctrl_mean + 0.1 rows.append( { "threshold": t, "vip_count": vip_count, "ctrl_mean": ctrl_mean, "ctrl_ci_lo": ci_low, "ctrl_ci_hi": ci_high, "ratio": (vip_count + 0.1) / denom, "ci_lo_ratio": (ci_low + 0.1) / denom, "ci_hi_ratio": (ci_high + 0.1) / denom, "p_value": p_val, } ) else: # === SLOW PATHS: per-threshold loop (use_clust=True or count_sweeps=True) === sweep_dict = {t: {g for g, c in sel_map.items() if c <= t} for t in cuts} for t in cuts: S = sweep_dict[t] if count_sweeps: ctrl_counts = [ _count_sweep_events(C, S, neighbor_map) for C in ctrl_sets_expanded ] vip_count = _count_sweep_events(vip_set, S, neighbor_map) else: # use_clust=True: expand by neighbours, count unique genes vip_count = len(vip_set & S) ctrl_counts = [len(C & S) for C in ctrl_sets_expanded] ctrl_counts_arr = np.asarray(ctrl_counts) ctrl_mean = float(np.mean(ctrl_counts_arr)) if ctrl_counts_arr.size else 0.0 if ctrl_counts_arr.size > 1: ci_low, ci_high = np.percentile(ctrl_counts_arr, [2.5, 97.5]) else: ci_low = ci_high = ctrl_mean denom = ctrl_mean + 0.1 p_val = ( float(np.mean(ctrl_counts_arr > vip_count)) if ctrl_counts_arr.size > 0 else float("nan") ) rows.append( { "threshold": t, "vip_count": vip_count, "ctrl_mean": ctrl_mean, "ctrl_ci_lo": float(ci_low), "ctrl_ci_hi": float(ci_high), "ratio": float((vip_count + 0.1) / denom), "ci_lo_ratio": float((ci_low + 0.1) / denom), "ci_hi_ratio": float((ci_high + 0.1) / denom), "p_value": p_val, } ) return pl.DataFrame(rows)
def _resolve_target_pops(pop_interest: str, populations: list, groups: list) -> list: """Return the list of population names matching pop_interest.""" if pop_interest == "All": return populations if pop_interest in populations: return [pop_interest] # treat as group name return [p for p, g in zip(populations, groups) if g == pop_interest] def _group_sweep_sets( pops_in_group: list, simplified_sweeps_by_pop: dict, thresholds: list, ) -> pl.DataFrame: """ Build a unified simplified-sweeps DataFrame for a group of populations: a gene is 'in sweep at threshold t' if it is in ANY member population's sweep set. The selected_cutoff is the minimum (most stringent) cutoff across member pops. """ cuts = sorted(int(x) for x in thresholds) # For each gene collect the minimum selected_cutoff across member pops gene_cutoff: dict = {} for pop in pops_in_group: df_pop = simplified_sweeps_by_pop.get(pop) if df_pop is None: continue for gene_id, cutoff in df_pop.select( ["gene_id", "selected_cutoff"] ).iter_rows(): if gene_id not in gene_cutoff or cutoff < gene_cutoff[gene_id]: gene_cutoff[gene_id] = cutoff if not gene_cutoff: return pl.DataFrame( {"gene_id": [], "selected_cutoff": []}, schema={"gene_id": pl.Utf8, "selected_cutoff": pl.Int64}, ) return pl.DataFrame( list(gene_cutoff.items()), schema=["gene_id", "selected_cutoff"], orient="row", )
[docs] def count_sweeps_multipop( vip_genes: list, control_sets: list, simplified_sweeps_by_pop: dict, neighbors: pl.DataFrame, thresholds: list, populations: list, groups: list, pop_interest: str, count_sweeps: bool = False, use_clust: bool = True, nthreads: int = 1, ) -> pl.DataFrame: """ Step 7: Multi-population sweep counting with group aggregation and adaptive depth. Populations are processed in parallel. For a group or 'All', the sweep sets of member populations are unioned before counting (each sweep counted once per group, not once per population). Adaptive depth escalates from 100 → 1000 → all control sets depending on observed p-values. Parameters ---------- vip_genes : list of VIP gene IDs control_sets : list of lists of control gene IDs (up to 10 000) simplified_sweeps_by_pop: dict mapping population name → pl.DataFrame (output of simplify_sweeps_gz per pop) neighbors : pl.DataFrame from build_gene_neighbors_numpy thresholds : list of int rank thresholds populations : ordered list of population codes groups : group label for each population (same length) pop_interest : single pop name, group name, or "All" count_sweeps : if True count distinct sweep events (connected components of the neighbour graph within sweep genes); if False count genes in sweeps (default; matches Perl behaviour) use_clust : expand gene sets by neighbours before counting (de-duplication; maps to Perl Count_sweeps parameter) nthreads : parallel workers Returns ------- pl.DataFrame with columns: scope, threshold, vip_count, ctrl_mean, ctrl_ci_lo, ctrl_ci_hi, ratio, ci_lo_ratio, ci_hi_ratio, p_value. scope is the population name, group name, or "All". """ target_pops = _resolve_target_pops(pop_interest, populations, groups) if not target_pops: raise ValueError( f"pop_interest '{pop_interest}' not found in populations or groups" ) # Determine scopes to evaluate: # - each individual target population # - each group that appears among target pops (if > 1 pop in group) # - "All" if pop_interest == "All" or covers multiple groups # Group/All scopes use gene union across member pops (min cutoff), which matches # the Perl recurrence-weighted formula: each gene counted once regardless of # how many populations it appears in. scopes: dict = {} # scope_name → simplified_sweeps df for pop in target_pops: if pop in simplified_sweeps_by_pop: scopes[pop] = simplified_sweeps_by_pop[pop] # Group-level scopes group_to_pops: dict = {} for pop in target_pops: g = groups[populations.index(pop)] group_to_pops.setdefault(g, []).append(pop) for grp, members in group_to_pops.items(): if len(members) > 1: scopes[grp] = _group_sweep_sets( members, simplified_sweeps_by_pop, thresholds ) # "All" scope if pop_interest == "All" and len(target_pops) > 1: scopes["All"] = _group_sweep_sets( target_pops, simplified_sweeps_by_pop, thresholds ) # Adaptive depth rounds: 100 → 1000 → all adaptive_limits = [100, 1000, len(control_sets)] all_results = [] for round_idx, n_ctrl in enumerate(adaptive_limits): ctrl_subset = control_sets[:n_ctrl] # Run each scope in parallel def _count_scope(scope_name, sw_df): df = count_sweeps_singlepop( vip_genes=vip_genes, control_sets=ctrl_subset, simplified_sweeps=sw_df, neighbors=neighbors, thresholds=thresholds, use_clust=use_clust, count_sweeps=count_sweeps, ) return scope_name, df scope_items = list(scopes.items()) results_round = Parallel(n_jobs=nthreads, backend="loky")( delayed(_count_scope)(name, df) for name, df in scope_items ) # Collect into a single DataFrame for this round frames = [] for scope_name, df in results_round: frames.append(df.with_columns(pl.lit(scope_name).alias("scope"))) combined = pl.concat(frames).select( [ "scope", "threshold", "vip_count", "ctrl_mean", "ctrl_ci_lo", "ctrl_ci_hi", "ratio", "ci_lo_ratio", "ci_hi_ratio", "p_value", ] ) all_results = combined # overwrite — later rounds have more control sets # Adaptive depth decision p_values = combined["p_value"].drop_nulls().to_list() if round_idx == 0: # Escalate if any p ≤ 0.05 or p ≥ 0.95 if not any(p <= 0.05 or p >= 0.95 for p in p_values): break # not significant at any threshold — stop here elif round_idx == 1: # Escalate if any p ≤ 0.002 or p ≥ 0.998 if not any(p <= 0.002 or p >= 0.998 for p in p_values): break # weakly significant only — no need for full 10 000 sets # round_idx == 2: always final # Don't escalate beyond what we have if n_ctrl >= len(control_sets): break return all_results
def shuffle_genome( coord_file: str, valid_genes: list, sweep_files: list, n_shuffles: int, shuffling_segments_number: int, output_dir: str = None, max_rank_boundary: int = 2000, seed: int = None, ) -> list: """ Step 8: Shuffle genomic positions of sweep signals to build an FDR null. Genes are ordered by chromosomal position and divided into segments. Segments are randomly reordered (and optionally reversed), building n_shuffles independent gene→gene permutation maps. For each original sweep file and each permutation, a shuffled dataset is produced. n_shuffles must be a multiple of 8. Each call to the outer loop produces 8 independent shufflings per outer iteration. Parameters ---------- coord_file : path to BED gene coordinates (no header) valid_genes : list of gene IDs to include (QC-passing universe) sweep_files : list of original sweep file paths (all window sizes) n_shuffles : total number of shuffled replicates (multiple of 8) shuffling_segments_number: number of genomic segments to cut the genome into output_dir : directory to write shuffled sweep files. If None (default), data is kept in memory and a list of {sweep_file: pl.DataFrame} dicts is returned (one dict per replicate) — no files written to disk. max_rank_boundary : genes ranked <= this are considered "strong signal" and candidate cut points flanked by two such genes are marked invalid Returns ------- output_dir is not None → list of str (paths to written shuffled files) output_dir is None → list of dict[str, pl.DataFrame], one per replicate. Each dict maps sweep_file path → shuffled DataFrame with the same column layout as the original file. Pass directly to estimate_fdr as shuffled_sweep_files_by_shuffle. """ if n_shuffles % 8 != 0: raise ValueError("n_shuffles must be a multiple of 8") in_memory = output_dir is None if not in_memory: os.makedirs(output_dir, exist_ok=True) # --- 1. Valid gene universe --- valid_genes = set(valid_genes) # --- 2. Load coordinates and build chromosomal order --- df_coords = ( pl.read_csv( coord_file, separator="\t", has_header=False, new_columns=["chrom", "start", "end", "gene_id", "strand"], schema_overrides={"chrom": pl.Utf8}, ) .filter(pl.col("gene_id").is_in(valid_genes)) .with_columns(((pl.col("start") + pl.col("end")) // 2).alias("center")) .with_columns( ((pl.col("center") / 10000).floor() * 10000).cast(pl.Int64).alias("bin") ) ) # Build ordered gene list (autosomes 1–22, numeric chromosome order) ordered_genes = ( df_coords.filter(pl.col("chrom").is_in([str(i) for i in range(1, 23)])) .with_columns(pl.col("chrom").cast(pl.Int32).alias("_chrom_int")) .sort(["_chrom_int", "bin"])["gene_id"] .to_list() ) n_genes = len(ordered_genes) if n_genes == 0: raise ValueError("No valid genes found in coord_file") gene_to_idx = {g: i for i, g in enumerate(ordered_genes)} # Perl: $seg_size = $segcut (direct segment size, not count). # shuffling_segments_number IS the number of genes per segment, not the # number of segments. Segments grow until a valid cut is found. seg_size = max(1, shuffling_segments_number) # --- 3. Load sweep data for cut-point validation --- # A gene is "strong signal" if its rank <= max_rank_boundary in ANY sweep file strong_signal_genes: set = set() sweep_data_by_file: dict = {} # file → dict gene_id → tab-separated line (disk mode) rank_matrix_by_file: dict = {} # file → (gene_col, rank_cols, numpy int32 matrix) for sf in sweep_files: gene_ranks: dict = {} _first_line = open(sf).readline() _sep = "\t" if "\t" in _first_line else " " df_sw = pl.read_csv( sf, has_header=False, separator=_sep, infer_schema_length=0, ) # column_1 = gene_id, remaining columns = per-population ranks # A gene is strong-signal if its rank in ANY population <= max_rank_boundary if df_sw.width >= 2: # Vectorized: cast rank columns to Int64, take row-wise min, filter _rank_exprs = [ pl.col(c).cast(pl.Int64, strict=False) for c in df_sw.columns[1:] ] strong_signal_genes |= set( df_sw.filter(pl.min_horizontal(_rank_exprs) <= max_rank_boundary)[ :, 0 ].to_list() ) # Disk-mode: store tab-separated lines for gene remapping if not in_memory: for row in df_sw.iter_rows(): gene_ranks[row[0]] = "\t".join(str(v) for v in row) # In-memory fast path (S1): pre-align sweep data to ordered_genes order # as a numpy int32 matrix so each shuffle is a single fancy-index op. _gcol = df_sw.columns[0] _rcols = df_sw.columns[1:] _odf = pl.DataFrame({"_g": ordered_genes, "_pos": list(range(n_genes))}) _aligned = _odf.join(df_sw.rename({_gcol: "_g"}), on="_g", how="left").sort( "_pos" ) rank_matrix_by_file[sf] = ( _gcol, list(_rcols), _aligned.select(list(_rcols)) .with_columns( [ pl.col(c).cast(pl.Int32, strict=False).fill_null(10_000_000) for c in _rcols ] ) .to_numpy(), # shape: (n_genes, n_pops) ) sweep_data_by_file[sf] = gene_ranks # --- 4. Detect valid cut points --- # Candidate cuts: every seg_size genes # Invalid: both the gene just before AND just after the cut are strong-signal genes cut_valid = [] # list of (cut_index, is_valid) for i in range(seg_size, n_genes, seg_size): gene_before = ordered_genes[i - 1] gene_after = ordered_genes[min(i, n_genes - 1)] both_strong = ( gene_before in strong_signal_genes and gene_after in strong_signal_genes ) cut_valid.append((i, not both_strong)) # --- 5. Build segments respecting valid cuts (as numpy arrays for fast mirror) --- segments = [] seg_start = 0 for cut_idx, is_valid in cut_valid: if is_valid: segments.append(np.arange(seg_start, cut_idx, dtype=np.int32)) seg_start = cut_idx # tail segment if seg_start < n_genes: segments.append(np.arange(seg_start, n_genes, dtype=np.int32)) rng = np.random.default_rng(seed) written_paths = [] inmem_replicates = [] n_outer = n_shuffles // 8 for outer in range(n_outer): # Produce 8 independent shufflings for this outer iteration for inner in range(8): shuffle_id = outer * 8 + inner + 1 # Shuffle segment order seg_order = list(range(len(segments))) rng.shuffle(seg_order) # Per-segment: either keep forward indices or mirror to opposite end of genome. # Matches Perl dice mechanism: dice>5 → use inds[n-1-i] instead of inds[i]. # Segments are numpy arrays — mirror is a single broadcast subtract (no loop). collected = [] for si in seg_order: seg = segments[si].copy() if rng.integers(0, 2) == 1: seg = ( np.int32(n_genes - 1) - seg ) # mirror position (numpy broadcast) collected.append(seg) shuffled_indices = np.concatenate(collected) # Produce one shuffled dataset per original sweep file replicate_dict = {} for sf in sweep_files: if in_memory: # S1: numpy fancy-index — single op replaces perm_map + gene loop _gcol, _rcols, _rmat = rank_matrix_by_file[sf] _reindexed = _rmat[shuffled_indices] # (n_genes, n_pops) replicate_dict[sf] = pl.DataFrame( {_gcol: ordered_genes} | {_rcols[j]: _reindexed[:, j] for j in range(len(_rcols))} ) else: # Disk mode: build perm_map and remap gene by gene (string-based) perm_map = { ordered_genes[i]: ordered_genes[int(shuffled_indices[i])] for i in range(min(n_genes, len(shuffled_indices))) } gene_ranks = sweep_data_by_file[sf] rows = [] for orig_gene in ordered_genes: shuffled_gene = perm_map.get(orig_gene, orig_gene) sweep_line = gene_ranks.get(shuffled_gene) if sweep_line is None: continue parts = sweep_line.split("\t") parts[0] = orig_gene rows.append(parts) base = os.path.basename(sf) out_path = os.path.join( output_dir, f"fake_{base}_shuffle{shuffle_id}" ) with open(out_path, "w") as fout: for parts in rows: fout.write("\t".join(parts) + "\n") written_paths.append(out_path) if in_memory: inmem_replicates.append(replicate_dict) return inmem_replicates if in_memory else written_paths
[docs] def estimate_fdr( real_results: pl.DataFrame, shuffled_sweep_files_by_shuffle: list, vip_genes: list, control_sets: list, neighbors: pl.DataFrame, thresholds: list, populations: list, groups: list, pop_interest: str, simplified_sweeps_by_pop_fn=None, count_sweeps: bool = False, use_clust: bool = True, nthreads: int = 1, p_cutoff: float = 0.05, max_threshold: int = None, min_threshold: int = 0, min_vip_count: int = 0, ) -> pl.DataFrame: """ Step 9: Estimate FDR by comparing real results to a null distribution built from genome-shuffled sweep files (output of shuffle_genome). Matches the behaviour of estimate_FPR.pl: for each scope (population / group / "All"), the test statistic is sum(vip_count − ctrl_mean) across rows that pass the significance and threshold filters. FDR p-value = fraction of null replicates where that statistic ≥ the real statistic. Parameters ---------- real_results : pl.DataFrame — output of count_sweeps_multipop on the real data (columns: scope, threshold, vip_count, ctrl_mean, p_value) shuffled_sweep_files_by_shuffle : list of lists — each inner list holds paths for one shuffle replicate. Layout A: one file per population [[esn_s1, gwd_s1, ...], ...]. Layout B: one multi-population file per replicate [[multi_s1], [multi_s2], ...]. vip_genes : list of VIP gene IDs control_sets : control sets from run_bootstrap neighbors : pl.DataFrame from build_gene_neighbors_numpy thresholds : list of int thresholds populations : ordered population list groups : group label per population pop_interest : single pop, group name, or "All" simplified_sweeps_by_pop_fn : callable(sweep_file, col_index) → pl.DataFrame, or callable(sweep_file) → dict[pop, pl.DataFrame] for multi-pop files. Defaults to simplify_sweeps_gz. use_clust : expand by neighbours before counting nthreads : parallel workers p_cutoff : p-value threshold for "significant" rows in the test statistic (Perl: cutoff, default 0.05) max_threshold : upper bound on rank threshold included in the statistic (Perl: limit; None = no limit) min_threshold : lower bound on rank threshold (Perl: cutoff2, default 0) min_vip_count : minimum vip_count for a row to contribute to the statistic (Perl: enough, default 0) Returns ------- pl.DataFrame with columns: scope, p_value, n_replicates, real_stat, max_null_stat. One row per scope. real_stat = sum(vip_count − ctrl_mean) for significant rows. p_value = fraction of null replicates where null_stat ≥ real_stat. """ target_pops = _resolve_target_pops(pop_interest, populations, groups) if simplified_sweeps_by_pop_fn is None: def simplified_sweeps_by_pop_fn(sweep_file, col_index=1): return simplify_sweeps_gz(sweep_file, thresholds, col_index=col_index) def _run_one_shuffle(shuffle_data) -> pl.DataFrame: """Run count_sweeps_multipop on one shuffled replicate. shuffle_data is either: - list[str] — file paths (disk mode) - dict[str, pl.DataFrame] — {sweep_file: df} (in-memory mode from shuffle_genome(output_dir=None)) """ simplified_by_pop = {} if isinstance(shuffle_data, dict): # In-memory mode: shuffle_data maps sweep_file → shuffled DataFrame. # Use simplify_sweeps_gz with df= kwarg to avoid any disk I/O. # If more than one sweep file, union the simplified sets per pop. sf_keys = list(shuffle_data.keys()) for i, pop in enumerate(target_pops): col_idx = i + 1 frames = [] for sf in sf_keys: df_shuf = shuffle_data[sf] if col_idx < len(df_shuf.columns): frames.append( simplify_sweeps_gz( sf, thresholds, col_index=col_idx, df=df_shuf ) ) if frames: simplified_by_pop[pop] = ( pl.concat(frames) .group_by("gene_id") .agg(pl.col("selected_cutoff").min()) if len(frames) > 1 else frames[0] ) elif len(shuffle_data) == 1: # Single multi-pop file (disk mode) for i, pop in enumerate(target_pops): try: result = simplified_sweeps_by_pop_fn(shuffle_data[0], i + 1) except TypeError: full = simplified_sweeps_by_pop_fn(shuffle_data[0]) if isinstance(full, dict): simplified_by_pop = { p: full[p] for p in target_pops if p in full } break simplified_by_pop[pop] = result else: # One file per population (disk mode) for i, pop in enumerate(target_pops): if i < len(shuffle_data): try: simplified_by_pop[pop] = simplified_sweeps_by_pop_fn( shuffle_data[i], i + 1 ) except TypeError: simplified_by_pop[pop] = simplified_sweeps_by_pop_fn( shuffle_data[i] ) # For FDR, _compute_stat sums individual population rows (not group/All rows). # Only compute per-population scopes here — skip group/All overhead. frames = [] for pop in target_pops: if pop not in simplified_by_pop: continue df_pop = count_sweeps_singlepop( vip_genes=vip_genes, control_sets=control_sets, simplified_sweeps=simplified_by_pop[pop], neighbors=neighbors, thresholds=thresholds, use_clust=use_clust, count_sweeps=count_sweeps, ) frames.append(df_pop.with_columns(pl.lit(pop).alias("scope"))) return pl.concat(frames) if frames else pl.DataFrame() null_results = Parallel(n_jobs=nthreads, backend="loky", verbose=1)( delayed(_run_one_shuffle)(shuf_files) for shuf_files in shuffled_sweep_files_by_shuffle ) # --- Compute FDR per scope, matching estimate_FPR.pl --- # Test statistic = sum(vip_count − ctrl_mean) for rows passing all filters. # FDR p-value = fraction of null replicates where null_stat ≥ real_stat. def _compute_stat(df: pl.DataFrame, scope: str) -> float: """Perl estimate_FPR.pl: total_diff = sum(vip_count) - sum(ctrl_mean) for sig rows. For scope == "All": Perl accumulates counts PER population (sum-based), not a gene union. Summing individual population rows reproduces this exactly. """ if "scope" in df.columns: if scope == "All": # Perl "All:" = sum of individual population contributions. # Each gene in k populations contributes k to the total, not 1. s = df.filter(pl.col("scope").is_in(populations)) else: s = df.filter(pl.col("scope") == scope) else: s = df _pred = ( (pl.col("p_value") <= p_cutoff) & (pl.col("vip_count") >= min_vip_count) & (pl.col("threshold") >= min_threshold) ) if max_threshold is not None: _pred = _pred & (pl.col("threshold") <= max_threshold) s = s.filter(_pred) if s.is_empty(): return 0.0 return float((s["vip_count"].cast(pl.Float64) - s["ctrl_mean"]).sum()) # Determine which scope(s) to report FDR for. # Matches estimate_FPR.pl: only the scope matching pop_interest is reported # (Perl uses regex match: if($pop =~ $splitter_line[1])). if "scope" in real_results.columns: all_scopes = real_results["scope"].unique().to_list() scopes = [s for s in all_scopes if s == pop_interest] if not scopes: scopes = [pop_interest] # fallback if scope not found else: scopes = [pop_interest] rows = [] for scope in sorted(scopes): real_stat = _compute_stat(real_results, scope) null_stats = [_compute_stat(df, scope) for df in null_results] n_reps = len(null_stats) pval = ( float(np.sum(np.array(null_stats) >= real_stat) / n_reps) if n_reps > 0 else float("nan") ) max_null = float(np.max(null_stats)) if null_stats else float("nan") rows.append( { "scope": scope, "p_value": pval, "n_replicates": n_reps, "real_stat": real_stat, "max_null_stat": max_null, } ) return pl.DataFrame(rows)
[docs] def run_enrichment( sweep_files: list, gene_set: str, factors_file: str, annotation_file: str, populations: list, groups: list, thresholds: list, pop_interest: str = "All", cluster_distance: int = 500_000, n_runs: int = 10, tolerance: float = 0.05, min_distance: int = 1_250_000, flip: bool = False, max_rep: int = 25, nthreads: int = 1, n_shuffles: int = 8, shuffling_segs: int = 2, bootstrap_dir: str = None, distance_file: str = None, ) -> list: """ Run the full gene-set sweep enrichment pipeline. Steps: 1. Load gene set and derive valid gene universe (factors ∩ sweep_files[0]). 2. Bootstrap control sets matched on confounding factors (or load pre-computed sets from ``bootstrap_dir``). 3. For each sweep file: count sweep overlaps, shuffle genome for null distribution, and estimate FDR. :param sweep_files: Paths to sweep rank files (gene_id + per-population rank columns, tab- or space-separated, optionally gzipped). :param gene_set: TSV with ``gene_id`` and ``yes``/``no`` label columns (no header). Genes labelled ``yes`` are VIPs. :param factors_file: TSV confounding factors file (gene_id + factor columns, no header). :param annotation_file: BED gene coordinates file (0-based, no header). :param populations: Population codes matching sweep file column order. :param groups: Group label per population (same length as populations). :param thresholds: Rank cutoffs for enrichment curve (e.g. [6000, ..., 20]). :param pop_interest: Population, group name, or ``'All'`` for FDR scope. :param cluster_distance: Max bp between genes to count as neighbours. :param n_runs: Bootstrap batches. :param tolerance: Allowed ± fraction deviation in factor averages for control gene matching. :param min_distance: Minimum bp distance from VIPs for control eligibility. :param flip: Flip test direction when ``len(VIPs) > len(controls)`` (increases power). :param max_rep: Max average resamples per control gene across bootstrap sets. :param nthreads: Parallel workers (joblib). :param n_shuffles: FDR shuffle replicates (must be a multiple of 8). :param shuffling_segs: Genes per genomic shuffle segment. :param bootstrap_dir: Folder with pre-computed ``VIPs/`` and ``nonVIPs/`` sub-dirs. When non-empty, the bootstrap step is skipped. :param distance_file: Optional pre-computed distance TSV (gene_id \\t distance, no header). Passed to ``run_bootstrap_nomatchomega`` to bypass internal distance computation and match Perl's control pool exactly. :returns: List of FDR DataFrames, one per entry in ``sweep_files``. """ import gzip as _gz _t_start = time.time() # --- Load gene set and derive valid universe from factors ∩ sweep_files[0] --- _df_geneset = pl.read_csv( gene_set, has_header=False, separator="\t", new_columns=["gene_id", "label"], ) _exclude = set(hla_genes + hist_genes) with open(factors_file, "rt") as _fh: _fsep = "\t" if "\t" in _fh.readline() else " " _factors_ids = set( pl.read_csv( factors_file, has_header=False, separator=_fsep, infer_schema_length=0 )[:, 0].to_list() ) _sf0 = sweep_files[0] _open_fn = _gz.open if _sf0.endswith(".gz") else open with _open_fn(_sf0, "rt") as _fh: _sep0 = "\t" if "\t" in _fh.readline() else " " _sweep_ids = set( pl.read_csv(_sf0, has_header=False, separator=_sep0, infer_schema_length=0)[ :, 0 ].to_list() ) _valid_genes_list = sorted(_factors_ids & _sweep_ids) _case_genes = ( _df_geneset.filter( (pl.col("label") == "yes") & ~pl.col("gene_id").is_in(_exclude) & pl.col("gene_id").is_in(_valid_genes_list) ) .drop_nulls()["gene_id"] .to_list() ) _control_genes = _df_geneset.filter( (pl.col("label") == "no") & ~pl.col("gene_id").is_in(_exclude) & pl.col("gene_id").is_in(_valid_genes_list) )["gene_id"].to_list() # --- Step 6: Bootstrap --- if bootstrap_dir is not None: _vip_file = os.path.join(bootstrap_dir, "VIPs", "file_1") _ctrl_file = os.path.join(bootstrap_dir, "nonVIPs", "file_1") with open(_vip_file) as _fh: _vip_list_raw = [l.strip() for l in _fh if l.strip()] df_case = pl.DataFrame({"gene_id": _vip_list_raw}) _ctrl_sets_raw = [] with open(_ctrl_file) as _fh: for _line in _fh: _parts = _line.strip().split() if _parts: _ctrl_sets_raw.append(_parts[1:]) # drop sample_N prefix _n_ctrl_genes = max(len(r) for r in _ctrl_sets_raw) if _ctrl_sets_raw else 0 df_control = pl.DataFrame( {"sample_id": [f"sample_{i + 1}" for i in range(len(_ctrl_sets_raw))]} | { f"gene_{j}": [(r[j] if j < len(r) else None) for r in _ctrl_sets_raw] for j in range(_n_ctrl_genes) } ) print("TIMING_BOOTSTRAP: 0.000", flush=True) else: df_case, df_control = run_bootstrap_nomatchomega( case_genes=_case_genes, factors_file=factors_file, annotation_file=annotation_file, runs_number=n_runs, tolerance=tolerance, min_dist=min_distance, flip=flip, max_rep=max_rep, nthreads=nthreads, control_genes=_control_genes, distance_file=distance_file, ) print(f"TIMING_BOOTSTRAP: {time.time() - _t_start:.3f}", flush=True) _t_post_bootstrap = time.time() # --- Step 7: Sweep counting --- df_n = build_gene_neighbors_numpy( annotation_file, _valid_genes_list, cluster_distance ) _vip_list = df_case["gene_id"].to_list() _ctrl_sets = df_control.drop("sample_id").to_numpy().tolist() fdr_results = [] curve_results = [] for _rank_file in sweep_files: _size = os.path.basename(_rank_file).rsplit("_", 1)[-1] _simplified_by_pop = { pop: simplify_sweeps_gz(_rank_file, thresholds, col_index=i + 1) for i, pop in enumerate(populations) } results = count_sweeps_multipop( vip_genes=_vip_list, control_sets=_ctrl_sets, simplified_sweeps_by_pop=_simplified_by_pop, neighbors=df_n, thresholds=thresholds, populations=populations, groups=groups, pop_interest="All", # FDR requires all-population scopes; matches Perl "All:" use_clust=False, count_sweeps=False, nthreads=nthreads, ) # --- Step 8: Genome shuffling (in-memory) --- shuffled_data = shuffle_genome( coord_file=annotation_file, valid_genes=_valid_genes_list, sweep_files=[_rank_file], n_shuffles=n_shuffles, shuffling_segments_number=shuffling_segs, output_dir=None, ) # --- Step 9: FDR estimation --- fdr = estimate_fdr( real_results=results, shuffled_sweep_files_by_shuffle=shuffled_data, vip_genes=_vip_list, control_sets=_ctrl_sets, neighbors=df_n, thresholds=thresholds, populations=populations, groups=groups, pop_interest="All", # matches Perl estimate_FPR.pl hardcoded "All:" use_clust=False, count_sweeps=False, nthreads=nthreads, p_cutoff=1.0, max_threshold=max(thresholds), min_threshold=0, min_vip_count=1, ) row = fdr.row(0, named=True) print( f"FDR {_size}: {row['p_value']:.4f} {row['n_replicates']} " f"{row['real_stat']:.2f} {row['max_null_stat']:.2f}", flush=True, ) results = results.with_columns(pl.lit(_rank_file).alias("dataset")) fdr = fdr.with_columns(pl.lit(_rank_file).alias("dataset")) curve_results.append(results) fdr_results.append(fdr) print(f"TIMING_POSTBOOTSTRAP: {time.time() - _t_post_bootstrap:.3f}", flush=True) print(f"TIMING_TOTAL: {time.time() - _t_start:.3f}", flush=True) return df_case, df_control, pl.concat(curve_results), pl.concat(fdr_results)
thresholds = [ 6000, 5000, 4000, 3000, 2500, 2000, 1500, 1000, 900, 800, 700, 600, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 20, ] # hla_genes = [ # original expanded list — DO NOT restore # "ENSG00000233095", "ENSG00000223980", "ENSG00000230254", "ENSG00000204642", # "ENSG00000204632", "ENSG00000206503", "ENSG00000204592", "ENSG00000114455", # "ENSG00000235220", "ENSG00000235346", "ENSG00000235657", "ENSG00000229252", # "ENSG00000228299", "ENSG00000228987", "ENSG00000236884", "ENSG00000236418", # "ENSG00000233209", "ENSG00000223793", "ENSG00000228813", "ENSG00000243496", # "ENSG00000226264", "ENSG00000241394", "ENSG00000235744", "ENSG00000229685", # "ENSG00000237710", "ENSG00000206435", "ENSG00000228964", "ENSG00000234794", # "ENSG00000227357", "ENSG00000228080", "ENSG00000232062", "ENSG00000231939", # "ENSG00000231526", "ENSG00000226165", "ENSG00000239457", "ENSG00000241296", # "ENSG00000243215", "ENSG00000232957", "ENSG00000236177", "ENSG00000226826", # "ENSG00000237508", "ENSG00000226260", "ENSG00000227826", "ENSG00000229074", # "ENSG00000225890", "ENSG00000235680", "ENSG00000225824", "ENSG00000231834", # "ENSG00000231823", "ENSG00000229493", "ENSG00000239329", "ENSG00000243189", # "ENSG00000231558", "ENSG00000228163", "ENSG00000236632", "ENSG00000229295", # "ENSG00000237022", "ENSG00000224608", "ENSG00000229698", "ENSG00000132297", # "ENSG00000206509", "ENSG00000206506", "ENSG00000206505", "ENSG00000206493", # "ENSG00000206452", "ENSG00000206450", "ENSG00000204525", "ENSG00000234745", # "ENSG00000232962", "ENSG00000224103", "ENSG00000230708", "ENSG00000206308", # "ENSG00000196101", "ENSG00000206306", "ENSG00000206305", "ENSG00000206302", # "ENSG00000225103", "ENSG00000196610", "ENSG00000241674", "ENSG00000242685", # "ENSG00000206292", "ENSG00000206291", "ENSG00000215048", "ENSG00000233841", # "ENSG00000223532", "ENSG00000204287", "ENSG00000198502", "ENSG00000196126", # "ENSG00000196735", "ENSG00000179344", "ENSG00000237541", "ENSG00000232629", # "ENSG00000241106", "ENSG00000137403", "ENSG00000230413", "ENSG00000224320", # "ENSG00000197568", "ENSG00000242574", "ENSG00000204257", "ENSG00000204252", # "ENSG00000231389", "ENSG00000223865", "ENSG00000233904", "ENSG00000234487", # "ENSG00000237216", "ENSG00000229215", "ENSG00000227715", "ENSG00000230726", # "ENSG00000231021", "ENSG00000228284", "ENSG00000231286", "ENSG00000225201", # "ENSG00000206301", "ENSG00000228254", "ENSG00000241910", "ENSG00000242386", # "ENSG00000239463", "ENSG00000230141", "ENSG00000168384", "ENSG00000230763", # "ENSG00000230463", "ENSG00000233192", "ENSG00000230675", "ENSG00000227993", # "ENSG00000243612", "ENSG00000231679", "ENSG00000206240", "ENSG00000206237", # "ENSG00000204276", "ENSG00000224305", "ENSG00000241386", "ENSG00000225691", # "ENSG00000242092", "ENSG00000242361", "ENSG00000235844", "ENSG00000236693", # "ENSG00000232126", "ENSG00000234154", "ENSG00000243719", # ] # # hist_genes = [ # original histone list — DO NOT restore # "ENSG00000164508", "ENSG00000146047", "ENSG00000124610", "ENSG00000198366", # "ENSG00000196176", "ENSG00000124529", "ENSG00000124693", "ENSG00000137259", # "ENSG00000196226", "ENSG00000196532", "ENSG00000187837", "ENSG00000197061", # "ENSG00000187475", "ENSG00000180596", "ENSG00000180573", "ENSG00000168298", # "ENSG00000158373", "ENSG00000197697", "ENSG00000188987", "ENSG00000197409", # "ENSG00000196866", "ENSG00000197846", "ENSG00000198518", "ENSG00000187990", # "ENSG00000168274", "ENSG00000196966", "ENSG00000124575", "ENSG00000198327", # "ENSG00000124578", "ENSG00000256316", "ENSG00000197459", "ENSG00000256018", # "ENSG00000168242", "ENSG00000158406", "ENSG00000124635", "ENSG00000196787", # "ENSG00000197903", "ENSG00000198339", "ENSG00000184825", "ENSG00000185130", # "ENSG00000196747", "ENSG00000203813", "ENSG00000182611", "ENSG00000196374", # "ENSG00000197238", "ENSG00000197914", "ENSG00000184348", "ENSG00000233822", # "ENSG00000198374", "ENSG00000184357", "ENSG00000182572", "ENSG00000198558", # "ENSG00000197153", "ENSG00000233224", "ENSG00000196331", "ENSG00000168148", # "ENSG00000181218", "ENSG00000196890", "ENSG00000203818", "ENSG00000203814", # "ENSG00000183598", "ENSG00000183941", "ENSG00000203811", "ENSG00000183558", # "ENSG00000203812", "ENSG00000203852", "ENSG00000182217", "ENSG00000184678", # "ENSG00000184260", "ENSG00000184270", "ENSG00000197837", "ENSG00000265232", # "ENSG00000263376", "ENSG00000265133", "ENSG00000265198", "ENSG00000266225", # "ENSG00000266725", "ENSG00000264719", "ENSG00000263521", "ENSG00000263965", # ] hla_genes = [ "ENSG00000179344", "ENSG00000196126", "ENSG00000196735", "ENSG00000198502", "ENSG00000204252", "ENSG00000204257", "ENSG00000204287", "ENSG00000204525", "ENSG00000204592", "ENSG00000204632", "ENSG00000204642", "ENSG00000206503", "ENSG00000223865", "ENSG00000231389", "ENSG00000232629", "ENSG00000234745", "ENSG00000237541", "ENSG00000241106", "ENSG00000242574", ] hist_genes = []