Source code for flexsweep.cnn

import importlib
import json
import math
import os
import time
from copy import deepcopy

import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.special import expit, logit
from sklearn.calibration import calibration_curve
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    auc,
    brier_score_loss,
    confusion_matrix,
    precision_recall_curve,
    precision_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K

# from tensorflow.keras import layers
from tensorflow.keras.saving import register_keras_serializable
from tensorflow.keras.utils import Sequence

from . import np, pl


@register_keras_serializable(package="fs", name="masked_bce")
def masked_bce(y_true, y_pred):
    y_pred = tf.boolean_mask(
        y_pred, tf.not_equal(y_true, -1)
    )  # -1 will be masked/ y_true or y_pred?
    y_true = tf.boolean_mask(y_true, tf.not_equal(y_true, -1))

    return tf.keras.losses.binary_crossentropy(y_true, y_pred)


@register_keras_serializable(package="fs", name="masked_binary_accuracy")
def masked_binary_accuracy(y_true, y_pred):
    y_pred = tf.boolean_mask(y_pred, tf.not_equal(y_true, -1))
    y_true = tf.boolean_mask(y_true, tf.not_equal(y_true, -1))
    return tf.keras.metrics.binary_accuracy(y_true, y_pred)


@register_keras_serializable(package="fs", name="GradReverse")
class GradReverse(tf.keras.layers.Layer):
    """
    Gradient Reversal Layer (GRL) with tunable strength ``λ``.

    Forward pass: identity (returns the input unchanged).
    Backward pass: multiplies the incoming gradient by ``-λ``, which
    *reverses* (and scales) gradients flowing into the shared feature extractor.
    This encourages the extractor to learn **domain-invariant** features when
    the GRL feeds a domain classifier.

    Parameters
    ----------
    lambd : float, default=0.0
        Initial GRL strength ``λ``. The effective gradient multiplier is ``-λ``.
        Can be updated during training (e.g., via :class:`GRLRamp`).
    **kw : Any
        Passed to :class:`tf.keras.layers.Layer`.

    Attributes
    ----------
    lambd : tf.Variable
        Non-trainable scalar variable storing the current ``λ`` value. It can be
        modified by callbacks to schedule warm-up or annealing.

    Notes
    -----
    - Serialization: the layer is Keras-serializable and preserves the initial
      ``λ`` in configs. At runtime, the **variable** value may be updated.
    - Typical schedules **warm up** ``λ`` from 0 → 0.4–1.0 over several epochs.

    References
    ----------
    Ganin & Lempitsky (2015), "Unsupervised Domain Adaptation by
    Backpropagation" (DANN/GRL).
    """

    @staticmethod
    @tf.custom_gradient
    def _grl_with_lambda(x, lambd):
        y = tf.identity(x)

        def grad(dy):
            # grad wrt x is -λ * dy; no grad wrt λ
            return -lambd * dy, tf.zeros_like(lambd)

        return y, grad

    def __init__(self, lambd=0.0, **kw):
        super().__init__(**kw)
        # Keep JSON-safe init value for serialization
        self._lambd_init = float(lambd)
        # Non-trainable so you can control it via callback
        self.lambd = tf.Variable(
            self._lambd_init, trainable=False, dtype=tf.float32, name="grl_lambda"
        )

    def call(self, x):
        # Use the staticmethod custom op
        return GradReverse._grl_with_lambda(x, self.lambd)

    # ---- Keras serialization ----
    def get_config(self):
        cfg = super().get_config()
        cfg.update({"lambd": float(self._lambd_init)})
        return cfg

    @classmethod
    def from_config(cls, cfg):
        return cls(**cfg)


class GRLRamp(tf.keras.callbacks.Callback):
    """
    Linear warm-up schedule for GRL strength ``λ``.

    Increases the GRL factor linearly from 0 to ``max_lambda`` over
    ``epochs`` calls to :meth:`on_epoch_begin`. After warm-up, ``λ`` is held
    constant at ``max_lambda``.

    Parameters
    ----------
    grl_layer : GradReverse
        The GRL layer instance whose ``lambd`` variable will be updated.
    max_lambda : float, default=0.5
        Target value for ``λ`` at the end of the warm-up.
    epochs : int, default=50
        Number of warm-up epochs. If total training epochs exceed this value,
        ``λ`` remains fixed thereafter.

    Notes
    -----
    - Warm-up helps stabilize training by letting the classifier learn a useful
      decision surface **before** strong domain-adversarial pressure is applied.
    - Consider tuning ``max_lambda`` and warm-up length based on how quickly the
      domain accuracy approaches ~0.5 (a sign of domain invariance).
    """

    def __init__(self, grl_layer, max_lambda=0.5, epochs=50):
        """
        epochs = number of ramp epochs (not total training epochs).
        After this many epochs, λ will be held at max_lambda.
        """
        super().__init__()
        self.grl_layer = grl_layer
        self.max_lambda = float(max_lambda)
        self.ramp_epochs = int(max(1, epochs))

    def on_epoch_begin(self, epoch, logs=None):
        # linear warmup 0 → max_lambda over `ramp_epochs`, then hold
        if epoch < self.ramp_epochs:
            t = epoch / max(1, self.ramp_epochs - 1)
            lam = self.max_lambda * t
        else:
            lam = self.max_lambda
        self.grl_layer.lambd.assign(lam)


class LogGRLLambda(tf.keras.callbacks.Callback):
    def __init__(self, grl_layer, key="grl_lambda"):
        super().__init__()
        self.grl = grl_layer
        self.key = key

    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            logs[self.key] = float(self.grl.lambd.numpy())


class LossWeightsScheduler(tf.keras.callbacks.Callback):
    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta

    def on_epoch_end(self, epoch, logs={}):
        gamma = 10  # 10, 5
        p = epoch / 30
        lambda_new = 2 / (1 + math.exp(-gamma * p)) - 1
        K.set_value(self.beta, lambda_new)


class LossWeightsLogger(tf.keras.callbacks.Callback):
    def __init__(self, loss_weights):
        super().__init__()
        self.loss_weights = loss_weights  # e.g., [alpha, beta]

    def on_epoch_end(self, epoch, logs=None):
        aw = float(K.get_value(self.loss_weights[0]))
        bw = float(K.get_value(self.loss_weights[1]))
        print(f"Loss Weights @ epoch {epoch + 1}: alpha={aw:.4f}, beta={bw:.4f}")
        if logs is not None:
            logs["alpha"] = aw
            logs["beta"] = bw


[docs] class CNN: """ Class to build and train a Convolutional Neural Network (CNN) for Flex-sweep. It loads/reshapes Flex-sweep feature vectors, trains, evaluates and predicts, including domain-adaptation extension. Attributes ---------- train_data : str | pl.DataFrame | None Path to training parquet/CSV (or a Polars DataFrame). source_data : str | None Path to *source* (labeled) parquet for domain adaptation. target_data : str | None Path to *target/empirical* parquet for domain adaptation (unlabeled). predict_data : str | pl.DataFrame | None Path/DataFrame with samples to predict (standard supervised path). valid_data : Any (Reserved) Optional separate validation set path/DF (unused). output_folder : str | None Directory where models, figures and predictions are written. normalize : bool If True, apply a Keras `Normalization` layer (fit on train only). model : tf.keras.Model | str | None A compiled Keras model or a path to a saved model. num_stats : int Number of per-window statistics used as channels. Default 11. center : np.ndarray[int] Center coordinates (bp) used to index columns; defaults to 500k..700k step 10k. windows : np.ndarray[int] Window sizes used to index columns; default [50k, 100k, 200k, 500k, 1M]. train_split : float Fraction of data used for training (rest split equally into val/test). gpu : bool If False, disable CUDA via `CUDA_VISIBLE_DEVICES=-1`. tf : module | None TensorFlow module, set by :meth:`check_tf`. history : pl.DataFrame | None Training history after :meth:`train` / :meth:`train_da`. prediction : pl.DataFrame | None Latest prediction table produced by :meth:`train` or :meth:`predict*`. """ def __init__( self, train_data=None, source_data=None, target_data=None, predict_data=None, valid_data=None, output_folder=None, normalize=False, model=None, num_stats=24, center=[0, 1.2e6], step=1e5, windows=np.array([100000]), ): """ Initialize a CNN runner. Parameters ---------- train_data : str | pl.DataFrame | None Path to training data (`.parquet`, `.csv[.gz]`) or Polars DataFrame. source_data : str | None Path to labeled source parquet for domain adaptation. target_data : str | None Path to unlabeled empirical/target parquet for domain adaptation. predict_data : str | pl.DataFrame | None Path/DataFrame for inference in :meth:`predict`. valid_data : Any, optional Reserved for a future explicit validation split (unused). output_folder : str | None Output directory for artifacts (models, plots, CSVs). normalize : bool, default=False If True, fit a `Normalization` layer on training features. model : tf.keras.Model | str | None Prebuilt Keras model or path to a saved model. Notes ----- Defaults assume 11 statistics × 5 windows × 21 centers organized in column names like: ``{stat}_{window}_{center}``. """ # self.sweep_data = sweep_data self.normalize = normalize self.train_data = train_data self.predict_data = predict_data self.test_train_data = None self.output_folder = output_folder self.output_prediction = "predictions.txt" self.num_stats = 24 self.center = np.arange(center[0] + step // 2, center[1], step) self.windows = np.asarray(windows) self.step = step self.train_split = 0.8 self.prediction = None self.history = None self.model = model self.gpu = True self.tf = None self.source_data = source_data self.target_data = target_data self.mean = None self.std = None self.scores = None def check_tf(self): """ Import TensorFlow (optionally forcing CPU). Returns ------- module Imported ``tensorflow`` module. Notes ----- If ``self.gpu`` is ``False``, the environment variable ``CUDA_VISIBLE_DEVICES`` is set to ``-1`` **before** importing TF. """ if self.gpu is False: os.environ["CUDA_VISIBLE_DEVICES"] = "-1" tf = importlib.import_module("tensorflow") return tf def preprocess(self, x, y=None, training=False, epsilon=1e-7): x = tf.cast(x, tf.float32) # mean = tf.cast(self.mean, tf.float32) # std = tf.cast(self.std, tf.float32) # mean = tf.reshape(mean, (self.num_stats, 1, 1)) # std = tf.reshape(std, (self.num_stats, 1, 1)) # Feature-wise normalization using training mean/std x = (x - self.mean) / (self.std + epsilon) # x = (x - mean) / (std + epsilon) if training: # # # Optional: small Gaussian noise (try stddev ~0.01-0.05) # x = x + tf.random.normal(tf.shape(x), mean=0.0, stddev=0.02, dtype=x.dtype) # # Optional: channel/stat dropout (drops whole stats) # keep_prob = 0.90 # if x.shape.rank == 3: # (S, W*C, 1) # mask = tf.cast( # tf.random.uniform((self.num_stats, 1, 1)) < keep_prob, x.dtype # ) # else: # (B, S, W*C, 1) # mask = tf.cast( # tf.random.uniform((1, self.num_stats, 1, 1)) < keep_prob, x.dtype # ) # x = x * mask / keep_prob # Horizontal flip augmentation x = tf.image.random_flip_left_right(x) if y is not None: return x, y else: return x def cnn_flexsweep_feature(self, model_input): """ Shared 2D-CNN feature extractor (three branches). Parameters ---------- model_input : tf.keras.layers.Input Input tensor of shape ``(W, C, S)``. Returns ------- tf.Tensor Flattened concatenated features from the three branches. See Also -------- cnn_flexsweep : Similar branch structure followed by classification head. """ tf = self.check_tf() He = tf.keras.initializers.HeNormal() # --- Branch 1: 3x3 convs --- b1 = tf.keras.layers.Conv2D( 64, 3, padding="same", kernel_initializer=He, name="fx_b1_c1" )(model_input) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.Conv2D( 128, 3, padding="same", kernel_initializer=He, name="fx_b1_c2" )(b1) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.Conv2D( 256, 3, padding="same", kernel_initializer=He, name="fx_b1_c3" )(b1) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.MaxPooling2D( pool_size=3, padding="same", name="fx_b1_pool" )(b1) b1 = tf.keras.layers.Dropout(0.15, name="fx_b1_drop")(b1) b1 = tf.keras.layers.Flatten(name="fx_b1_flat")(b1) # --- Branch 2: 2x2 convs with dilation (1,3) --- b2 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c1", )(model_input) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c2", )(b2) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c3", )(b2) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.MaxPooling2D(pool_size=2, name="fx_b2_pool")(b2) b2 = tf.keras.layers.Dropout(0.15, name="fx_b2_drop")(b2) b2 = tf.keras.layers.Flatten(name="fx_b2_flat")(b2) # --- Branch 3: 2x2 convs with dilation (5,1) then (1,5) --- b3 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[5, 1], padding="same", kernel_initializer=He, name="fx_b3_c1", )(model_input) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 5], padding="same", kernel_initializer=He, name="fx_b3_c2", )(b3) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 5], padding="same", kernel_initializer=He, name="fx_b3_c3", )(b3) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.MaxPooling2D(pool_size=2, name="fx_b3_pool")(b3) b3 = tf.keras.layers.Dropout(0.15, name="fx_b3_drop")(b3) b3 = tf.keras.layers.Flatten(name="fx_b3_flat")(b3) feat = tf.keras.layers.Concatenate(name="fx_concat")([b1, b2, b3]) h = tf.keras.layers.Dense(128, activation="relu")(feat) h = tf.keras.layers.Dropout(0.20)(h) h = tf.keras.layers.Dense(32, activation="relu")(h) h = tf.keras.layers.Dropout(0.10)(h) out_cls = tf.keras.layers.Dense(1, activation="sigmoid", name="classifier")(h) return out_cls def cnn_flexsweep(self, model_input, num_classes=1): """ Flex-sweep CNN architecture with multiple convolutional and pooling layers. Args: input_shape (tuple): Shape of the input data, e.g., (224, 224, 3). Default Flex-sweep input statistics, windows and centers num_classes (int): Number of output classes in the classification problem. Default: Flex-sweep binary classification Returns: Model: A Keras model instance representing the Flex-sweep CNN architecture. """ tf = self.check_tf() # 3x3 layer layer1 = tf.keras.layers.Conv2D( 64, 3, padding="same", name="convlayer1_1", kernel_initializer="glorot_uniform", )(model_input) layer1 = tf.keras.layers.ReLU(negative_slope=0)(layer1) layer1 = tf.keras.layers.Conv2D( 128, 3, padding="same", name="convlayer1_2", kernel_initializer="glorot_uniform", )(layer1) layer1 = tf.keras.layers.ReLU(negative_slope=0)(layer1) layer1 = tf.keras.layers.Conv2D(256, 3, padding="same", name="convlayer1_3")( layer1 ) layer1 = tf.keras.layers.ReLU(negative_slope=0)(layer1) layer1 = tf.keras.layers.MaxPooling2D( pool_size=3, name="poollayer1", padding="same" )(layer1) layer1 = tf.keras.layers.Dropout(0.15, name="droplayer1")(layer1) layer1 = tf.keras.layers.Flatten(name="flatlayer1")(layer1) # 2x2 layer with 1x3 dilation layer2 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[1, 3], padding="same", name="convlayer2_1", kernel_initializer="glorot_uniform", )(model_input) layer2 = tf.keras.layers.ReLU(negative_slope=0)(layer2) layer2 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 3], padding="same", name="convlayer2_2", kernel_initializer="glorot_uniform", )(layer2) layer2 = tf.keras.layers.ReLU(negative_slope=0)(layer2) layer2 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 3], padding="same", name="convlayer2_3" )(layer2) layer2 = tf.keras.layers.ReLU(negative_slope=0)(layer2) layer2 = tf.keras.layers.MaxPooling2D(pool_size=2, name="poollayer2")(layer2) layer2 = tf.keras.layers.Dropout(0.15, name="droplayer2")(layer2) layer2 = tf.keras.layers.Flatten(name="flatlayer2")(layer2) # 2x2 with 1x5 dilation layer3 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[1, 5], padding="same", name="convlayer4_1", kernel_initializer="glorot_uniform", )(model_input) layer3 = tf.keras.layers.ReLU(negative_slope=0)(layer3) layer3 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 5], padding="same", name="convlayer4_2", kernel_initializer="glorot_uniform", )(layer3) layer3 = tf.keras.layers.ReLU(negative_slope=0)(layer3) layer3 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 5], padding="same", name="convlayer4_3" )(layer3) layer3 = tf.keras.layers.ReLU(negative_slope=0)(layer3) layer3 = tf.keras.layers.MaxPooling2D(pool_size=2, name="poollayer3")(layer3) layer3 = tf.keras.layers.Dropout(0.15, name="droplayer3")(layer3) layer3 = tf.keras.layers.Flatten(name="flatlayer3")(layer3) # concatenate convolution layers concat = tf.keras.layers.concatenate([layer1, layer2, layer3]) concat = tf.keras.layers.Dense(512, name="512dense", activation="relu")(concat) concat = tf.keras.layers.Dropout(0.2, name="dropconcat1")(concat) concat = tf.keras.layers.Dense(128, name="last_dense", activation="relu")( concat ) concat = tf.keras.layers.Dropout(0.2 / 2, name="dropconcat2")(concat) output = tf.keras.layers.Dense( num_classes, name="out_dense", activation="sigmoid", kernel_initializer="glorot_uniform", )(concat) return output def load_training_data(self, _stats=None, w=None, n=None, one_dim=False): """ Load and reshape training/validation/test tensors from table-format features. Parameters ---------- _stats : list[str] | None List of statistic base names to include (e.g., ``["ihs","nsl",...]``). If None, you must pass an explicit list later in :meth:`train`. w : int | list[int] | None Restrict to specific window sizes (e.g., 100000 or [50000,100000]). Columns are selected by regex suffix ``_{window}``. n : int | None Optional number of rows to sample from parquet. one_dim : bool, default=False If True, flatten spatial grid to ``(W*C, S)`` for 1D models. Returns ------- tuple ``(X_train, X_test, Y_train, Y_test, X_valid, Y_valid)`` with shapes: - if ``one_dim`` is False: ``X_*`` → ``(N, W, C, S)``, labels are 0/1. - if ``one_dim`` is True: ``X_*`` → ``(N, W*C, S)``. Raises ------ AssertionError If ``train_data`` is missing or has an unsupported extension. Notes ----- Any ``model`` value not equal to ``"neutral"`` is coerced to ``"sweep"``. """ assert self.train_data is not None, "Please input training data" assert ( "txt" in self.train_data or "csv" in self.train_data or self.train_data.endswith(".parquet") ), "Please save your dataframe as CSV or parquet" if isinstance(self.train_data, pl.DataFrame): pass elif self.train_data.endswith(".gz"): tmp = pl.read_csv(self.train_data, separator=",") elif self.train_data.endswith(".parquet"): tmp = pl.read_parquet(self.train_data) if n is not None: tmp = tmp.sample(n) tmp = tmp.with_columns( pl.when(pl.col("model") != "neutral") .then(pl.lit("sweep")) .otherwise(pl.lit("neutral")) .alias("model") ) if w is not None: try: self.center = np.array([int(w)]) tmp = tmp.select( "iter", "s", "t", "f_i", "f_t", "model", f"^*._{int(w)}$" ) except Exception: self.center = np.sort(np.array(w).astype(int)) _tmp = [] _h = tmp.select("iter", "s", "t", "f_i", "f_t", "model") for window in self.center: _tmp.append(tmp.select(f"^*._{int(window)}$")) tmp = pl.concat(_tmp, how="horizontal") tmp = pl.concat([_h, tmp], how="horizontal") # sweep_parameters = tmp.filter("model" != "neutral").select(tmp.columns[:7]) stats = [] if _stats is not None: stats = stats + _stats train_stats = [] for i in stats: train_stats.append(tmp.select(pl.col(f"^{i}_[0-9]+_[0-9]+$"))) train_stats = pl.concat(train_stats, how="horizontal") train_stats = pl.concat( [ tmp.select("model", "iter", "s", "f_i", "f_t", "t", "mu", "r"), train_stats, ], how="horizontal", ) y = train_stats.select( ((~pl.col("model").str.contains("neutral")).cast(pl.Int8)).alias( "neutral_flag" ) )["neutral_flag"].to_numpy() test_split = round(1 - self.train_split, 2) ( X_train, X_test, Y_train, y_test, ) = train_test_split(train_stats, y, test_size=test_split, shuffle=True) X_train = ( train_stats.select(train_stats.columns[8:]) .to_numpy() .reshape( train_stats.shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) ) X_valid, X_test, Y_valid, Y_test = train_test_split( X_test, y_test, test_size=0.5 ) X_test_params = X_test.select(X_test.columns[:6]) X_test = ( X_test.select(train_stats.columns[8:]) .to_numpy() .reshape( X_test.shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) ) X_valid = ( X_valid.select(train_stats.columns[8:]) .to_numpy() .reshape( X_valid.shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) ) # Normalization on training data if self.normalize: self.stat_norm = tf.keras.layers.Normalization(axis=-1, name="stat_norm") self.stat_norm.adapt(X_train) # learns mean/std from training set only # Input stats as channel to improve performance # Avoiding changes stats order X_train = X_train.reshape( X_train.shape[0], self.windows.size, self.center.size, self.num_stats ) X_test = X_test.reshape( X_test.shape[0], self.windows.size, self.center.size, self.num_stats ) X_valid = X_valid.reshape( X_valid.shape[0], self.windows.size, self.center.size, self.num_stats ) X_test = X_test.reshape( X_test.shape[0], self.windows.size, self.center.size, self.num_stats ) if one_dim: X_train = X_train.reshape( -1, self.windows.size * self.center.size, self.num_stats ) X_valid = X_valid.reshape( -1, self.windows.size * self.center.size, self.num_stats ) X_test = X_test.reshape( -1, self.windows.size * self.center.size, self.num_stats ) self.test_train_data = [X_test, X_test_params, Y_test] return ( X_train, X_test, Y_train, Y_test, X_valid, Y_valid, ) def train( self, _iter=1, _stats=None, w=None, cnn=None, one_dim=False, preprocess=False, show_plot=False, ): """ Train a CNN on flex-sweep tensors with early stopping and checkpoints. Parameters ---------- _iter : int, default=1 Tag for output naming (kept for backwards compatibility). _stats : list[str] | None Statistic base names. If None, defaults to the 11 flex-sweep stats. w : int | list[int] | None Window size(s) to select (see :meth:`load_training_data`). cnn : callable | None A function mapping a Keras input tensor to an output tensor. Defaults to :meth:`cnn_flexsweep`. If ``one_dim=True``, you must provide a compatible 1D architecture. one_dim : bool, default=False If True, uses flattened ``(W*C, S)`` inputs. Returns ------- pl.DataFrame Predictions on the held-out test set with columns: ``['model','f_i','f_t','s','t','predicted_model','prob_sweep','prob_neutral']``. Notes ----- - Optimizer: Adam with cosine-restarts schedule. - Loss: Binary cross-entropy with label smoothing (0.05). - Early stopping monitors validation AUC (restore best weights). - Saves ``model.keras`` to ``output_folder`` if provided. """ if one_dim: assert cnn is not None, "Please input a 1D CNN architecture" # Default stats if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] if one_dim: assert cnn is not None, "Please input a 1D CNN architecture" self.num_stats = len(_stats) self.feature_names = list(_stats) # Default CNN if cnn is None: cnn = self.cnn_flexsweep ( X_train, X_test, Y_train, Y_test, X_valid, Y_valid, ) = self.load_training_data(w=w, _stats=_stats, one_dim=one_dim) self.num_stats = len(_stats) self.feature_names = list(_stats) # Default CNN if cnn is None: cnn = self.cnn_flexsweep ( X_train, X_test, Y_train, Y_test, X_valid, Y_valid, ) = self.load_training_data(w=w, _stats=_stats, one_dim=one_dim) X_train = X_train.reshape( X_train.shape[0], self.num_stats, self.center.size * self.windows.size, 1 ) X_test = X_test.reshape( X_test.shape[0], self.num_stats, self.center.size * self.windows.size, 1 ) X_valid = X_valid.reshape( X_valid.shape[0], self.num_stats, self.center.size * self.windows.size, 1 ) # put model together input_to_model = tf.keras.Input(X_train.shape[1:]) batch_size = 32 # norm = tf.keras.layers.Normalization(axis=(0, 1, 2)) # augment = tf.keras.Sequential( # [tf.keras.layers.RandomFlip("horizontal")], # name="augment", # ) if preprocess: self.mean = X_train.mean(axis=(0, 1, 2), keepdims=False) self.std = X_train.std(axis=(0, 1, 2), keepdims=False) # self.mean = X_train.mean(axis=(0, 2, 3), keepdims=True) # self.std = X_train.std(axis=(0, 2, 3), keepdims=True) train_dataset = ( tf.data.Dataset.from_tensor_slices((X_train, Y_train)) .shuffle(10000) .map(lambda x, y: self.preprocess(x, y, training=True)) .batch(batch_size) .prefetch(tf.data.AUTOTUNE) ) valid_dataset = ( tf.data.Dataset.from_tensor_slices((X_valid, Y_valid)) .map(lambda x, y: self.preprocess(x, y, training=False)) .batch(batch_size) .prefetch(tf.data.AUTOTUNE) ) test_dataset = ( tf.data.Dataset.from_tensor_slices((X_test, Y_test)) .map(lambda x, y: self.preprocess(x, y, training=False)) .batch(batch_size) .prefetch(tf.data.AUTOTUNE) ) else: train_dataset = ( tf.data.Dataset.from_tensor_slices((X_train, Y_train)) .shuffle(10000) .batch(batch_size) .prefetch(tf.data.AUTOTUNE) ) valid_dataset = ( tf.data.Dataset.from_tensor_slices((X_valid, Y_valid)) .batch(batch_size) .prefetch(tf.data.AUTOTUNE) ) test_dataset = ( tf.data.Dataset.from_tensor_slices((X_test, Y_test)) .batch(batch_size) .prefetch(tf.data.AUTOTUNE) ) model = tf.keras.models.Model( inputs=[input_to_model], outputs=[cnn(input_to_model)] ) model_path = f"{self.output_folder}/model.keras" metrics_measures = [ tf.keras.metrics.BinaryAccuracy(name="accuracy"), tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.AUC(name="roc", curve="ROC"), ] lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecayRestarts( initial_learning_rate=1e-4, first_decay_steps=300 ) opt_adam = tf.keras.optimizers.Adam( learning_rate=lr_decayed_fn, epsilon=0.0000001, amsgrad=True ) # Keep only one compilation model.compile( optimizer=opt_adam, loss="binary_crossentropy", # loss=custom_loss, metrics=metrics_measures, ) earlystop = tf.keras.callbacks.EarlyStopping( monitor="val_accuracy", # monitor="val_auc", min_delta=0.0001, patience=5, verbose=2, mode="max", restore_best_weights=True, ) checkpoint = tf.keras.callbacks.ModelCheckpoint( model_path, monitor="val_accuracy", # monitor="val_auc", verbose=2, save_best_only=True, mode="max", ) callbacks_list = [checkpoint, earlystop] start = time.time() history = model.fit( train_dataset, epochs=1000, validation_data=valid_dataset, callbacks=callbacks_list, ) val_score = model.evaluate( valid_dataset, batch_size=32, steps=len(Y_valid) // 32, ) test_score = model.evaluate( test_dataset, batch_size=32, steps=len(Y_test) // 32, ) train_score = model.evaluate( train_dataset, batch_size=32, steps=len(Y_train) // 32, ) self.scores = [val_score, test_score, train_score] self.model = model df_history = pl.DataFrame(history.history) self.history = df_history print( f"Training and testing model took {round(time.time() - start, 3)} seconds" ) if self.output_folder is not None: model.save(model_path) # ROC curves and confusion matrix if self.output_folder is None: _output_prediction = self.output_prediction else: _output_prediction = f"{self.output_folder}/{self.output_prediction}" test_X, test_X_params, test_Y = deepcopy(self.test_train_data) test_X = test_X.reshape( test_X.shape[0], self.num_stats, self.windows.size * self.center.size, 1 ) # self.mean = test_X.mean(axis=(0, 1, 2), keepdims=False) # self.std = test_X.std(axis=(0, 1, 2), keepdims=False) if preprocess: preds = model.predict(self.preprocess(test_X)) else: preds = model.predict(test_X) preds = np.column_stack([1 - preds, preds]) predictions = np.argmax(preds, axis=1) prediction_dict = { 0: "neutral", 1: "sweep", } predictions_class = np.vectorize(prediction_dict.get)(predictions) df_prediction = pl.concat( [ test_X_params.select("model", "f_i", "f_t", "s", "t"), pl.DataFrame( { "predicted_model": predictions_class, "prob_sweep": preds[:, 1], "prob_neutral": preds[:, 0], } ), ], how="horizontal", ) self.prediction = df_prediction.with_columns( ( pl.when(pl.col("model").str.contains("neutral")) .then(pl.lit("neutral")) .otherwise(pl.lit("sweep")) ).alias("model") ) # self.prediction.write_csv("train_predictions.txt") self.roc_curve(show_plot=show_plot) if self.output_folder is not None: df_prediction.write_csv(_output_prediction) return df_prediction def _load_X_y(self): """Reload feature tensor and labels from train_data using stored feature_names.""" if isinstance(self.train_data, pl.DataFrame): df = self.train_data elif self.train_data.endswith(".parquet"): df = pl.read_parquet(self.train_data) else: df = pl.read_csv(self.train_data, separator=",") stat_frames = [ df.select(pl.col(f"^{name}_[0-9]+_[0-9]+$")) for name in self.feature_names ] X_df = pl.concat(stat_frames, how="horizontal") X = X_df.to_numpy().reshape( df.shape[0], len(self.feature_names), self.center.size * self.windows.size, 1, ) y = (~df["model"].str.contains("neutral")).cast(pl.Int8).to_numpy() return X, y def feature_importance(self, X=None, y=None, n_repeats=5, output_folder=None): """ Permutation feature importance over stat channels. For each stat (axis=1 of the CNN input), shuffle values across samples n_repeats times and measure the mean accuracy drop vs baseline. Parameters ---------- X : np.ndarray, shape (N, num_stats, n_positions, 1), optional Feature tensor. If None, reloads from self.train_data. y : np.ndarray, shape (N,), optional Integer labels (0=neutral, 1=sweep). Required when X is provided. n_repeats : int Shuffle repetitions per stat. Default 5. output_folder : str, optional If given, saves feature_importance.svg and feature_importance.csv. Returns ------- df : pl.DataFrame Columns: feature, mean_drop, std_drop — sorted descending by mean_drop. fig : matplotlib.figure.Figure """ assert hasattr( self, "feature_names" ), "Call train() before feature_importance()." if X is None: X, y = self._load_X_y() baseline_pred = self.model.predict(X, verbose=0).argmax(axis=1) baseline_acc = (baseline_pred == y).mean() rng = np.random.default_rng(42) records = [] for i, name in enumerate(self.feature_names): drops = [] for _ in range(n_repeats): X_perm = X.copy() perm_idx = rng.permutation(X_perm.shape[0]) X_perm[:, i, :, :] = X_perm[perm_idx, i, :, :] acc = (self.model.predict(X_perm, verbose=0).argmax(axis=1) == y).mean() drops.append(baseline_acc - acc) records.append( { "feature": name, "mean_drop": float(np.mean(drops)), "std_drop": float(np.std(drops)), } ) df = pl.DataFrame(records).sort("mean_drop", descending=True) names = df["feature"].to_list() drops_v = df["mean_drop"].to_list() errs_v = df["std_drop"].to_list() fig, ax = plt.subplots(figsize=(8, max(4, len(df) * 0.28))) ax.barh( names[::-1], drops_v[::-1], xerr=errs_v[::-1], color="steelblue", ecolor="gray", capsize=3, ) ax.axvline(0, color="black", linewidth=0.8) ax.set_xlabel("Mean accuracy drop (permutation importance)") fig.tight_layout() if output_folder is not None: fig.savefig( os.path.join(output_folder, "feature_importance.svg"), bbox_inches="tight", ) df.write_csv(os.path.join(output_folder, "feature_importance.csv")) return df, fig def predict( self, _stats=None, w=None, one_dim=False, _iter=1, fname=None, preprocess=True ): """ Predict on a feature table using a trained model. Parameters ---------- _stats : list[str] | None Statistic base names to include; defaults to the 11 flex-sweep stats. w : int | list[int] | None Window size(s) to select. simulations : bool, default=False Reserved flag; has no effect here. _iter : int, default=1 Tag for output naming (unused). Returns ------- pl.DataFrame Sorted predictions per region with columns: ``['chr','start','end','f_i','f_t','s','t','predicted_model','prob_sweep','prob_neutral']``. Raises ------ AssertionError If ``self.model`` is not set or ``predict_data`` is missing. Notes ----- If ``self.model`` is a string path, it is loaded via ``tf.keras.models.load_model``. """ if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] self.num_stats = len(_stats) assert self.model is not None, "Please input the CNN trained model" if isinstance(self.model, str): model = tf.keras.models.load_model(self.model) else: model = self.model # import data to predict assert self.predict_data is not None, "Please input training data" assert ( isinstance(self.predict_data, pl.DataFrame) or "txt" in self.predict_data or "csv" in self.predict_data or self.predict_data.endswith(".parquet") ), "Please input a parquet pl.DataFrame" df_test = pl.read_parquet(self.predict_data) df_test = df_test.with_columns( pl.when(pl.col("model") != "neutral") .then(pl.lit("sweep")) .otherwise(pl.lit("neutral")) .alias("model") ) regions = df_test["iter"].to_numpy() stats = [] if _stats is not None: stats = stats + _stats test_stats = [] for i in stats: test_stats.append(df_test.select(pl.col(f"^{i}_[0-9]+_[0-9]+$"))) X_test = pl.concat(test_stats, how="horizontal") if w is not None: try: self.center = np.array([int(w)]) X_test = X_test.select(f"^*._{int(w)}$") except Exception: self.center = np.sort(np.array(w).astype(int)) _X_test = [] for window in self.center: _X_test.append(X_test.select(f"^*._{int(window)}$")) X_test = pl.concat(_X_test, how="horizontal") test_X_params = df_test.select( "model", "iter", "s", "f_i", "f_t", "t", "mu", "r" ) test_X = X_test.to_numpy().reshape( X_test.shape[0], self.num_stats, self.windows.size * self.center.size, 1 ) if one_dim: test_X = test_X.reshape( -1, self.windows.size * self.center.size, self.num_stats ) if preprocess: self.mean = test_X.mean(axis=(0, 1, 2), keepdims=False) self.std = test_X.std(axis=(0, 1, 2), keepdims=False) # self.mean = test_X.mean(axis=(0, 2, 3), keepdims=True) # self.std = test_X.std(axis=(0, 2, 3), keepdims=True) test_X_ds = ( tf.data.Dataset.from_tensor_slices(test_X) .map(lambda x: self.preprocess(x, training=False)) .batch(32) .prefetch(tf.data.AUTOTUNE) ) preds = model.predict(self.preprocess(test_X, training=False)) else: test_X_ds = ( tf.data.Dataset.from_tensor_slices(test_X) .batch(32) .prefetch(tf.data.AUTOTUNE) ) preds = model.predict(test_X_ds) preds = np.column_stack([1 - preds, preds]) predictions = np.argmax(preds, axis=1) prediction_dict = { 0: "neutral", 1: "sweep", } predictions_class = np.vectorize(prediction_dict.get)(predictions) df_prediction = pl.concat( [ test_X_params.select("model", "f_i", "f_t", "s", "t", "mu", "r"), pl.DataFrame( { "predicted_model": predictions_class, "prob_sweep": preds[:, 1], "prob_neutral": preds[:, 0], } ), ], how="horizontal", ) # Same folder custom fvs name based on input VCF. # _output_prediction = f"{self.output_folder}/{os.path.basename(self.predict_data).replace("fvs_", "").replace(".parquet", "_predictions.txt")}" _output_prediction = f"{self.output_folder}/{os.path.basename(self.predict_data).replace('fvs_', '').replace('.parquet', '_predictions.txt')}" df_prediction = df_prediction.with_columns(pl.Series("region", regions)) try: chr_start_end = np.array( [item.replace(":", "-").split("-") for item in regions] ) df_prediction = df_prediction.with_columns( pl.Series("chr", chr_start_end[:, 0]), pl.Series("start", chr_start_end[:, 1], dtype=pl.Int64), pl.Series("end", chr_start_end[:, 2], dtype=pl.Int64), pl.Series( "nchr", pl.Series(chr_start_end[:, 0]).str.replace("chr", "").cast(int), ), ) df_prediction = df_prediction.sort("nchr", "start").select( pl.exclude("region", "iter", "model", "nchr") ) except: chr_start_end = np.zeros((regions.size, 3)) df_prediction = df_prediction.with_columns( pl.Series("chr", chr_start_end[:, 0].astype(str)), pl.Series("start", chr_start_end[:, 1], dtype=pl.Int64), pl.Series("end", chr_start_end[:, 2], dtype=pl.Int64), pl.Series( "nchr", pl.Series(chr_start_end[:, 0]), ), ) df_prediction = df_prediction.sort("nchr", "start").select( pl.exclude("region", "iter", "nchr") ) if self.output_folder is not None: if fname is not None: _output_prediction = f"{self.output_folder}/{fname}" df_prediction.write_csv(_output_prediction) df_prediction = df_prediction.select( [ "chr", "start", "end", "f_i", "f_t", "s", "t", "predicted_model", "prob_sweep", "prob_neutral", ] ) return df_prediction def roc_curve(self, _iter=1, show_plot=False): """ Build ROC curve, confusion matrix and training-history plots. Parameters ---------- _iter : int, default=1 Tag for output naming (kept for compatibility). Returns ------- tuple[matplotlib.figure.Figure, matplotlib.figure.Figure] ``(plot_roc, plot_history)`` figures. Confusion matrix is also saved to ``confusion_matrix.svg`` when ``output_folder`` is set. Notes ----- - AUC is computed treating ``'sweep'`` as the positive class. - The method expects :attr:`prediction` to contain the latest predictions including ``prob_sweep``. """ import matplotlib.pyplot as plt if isinstance(self.prediction, str): pred_data = pl.read_csv(self.prediction) else: pred_data = self.prediction pred_data = self.prediction # --- Confusion Matrix & Metrics --- y_true = pred_data["model"] y_pred = pred_data["predicted_model"] cm = confusion_matrix( y_true, y_pred, labels=["neutral", "sweep"], normalize="true" ) disp = ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=["neutral", "sweep"] ) cm_plot = disp.plot(cmap="Blues") accuracy = accuracy_score(y_true, y_pred) precision = precision_score(y_true, y_pred, pos_label="sweep") print("Confusion Matrix:\n", cm) print("Accuracy:", accuracy) print("Precision:", precision) # --- ROC Curve --- roc_auc_value = roc_auc_score( (y_true == "sweep").cast(int), pred_data["prob_sweep"].cast(float), ) fpr, tpr, _ = roc_curve( (y_true == "sweep").cast(int), pred_data["prob_sweep"].cast(float), ) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot( fpr, tpr, color="orange", linewidth=2, label=f"ROC Curve (AUC = {roc_auc_value:.3f})", ) ax.plot([0, 1], [0, 1], color="grey", linestyle="--") ax.set_xlabel("False Positive Rate") ax.set_ylabel("Sensitivity") ax.set_title("ROC Curve") ax.axis("equal") ax.grid(True, which="both", linestyle="--", linewidth=0.5) ax.legend() fig.tight_layout() plot_roc = fig # --- Training History --- history_data = self.history h = history_data.select( [ "loss", "val_loss", "accuracy", "val_accuracy", ] ).clone() h = h.with_columns((pl.arange(0, h.height) + 1).alias("epoch")) h_melted = h.unpivot( index=["epoch"], on=["loss", "val_loss", "accuracy", "val_accuracy"], variable_name="metric_name", value_name="metric_val", ) line_styles = { "loss": "-", "val_loss": "--", "accuracy": "-", "val_accuracy": "--", } colors = { "loss": "orange", "val_loss": "orange", "accuracy": "blue", "val_accuracy": "blue", } fig, ax = plt.subplots(figsize=(10, 6)) for group_name, group_df in h_melted.group_by("metric_name"): ax.plot( group_df["epoch"].to_numpy(), group_df["metric_val"].to_numpy(), label=group_name[0], linestyle=line_styles[group_name[0]], color=colors[group_name[0]], linewidth=2, ) ax.set_title("History") ax.set_xlabel("Epoch") ax.set_ylabel("Value") ax.tick_params(axis="both", labelsize=10) ax.grid(True) ax.legend(title="", loc="upper right") plot_history = fig ##################### y_true = (pred_data["model"] == "sweep").cast(int).to_numpy() y_score = pred_data["prob_sweep"].cast(float).to_numpy() pr, rc, _ = precision_recall_curve(y_true, y_score) auc_pr = auc(rc, pr) fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(rc, pr, linewidth=2, label=f"AUC-PR = {auc_pr:.3f}") ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title("Precision-Recall (positive = sweep)") ax.grid(True, linestyle="--", alpha=0.4) ax.legend(loc="lower left") fig.tight_layout() pr_curve = fig y_score_clip = np.clip(y_score, 1e-6, 1 - 1e-6) prob_true, prob_pred = calibration_curve( y_true, y_score_clip, n_bins=10, strategy="quantile" ) brier = brier_score_loss(y_true, y_score_clip) fig, ax = plt.subplots(figsize=(7, 5)) plt.plot([0, 1], [0, 1], "--", linewidth=1.5, label="perfect calibration") plt.plot( prob_pred, prob_true, marker="o", linewidth=2, label=f"model (Brier={brier:.3f})", ) plt.xlabel("Mean predicted probability (sweep)") plt.ylabel("Fraction of positives") plt.title("Calibration (Reliability Diagram)") plt.grid(True, alpha=0.4) plt.legend(loc="upper left") cal = fig # --- Save if needed --- if self.output_folder is not None: plot_roc.savefig(f"{self.output_folder}/roc_curve.svg") plot_history.savefig(f"{self.output_folder}/train_history.svg") pr_curve.savefig(f"{self.output_folder}/auprc.svg") cal.savefig(f"{self.output_folder}/calibration.svg") cm_plot.figure_.savefig(f"{self.output_folder}/confusion_matrix.svg") if show_plot: plt.show() else: plt.close("all") return plot_roc, plot_history, cm_plot def _select_stats_matrix(self, df: pl.DataFrame, stats: list[str]): # Standardize model: anything not 'neutral' -> 'sweep' df = df.with_columns( pl.when(pl.col("model") != "neutral") .then(pl.lit("sweep")) .otherwise(pl.lit("neutral")) .alias("model") ) blocks = [] windows_set = set(self.windows.tolist()) centers_set = set(self.center.tolist()) for stat in stats: blk = df.select(pl.col(f"^{stat}_[0-9]+_[0-9]+$")) cols = blk.columns keys = [] for col in cols: _, a, b = col.rsplit("_", 2) a, b = int(a), int(b) if a in windows_set and b in centers_set: wv, cv = a, b elif a in centers_set and b in windows_set: wv, cv = b, a else: cv, wv = a, b keys.append((wv, cv, col)) sorted_cols = [col for _, _, col in sorted(keys)] blocks.append(blk.select(sorted_cols)) X = pl.concat(blocks, how="horizontal") y = (df["model"] != "neutral").cast(pl.Int8).to_numpy().astype(np.float32) params = df.select("iter", "s", "t", "f_i", "f_t", "model") N = df.height X = ( X.to_numpy() .reshape(N, self.windows.size, self.center.size, len(stats)) .astype(np.float32) ) return X, y, params def load_da_data(self, _stats=None, src_val_frac=0.10): """ Prepares DA inputs for the binary (neutral=0, sweep=1) setup. Produces: - src_neutral_tr, src_sweep_tr : source train arrays per class - neutral_train_idx, sweep_train_idx : counts for generator slicing - X_tgt : unlabeled target (domain discriminator pool) - val_X, val_Y_class, val_Y_discr : validation set - test_data : (X_test, y_test, X_test_params) from held-out source """ # ---------- Load ---------- df_all = pl.read_parquet( self.source_data ) # labeled source with 'model' ∈ {'neutral','sweep'} tgt_df = pl.read_parquet(self.target_data) # target (may be unlabeled) # Hold-out from source for a final test set (kept as in your original code) (src_df, df_test) = train_test_split( df_all, test_size=(1 - self.train_split) * 0.5, shuffle=True ) stats = [] if _stats is None else list(_stats) # ---------- Source matrices ---------- X_src, y_src, _src_params = self._select_stats_matrix(src_df, stats) X_test, y_test, X_test_params = self._select_stats_matrix(df_test, stats) # Map labels to binary {0,1} if needed (accepts strings or ints) if y_src.ndim > 1 and y_src.shape[-1] == 2: # one-hot -> index y_src_bin = np.argmax(y_src, axis=-1).astype(np.int64) else: # strings or ints y_src_bin = np.array(y_src).reshape(-1) if y_src_bin.dtype.kind in {"U", "S", "O"}: map_dict = {"neutral": 0, "sweep": 1} y_src_bin = np.vectorize(map_dict.get)(y_src_bin).astype(np.int64) # Source train/val split for early stopping Xs_tr, Xs_va, ys_tr, ys_va = train_test_split( X_src, y_src_bin, test_size=src_val_frac, stratify=y_src_bin ) # Build class-specific source training arrays for the generator src_neutral_tr = Xs_tr[ys_tr == 0] src_sweep_tr = Xs_tr[ys_tr == 1] # ---------- Target matrix (unlabeled for discriminator) ---------- X_tgt, _yt_placeholder, tgt_params = self._select_stats_matrix(tgt_df, stats) # ---------- Validation set ---------- # source validation val_X = Xs_va val_Y_class = ys_va.astype(np.float32) val_Y_discr = -1 * np.ones((val_X.shape[0],), dtype=np.float32) # ---------- Package ---------- self.da_data = { "stats": stats, "src_neutral_tr": src_neutral_tr.astype(np.float32), "src_sweep_tr": src_sweep_tr.astype(np.float32), "X_tgt": X_tgt.astype(np.float32), # unlabeled target pool "tgt_params": tgt_params, # Validation (binary labels 0/1; discriminator masked with -1) "val_X": val_X.astype(np.float32), "val_Y_class": val_Y_class.astype(np.float32), "val_Y_discr": val_Y_discr.astype(np.float32), # Kept for downstream evaluation on held-out source "test_data": [ X_test.astype(np.float32), ( np.argmax(y_test, axis=-1) if (y_test.ndim > 1 and y_test.shape[-1] == 2) else y_test ).astype(np.int64), X_test_params, ], } def feature_extractor(self, model_input): """ Shared 2D-CNN feature extractor (three branches). Parameters ---------- model_input : tf.keras.layers.Input Input tensor of shape ``(W, C, S)``. Returns ------- tf.Tensor Flattened concatenated features from the three branches. See Also -------- cnn_flexsweep : Similar branch structure followed by classification head. """ He = tf.keras.initializers.HeNormal() # # ---- Channel Dropout on stats (drops whole statistic channels) ---- x = tf.keras.layers.SpatialDropout2D(0.10, name="fx_input_chdrop")(model_input) # x = model_input # ---- Stem: 1×1 mixes stats early to avoid single-stat shortcutting ---- x = tf.keras.layers.Conv2D( 64, 1, padding="same", kernel_initializer=He, name="fx_stem_conv" )(x) x = tf.keras.layers.BatchNormalization(name="fx_stem_bn")(x) x = tf.keras.layers.ReLU(name="fx_stem_relu")(x) # --- Branch 1: 3x3 convs --- b1 = tf.keras.layers.Conv2D( 64, 3, padding="same", kernel_initializer=He, name="fx_b1_c1" )(x) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.Conv2D( 128, 3, padding="same", kernel_initializer=He, name="fx_b1_c2" )(b1) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.Conv2D( 256, 3, padding="same", kernel_initializer=He, name="fx_b1_c3" )(b1) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.MaxPooling2D( pool_size=3, padding="same", name="fx_b1_pool" )(b1) b1 = tf.keras.layers.Dropout(0.15, name="fx_b1_drop")(b1) b1 = tf.keras.layers.Flatten(name="fx_b1_flat")(b1) # --- Branch 2: 2x2 convs with dilation (1,3) --- b2 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c1", )(x) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c2", )(b2) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c3", )(b2) b2 = tf.keras.layers.ReLU()(b2) # b2 = tf.keras.layers.MaxPooling2D(pool_size=2, name="fx_b2_pool")(b2) b2 = tf.keras.layers.MaxPooling2D( pool_size=(1, 2), padding="same", name="fx_b2_pool" )(b2) b2 = tf.keras.layers.Dropout(0.15, name="fx_b2_drop")(b2) b2 = tf.keras.layers.Flatten(name="fx_b2_flat")(b2) # --- Branch 3: 2x2 convs with dilation (5,1) then (1,5) --- b3 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[5, 1], padding="same", kernel_initializer=He, name="fx_b3_c1", )(x) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 5], padding="same", kernel_initializer=He, name="fx_b3_c2", )(b3) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 5], padding="same", kernel_initializer=He, name="fx_b3_c3", )(b3) b3 = tf.keras.layers.ReLU()(b3) # b3 = tf.keras.layers.MaxPooling2D(pool_size=2, name="fx_b3_pool")(b3) b3 = tf.keras.layers.MaxPooling2D( pool_size=(1, 2), padding="same", name="fx_b3_pool" )(b3) b3 = tf.keras.layers.Dropout(0.15, name="fx_b3_drop")(b3) b3 = tf.keras.layers.Flatten(name="fx_b3_flat")(b3) feat = tf.keras.layers.Concatenate(name="fx_concat")( [ b1, b2, b3, ] ) # shared representation return feat def build_grl_model(self, input_shape): """ Build a two-head domain-adversarial CNN with a Gradient Reversal Layer. Architecture ------------ - **Shared feature extractor**: :meth:`feature_extractor` over inputs shaped ``(W, C, S)`` (windows × centers × statistics), channels-last. - **Classifier head** (task): 2 dense layers + sigmoid output named ``"classifier"`` (sweep vs. neutral, BCE). - **Domain head**: GRL → 2 dense layers + sigmoid output named ``"discriminator"`` (source=0 vs. target=1, BCE). Parameters ---------- input_shape : tuple[int, int, int] ``(W, C, S)`` defining windows, centers, and number of stats (channels). Returns ------- tf.keras.Model Uncompiled Keras model with two outputs: ``[classifier(sigmoid), discriminator(sigmoid)]``. Notes ----- - The GRL instance is stored at ``self.grl`` so a callback (e.g., :class:`GRLRamp`) can update its strength during training. - Compilation (optimizer, losses, metrics) is performed in :meth:`train_da_empirical`. """ inp = tf.keras.Input(shape=input_shape) # (W, C, S), channels-last # x_in = ( # self.stat_norm_da(inp) # if hasattr(self, "stat_norm_da") and self.stat_norm_da is not None # else inp # ) feat = self.feature_extractor(inp) # classifier head h = tf.keras.layers.Dense(128, activation="relu")(feat) h = tf.keras.layers.Dropout(0.20)(h) h = tf.keras.layers.Dense(32, activation="relu")(h) h = tf.keras.layers.Dropout(0.10)(h) out_cls = tf.keras.layers.Dense(1, activation="sigmoid", name="classifier")(h) # domain head via GRL (store the layer for ramping) self.grl = GradReverse(lambd=0) g = self.grl(feat) g = tf.keras.layers.Dense(128, activation="relu")(g) g = tf.keras.layers.Dropout(0.20)(g) g = tf.keras.layers.Dense(32, activation="relu")(g) out_dom = tf.keras.layers.Dense(1, activation="sigmoid", name="discriminator")( g ) model = tf.keras.Model(inputs=inp, outputs=[out_cls, out_dom]) return model def feature_extractor_f(self, model_input): """ Shared 2D-CNN feature extractor (three branches). Parameters ---------- model_input : tf.keras.layers.Input Input tensor of shape ``(W, C, S)``. Returns ------- tf.Tensor Flattened concatenated features from the three branches. See Also -------- cnn_flexsweep : Similar branch structure followed by classification head. """ He = tf.keras.initializers.HeNormal() # --- Branch 1: 3x3 convs --- b1 = tf.keras.layers.Conv2D( 64, 3, padding="same", kernel_initializer=He, name="fx_b1_c1" )(model_input) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.Conv2D( 128, 3, padding="same", kernel_initializer=He, name="fx_b1_c2" )(b1) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.Conv2D( 256, 3, padding="same", kernel_initializer=He, name="fx_b1_c3" )(b1) b1 = tf.keras.layers.ReLU()(b1) b1 = tf.keras.layers.MaxPooling2D( pool_size=3, padding="same", name="fx_b1_pool" )(b1) b1 = tf.keras.layers.Dropout(0.15, name="fx_b1_drop")(b1) b1 = tf.keras.layers.Flatten(name="fx_b1_flat")(b1) # --- Branch 2: 2x2 convs with dilation (1,3) --- b2 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c1", )(model_input) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c2", )(b2) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 3], padding="same", kernel_initializer=He, name="fx_b2_c3", )(b2) b2 = tf.keras.layers.ReLU()(b2) b2 = tf.keras.layers.MaxPooling2D(pool_size=2, name="fx_b2_pool")(b2) b2 = tf.keras.layers.Dropout(0.15, name="fx_b2_drop")(b2) b2 = tf.keras.layers.Flatten(name="fx_b2_flat")(b2) # --- Branch 3: 2x2 convs with dilation (5,1) then (1,5) --- b3 = tf.keras.layers.Conv2D( 64, 2, dilation_rate=[5, 1], padding="same", kernel_initializer=He, name="fx_b3_c1", )(model_input) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.Conv2D( 128, 2, dilation_rate=[1, 5], padding="same", kernel_initializer=He, name="fx_b3_c2", )(b3) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.Conv2D( 256, 2, dilation_rate=[1, 5], padding="same", kernel_initializer=He, name="fx_b3_c3", )(b3) b3 = tf.keras.layers.ReLU()(b3) b3 = tf.keras.layers.MaxPooling2D(pool_size=2, name="fx_b3_pool")(b3) b3 = tf.keras.layers.Dropout(0.15, name="fx_b3_drop")(b3) b3 = tf.keras.layers.Flatten(name="fx_b3_flat")(b3) feat = tf.keras.layers.Concatenate(name="fx_concat")( [ b1, b2, b3, ] ) # shared representation return feat def build_grl_model_f(self, input_shape): inp = tf.keras.Input(shape=input_shape) # (W, C, S), channels-last feat = self.feature_extractor(inp) # classifier head h = tf.keras.layers.Dense(128, activation="relu")(feat) h = tf.keras.layers.Dropout(0.20)(h) h = tf.keras.layers.Dense(32, activation="relu")(h) h = tf.keras.layers.Dropout(0.10)(h) out_cls = tf.keras.layers.Dense(1, activation="sigmoid", name="classifier")(h) # domain head via GRL (store the layer for ramping) self.grl = GradReverse(lambd=0.0) g = self.grl(feat) g = tf.keras.layers.Dense(128, activation="relu")(g) g = tf.keras.layers.Dropout(0.20)(g) g = tf.keras.layers.Dense(32, activation="relu")(g) out_dom = tf.keras.layers.Dense(1, activation="sigmoid", name="discriminator")( g ) model = tf.keras.Model(inputs=inp, outputs=[out_cls, out_dom]) return model def build_grl_model_beta(self, input_shape, max_lambda): """ Build a two-head domain-adversarial CNN with a Gradient Reversal Layer. Architecture ------------ - **Shared feature extractor**: :meth:`feature_extractor` over inputs shaped ``(W, C, S)`` (windows × centers × statistics), channels-last. - **Classifier head** (task): 2 dense layers + sigmoid output named ``"classifier"`` (sweep vs. neutral, BCE). - **Domain head**: GRL → 2 dense layers + sigmoid output named ``"discriminator"`` (source=0 vs. target=1, BCE). Parameters ---------- input_shape : tuple[int, int, int] ``(W, C, S)`` defining windows, centers, and number of stats (channels). Returns ------- tf.keras.Model Uncompiled Keras model with two outputs: ``[classifier(sigmoid), discriminator(sigmoid)]``. Notes ----- - The GRL instance is stored at ``self.grl`` so a callback (e.g., :class:`GRLRamp`) can update its strength during training. - Compilation (optimizer, losses, metrics) is performed in :meth:`train_da_empirical`. """ inp = tf.keras.Input(shape=input_shape) # (W, C, S), channels-last # x_in = ( # self.stat_norm_da(inp) # if hasattr(self, "stat_norm_da") and self.stat_norm_da is not None # else inp # ) feat = self.feature_extractor(inp) # classifier head h = tf.keras.layers.Dense(128, activation="relu")(feat) h = tf.keras.layers.Dropout(0.20)(h) h = tf.keras.layers.Dense(32, activation="relu")(h) h = tf.keras.layers.Dropout(0.10)(h) out_cls = tf.keras.layers.Dense(1, activation="sigmoid", name="classifier")(h) self.grl = GradReverse(lambd=max_lambda) g = self.grl(feat) g = tf.keras.layers.Dense(128, activation="relu")(g) g = tf.keras.layers.Dropout(0.20)(g) g = tf.keras.layers.Dense(32, activation="relu")(g) out_dom = tf.keras.layers.Dense(1, activation="sigmoid", name="discriminator")( g ) model = tf.keras.Model(inputs=inp, outputs=[out_cls, out_dom]) return model def train_da_f( self, _stats=None, max_lambda=1, ramp_epochs=20, tgt_ratio=1, batch_size=32, preprocess=True, ): tf = self.check_tf() if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] self.predict_data = self.target_data if not hasattr(self, "da_data") or self.da_data is None: self.load_da_data(_stats=_stats) dd = self.da_data dd["src_neutral_tr"] = dd["src_neutral_tr"].reshape( dd["src_neutral_tr"].shape[0], self.windows.size * self.center.size, self.num_stats, 1, ) dd["src_sweep_tr"] = dd["src_sweep_tr"].reshape( dd["src_sweep_tr"].shape[0], self.windows.size * self.center.size, self.num_stats, 1, ) dd["val_X"] = dd["val_X"].reshape( dd["val_X"].shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) dd["test_data"][0] = dd["test_data"][0].reshape( dd["test_data"][0].shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) dd["X_tgt"] = dd["X_tgt"].reshape( dd["X_tgt"].shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) dd["val_X"] = dd["val_X"].reshape( dd["val_X"].shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) if preprocess: X_src_tr = np.vstack([dd["src_neutral_tr"], dd["src_sweep_tr"]]) self.mean = X_src_tr.mean(axis=(0, 1, 2)) self.std = X_src_tr.std(axis=(0, 1, 2)) dd["src_neutral_tr"] = self.preprocess( dd["src_neutral_tr"], training=True ).numpy() dd["src_sweep_tr"] = self.preprocess( dd["src_sweep_tr"], training=True ).numpy() dd["val_X"] = self.preprocess(dd["val_X"]).numpy() # normalize target using its own stats (as in your previous code) self.mean = dd["X_tgt"].mean(axis=(0, 1, 2)) self.std = dd["X_tgt"].std(axis=(0, 1, 2)) dd["X_tgt"] = self.preprocess(dd["X_tgt"], training=True).numpy() dd["test_data"][0] = self.preprocess( dd["test_data"][0], training=False ).numpy() val_X = dd["val_X"] val_Y_class = dd["val_Y_class"] # 0/1 (binary) val_Y_discr = dd["val_Y_discr"] # all -1 to mask discriminator on val data_gen = DAParquetSequence( src_neutral=dd["src_neutral_tr"], src_sweep=dd["src_sweep_tr"], tar_all=dd["X_tgt"], batch_size=batch_size, ) input_shape = (self.num_stats, self.windows.size * self.center.size, 1) model = self.build_grl_model_f(input_shape) opt = tf.keras.optimizers.AdamW( learning_rate=tf.keras.optimizers.schedules.CosineDecayRestarts(5e-5, 300), epsilon=1e-7, amsgrad=True, ) model.compile( optimizer=opt, loss={"classifier": masked_bce, "discriminator": masked_bce}, loss_weights={"classifier": 1.0, "discriminator": 1.0}, metrics={ "classifier": masked_binary_accuracy, "discriminator": masked_binary_accuracy, }, jit_compile=False, ) callbacks = [ GRLRamp(self.grl, max_lambda=max_lambda, epochs=ramp_epochs), LogGRLLambda(self.grl), tf.keras.callbacks.EarlyStopping( monitor="val_classifier_masked_binary_accuracy", mode="max", patience=20, min_delta=1e-4, restore_best_weights=True, ), ] if self.output_folder: callbacks.append( tf.keras.callbacks.ModelCheckpoint( f"{self.output_folder}/model_da.keras", monitor="val_classifier_masked_binary_accuracy", mode="max", save_best_only=True, verbose=1, ) ) hist = model.fit( data_gen, epochs=max(1000, ramp_epochs), steps_per_epoch=len(data_gen), validation_data=( val_X, {"classifier": val_Y_class, "discriminator": val_Y_discr}, ), callbacks=callbacks, verbose=2, ) # Logging with same keys you already read elsewhere hh = hist.history self.history = pl.DataFrame( { "loss": hh["loss"], "classifier_accuracy": hh["classifier_masked_binary_accuracy"], "discriminator_accuracy": hh["discriminator_masked_binary_accuracy"], "classifier_loss": hh["classifier_loss"], "discriminator_loss": hh["discriminator_loss"], "val_classifier_accuracy": hh["val_classifier_masked_binary_accuracy"], "val_discriminator_accuracy": hh[ "val_discriminator_masked_binary_accuracy" ], "val_classifier_loss": hh["val_classifier_loss"], "val_discriminator_loss": hh["val_discriminator_loss"], } ) self.model = model if self.output_folder: model.save(f"{self.output_folder}/model_da.keras") # quick eval on held-out source test for the plots you already have wired X_test, Y_test, X_test_params = dd["test_data"] out = model.predict(X_test, verbose=0, batch_size=32) cls = out[0] if isinstance(out, (list, tuple)) else out p = cls.ravel().astype(np.float32) # self._fit_platt(Y_test.astype(int), p) # self._save_calibration() # p_cal = self._apply_calibration(p) df_pred = ( pl.concat( [ X_test_params, pl.DataFrame( { "predicted_model": np.where(p >= 0.5, "sweep", "neutral"), "prob_sweep": p, "prob_neutral": 1.0 - p, } ), ], how="horizontal", ) .drop("model") .with_columns(pl.Series("model", np.where(Y_test == 1, "sweep", "neutral"))) ) self.prediction = df_pred return self.history def train_da( self, _stats=None, max_lambda=1, ramp_epochs=30, tgt_ratio=1, batch_size=128, preprocess=True, ): tf = self.check_tf() if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] self.predict_data = self.target_data if not hasattr(self, "da_data") or self.da_data is None: self.load_da_data(_stats=_stats) dd = self.da_data if preprocess: X_src_tr = np.vstack([dd["src_neutral_tr"], dd["src_sweep_tr"]]) self.mean = X_src_tr.mean(axis=(0, 1, 2)) self.std = X_src_tr.std(axis=(0, 1, 2)) dd["src_neutral_tr"] = self.preprocess( dd["src_neutral_tr"], training=True ).numpy() dd["src_sweep_tr"] = self.preprocess( dd["src_sweep_tr"], training=True ).numpy() dd["val_X"] = self.preprocess(dd["val_X"]).numpy() dd["test_data"][0] = self.preprocess(dd["test_data"][0]).numpy() # normalize target using its own stats (as in your previous code) self.mean = dd["X_tgt"].mean(axis=(0, 1, 2)) self.std = dd["X_tgt"].std(axis=(0, 1, 2)) dd["X_tgt"] = self.preprocess(dd["X_tgt"], training=True).numpy() dd["test_data"][0] = self.preprocess( dd["test_data"][0], training=False ).numpy() data_gen = DAParquetSequence( src_neutral=dd["src_neutral_tr"], src_sweep=dd["src_sweep_tr"], tar_all=dd["X_tgt"], batch_size=batch_size, tgt_ratio=tgt_ratio, ) val_X = dd["val_X"] val_Y_class = dd["val_Y_class"] # 0/1 (binary) val_Y_discr = dd["val_Y_discr"] # all -1 to mask discriminator on val # input_shape = (self.windows.size, self.center.size, self.num_stats) input_shape = dd["X_tgt"].shape[1:] model = self.build_grl_model(input_shape) # GRL λ fixed at 1.0 # {'max_lambda': 0.2843709709293154 # 'ramp_epochs': 64 # 'patience': 18 # 'batch_size': 128 # 'tgt_ratio': 1.651652308372508 # 'clip_value': 5.443662527929382 # 'lr': 0.000989613742860149 # 'weight_decay': 0.0003667748863292507 # 'loss_weight_discriminator': 0.5948347312577422} opt = tf.keras.optimizers.AdamW( learning_rate=tf.keras.optimizers.schedules.CosineDecayRestarts(5e-5, 300), # learning_rate=tf.keras.optimizers.schedules.CosineDecayRestarts(1e-3, 300), weight_decay=1e-4, # weight_decay=0.0003667748863292507, epsilon=1e-7, clipnorm=1.0, ) model.compile( optimizer=opt, loss={"classifier": masked_bce, "discriminator": masked_bce}, loss_weights={"classifier": 1.0, "discriminator": 0.75}, # loss_weights={"classifier": 1.0, "discriminator": 1.0}, metrics={ "classifier": masked_binary_accuracy, "discriminator": masked_binary_accuracy, }, jit_compile=False, ) callbacks = [ GRLRamp(self.grl, max_lambda=max_lambda, epochs=ramp_epochs), LogGRLLambda(self.grl), tf.keras.callbacks.EarlyStopping( monitor="val_classifier_masked_binary_accuracy", mode="max", patience=20, min_delta=1e-4, restore_best_weights=True, ), ] if self.output_folder: callbacks.append( tf.keras.callbacks.ModelCheckpoint( f"{self.output_folder}/model_da.keras", monitor="val_classifier_masked_binary_accuracy", mode="max", save_best_only=True, verbose=1, ) ) hist = model.fit( data_gen, epochs=max(1000, ramp_epochs), # epochs=30, steps_per_epoch=len(data_gen), validation_data=( val_X, {"classifier": val_Y_class, "discriminator": val_Y_discr}, ), callbacks=callbacks, verbose=2, ) # Logging with same keys you already read elsewhere hh = hist.history self.history = pl.DataFrame( { "loss": hh["loss"], "classifier_accuracy": hh["classifier_masked_binary_accuracy"], "discriminator_accuracy": hh["discriminator_masked_binary_accuracy"], "classifier_loss": hh["classifier_loss"], "discriminator_loss": hh["discriminator_loss"], "val_classifier_accuracy": hh["val_classifier_masked_binary_accuracy"], "val_discriminator_accuracy": hh[ "val_discriminator_masked_binary_accuracy" ], "val_classifier_loss": hh["val_classifier_loss"], "val_discriminator_loss": hh["val_discriminator_loss"], } ) self.model = model if self.output_folder: model.save(f"{self.output_folder}/model_da.keras") self.history.write_csv(f"{self.output_folder}/history_da.txt") # quick eval on held-out source test for the plots you already have wired X_test, Y_test, X_test_params = dd["test_data"] out = model.predict(X_test, verbose=0, batch_size=32) cls = out[0] if isinstance(out, (list, tuple)) else out p = cls.ravel().astype(np.float32) self._fit_platt(Y_test.astype(int), p) self._save_calibration() # p_cal = p p_cal = self._apply_calibration(p) df_pred = ( pl.concat( [ X_test_params, pl.DataFrame( { "predicted_model": np.where( p_cal >= 0.5, "sweep", "neutral" ), "prob_sweep": p_cal, "prob_neutral": 1.0 - p_cal, } ), ], how="horizontal", ) .drop("model") .with_columns(pl.Series("model", np.where(Y_test == 1, "sweep", "neutral"))) ) self.prediction = df_pred self.plot_da_curves() return self.history def train_da_beta( self, _stats=None, max_lambda=1, max_beta=1.0, ramp_epochs=31, tgt_ratio=1, batch_size=32, preprocess=True, ): tf = self.check_tf() if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] self.predict_data = self.target_data if not hasattr(self, "da_data") or self.da_data is None: self.load_da_data(_stats=_stats) dd = self.da_data if preprocess: X_src_tr = np.vstack([dd["src_neutral_tr"], dd["src_sweep_tr"]]) self.mean = X_src_tr.mean(axis=(0, 1, 2)) self.std = X_src_tr.std(axis=(0, 1, 2)) dd["src_neutral_tr"] = self.preprocess( dd["src_neutral_tr"], training=True ).numpy() dd["src_sweep_tr"] = self.preprocess( dd["src_sweep_tr"], training=True ).numpy() dd["val_X"] = self.preprocess(dd["val_X"]).numpy() dd["test_data"][0] = self.preprocess(dd["test_data"][0]).numpy() # normalize target using its own stats (as in your previous code) self.mean = dd["X_tgt"].mean(axis=(0, 1, 2)) self.std = dd["X_tgt"].std(axis=(0, 1, 2)) dd["X_tgt"] = self.preprocess(dd["X_tgt"], training=True).numpy() dd["test_data"][0] = self.preprocess( dd["test_data"][0], training=False ).numpy() data_gen = DAParquetSequence( src_neutral=dd["src_neutral_tr"], src_sweep=dd["src_sweep_tr"], tar_all=dd["X_tgt"], batch_size=batch_size, tgt_ratio=tgt_ratio, ) val_X = dd["val_X"] val_Y_class = dd["val_Y_class"] # 0/1 (binary) val_Y_discr = dd["val_Y_discr"] # all -1 to mask discriminator on val input_shape = (self.windows.size, self.center.size, self.num_stats) model = self.build_grl_model_beta(input_shape, max_lambda=max_lambda) opt = tf.keras.optimizers.AdamW( learning_rate=tf.keras.optimizers.schedules.CosineDecayRestarts(5e-5, 300), weight_decay=1e-4, epsilon=1e-7, clipnorm=5.0, ) # Variable loss weights (alpha fixed at 1.0, beta starts at 0.0 then ramp) alpha = K.variable(1.0, dtype="float32", name="alpha_cls") beta = K.variable(0.0, dtype="float32", name="beta_dom") model.compile( optimizer=opt, loss={"classifier": masked_bce, "discriminator": masked_bce}, loss_weights={"classifier": 1.0, "discriminator": 1.0}, metrics={ "classifier": masked_binary_accuracy, "discriminator": masked_binary_accuracy, }, jit_compile=False, ) # callbacks: your beta ramp & logger + early stopping callbacks = [ LossWeightsScheduler(alpha, beta), LossWeightsLogger([alpha, beta]), tf.keras.callbacks.EarlyStopping( monitor="val_classifier_masked_binary_accuracy", mode="max", patience=25, min_delta=1e-4, restore_best_weights=True, ), ] if self.output_folder: callbacks.append( tf.keras.callbacks.ModelCheckpoint( f"{self.output_folder}/model_da.keras", monitor="val_classifier_masked_binary_accuracy", mode="max", save_best_only=True, verbose=1, ) ) hist = model.fit( data_gen, epochs=30, steps_per_epoch=len(data_gen), validation_data=( val_X, {"classifier": val_Y_class, "discriminator": val_Y_discr}, ), callbacks=callbacks, verbose=2, ) hh = hist.history self.history = pl.DataFrame( { "loss": hh["loss"], "classifier_accuracy": hh["classifier_masked_binary_accuracy"], "discriminator_accuracy": hh["discriminator_masked_binary_accuracy"], "classifier_loss": hh["classifier_loss"], "discriminator_loss": hh["discriminator_loss"], "val_classifier_accuracy": hh["val_classifier_masked_binary_accuracy"], "val_discriminator_accuracy": hh[ "val_discriminator_masked_binary_accuracy" ], "val_classifier_loss": hh["val_classifier_loss"], "val_discriminator_loss": hh["val_discriminator_loss"], } ) self.model = model if self.output_folder: model.save(f"{self.output_folder}/model_da.keras") X_test, Y_test, X_test_params = dd["test_data"] out = model.predict(X_test, verbose=0, batch_size=32) cls = out[0] if isinstance(out, (list, tuple)) else out p = cls.ravel().astype(np.float32) # self._fit_platt(Y_test.astype(int), p) # self._save_calibration() # p_cal = self._apply_calibration(p) df_pred = ( pl.concat( [ X_test_params, pl.DataFrame( { "predicted_model": np.where(p >= 0.5, "sweep", "neutral"), "prob_sweep": p, "prob_neutral": 1.0 - p, } ), ], how="horizontal", ) .drop("model") .with_columns(pl.Series("model", np.where(Y_test == 1, "sweep", "neutral"))) ) self.prediction = df_pred return self.history def predict_da(self, _stats=None, preprocess=True, fname=None): """ Predict sweep probabilities on empirical (target) data using a DA model. Loads a trained two-head model and returns per-region predictions from the **classifier** head (sweep vs. neutral). The domain head is unused at inference. Parameters ---------- _stats : list[str] | None Statistic base names to include (must match training). Returns ------- pl.DataFrame Table with per-region predictions and metadata, including: ``['chr','start','end','f_i','f_t','s','t','predicted_model', 'prob_sweep','prob_neutral']`` sorted by chromosome and start. Raises ------ AssertionError If no model is loaded or the test data path is invalid. Notes ----- - Expects the same (W, C, S) layout used in training. - Output ``prob_sweep`` is the classifier sigmoid; ``prob_neutral=1-prob_sweep``. """ assert self.model is not None, "Call train_da() first" tf = self.check_tf() if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] if isinstance(self.model, str): model = tf.keras.models.load_model( self.model, safe_mode=True, ) else: model = self.model assert ( isinstance(self.predict_data, pl.DataFrame) or "txt" in self.predict_data or "csv" in self.predict_data or self.predict_data.endswith(".parquet") ), "Please input a pl.DataFrame or save it as CSV or parquet" try: df_test = pl.read_parquet(self.predict_data) if "test" in self.predict_data: df_test = df_test.sample( with_replacement=False, fraction=1.0, shuffle=True ) except Exception: df_test = pl.read_csv(self.predict_data, separator=",") df_test = df_test.with_columns( pl.when(pl.col("model") != "neutral") .then(pl.lit("sweep")) .otherwise(pl.lit("neutral")) .alias("model") ) regions = df_test["iter"].to_numpy() stats = [] if _stats is not None: stats = stats + _stats test_stats = [] self.num_stats = len(stats) for i in stats: test_stats.append(df_test.select(pl.col(f"^{i}_[0-9]+_[0-9]+$"))) X_test = pl.concat(test_stats, how="horizontal") test_X_params = df_test.select("model", "iter", "s", "f_i", "f_t", "t") test_X = ( X_test.select(X_test) .to_numpy() .reshape( X_test.shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) ) test_X = ( X_test.select(X_test) .to_numpy() .reshape( X_test.shape[0], self.windows.size, self.center.size, self.num_stats, ) ) if preprocess: self.mean = test_X.mean(axis=(0, 1, 2), keepdims=False) self.std = test_X.std(axis=(0, 1, 2), keepdims=False) test_X = self.preprocess(test_X) out = model.predict(test_X, batch_size=32) # two heads → [classifier_probs, discriminator_probs] cls = out[0] if isinstance(out, (list, tuple)) else out p = cls.ravel().astype(np.float32) p_cal = self._apply_calibration(p) # p_cal = p df_pred = pl.concat( [ test_X_params, pl.DataFrame( { "predicted_model": np.where(p_cal >= 0.5, "sweep", "neutral"), "prob_sweep_raw": p, "prob_sweep": p_cal, "prob_neutral": 1.0 - p_cal, } ), ], how="horizontal", ) df_prediction = df_pred.with_columns(pl.Series("region", regions)) chr_start_end = np.array( [item.replace(":", "-").split("-") for item in regions] ) df_prediction = df_prediction.with_columns( pl.Series("chr", chr_start_end[:, 0]), pl.Series("start", chr_start_end[:, 1], dtype=pl.Int64), pl.Series("end", chr_start_end[:, 2], dtype=pl.Int64), pl.Series( "nchr", pl.Series(chr_start_end[:, 0]).str.replace("chr", "").cast(int), ), ) df_prediction = df_prediction.sort("nchr", "start").select( [ "chr", "start", "end", "f_i", "f_t", "s", "t", "predicted_model", "prob_sweep", "prob_neutral", ] ) self.prediction = df_prediction if self.output_folder: # Same folder custom fvs name based on input VCF. _output_prediction = f"{self.output_folder}/{os.path.basename(self.predict_data).replace('fvs_', '').replace('.parquet', '_da_predictions.txt')}" if fname is not None: _output_prediction = f"{self.output_folder}/{fname}" df_prediction.write_csv(_output_prediction) return df_prediction def predict_da_f(self, _stats=None, preprocess=True, fname=None): """ Predict sweep probabilities on empirical (target) data using a DA model. Loads a trained two-head model and returns per-region predictions from the **classifier** head (sweep vs. neutral). The domain head is unused at inference. Parameters ---------- _stats : list[str] | None Statistic base names to include (must match training). Returns ------- pl.DataFrame Table with per-region predictions and metadata, including: ``['chr','start','end','f_i','f_t','s','t','predicted_model', 'prob_sweep','prob_neutral']`` sorted by chromosome and start. Raises ------ AssertionError If no model is loaded or the test data path is invalid. Notes ----- - Expects the same (W, C, S) layout used in training. - Output ``prob_sweep`` is the classifier sigmoid; ``prob_neutral=1-prob_sweep``. """ assert self.model is not None, "Call train_da() first" tf = self.check_tf() if _stats is None: _stats = [ "dind", "dist_kurtosis", "dist_skew", "dist_var", "h1", "h12", "h2_h1", "haf", "hapdaf_o", "hapdaf_s", "high_freq", "ihs", "isafe", "k_counts", "low_freq", "max_fda", "nsl", "omega_max", "pi", "s_ratio", "tajima_d", "theta_h", "theta_w", "zns", ] if isinstance(self.model, str): model = tf.keras.models.load_model( self.model, safe_mode=True, ) else: model = self.model assert ( isinstance(self.predict_data, pl.DataFrame) or "txt" in self.predict_data or "csv" in self.predict_data or self.predict_data.endswith(".parquet") ), "Please input a pl.DataFrame or save it as CSV or parquet" try: df_test = pl.read_parquet(self.predict_data) if "test" in self.predict_data: df_test = df_test.sample( with_replacement=False, fraction=1.0, shuffle=True ) except Exception: df_test = pl.read_csv(self.predict_data, separator=",") df_test = df_test.with_columns( pl.when(pl.col("model") != "neutral") .then(pl.lit("sweep")) .otherwise(pl.lit("neutral")) .alias("model") ) regions = df_test["iter"].to_numpy() stats = [] if _stats is not None: stats = stats + _stats test_stats = [] for i in stats: test_stats.append(df_test.select(pl.col(f"^{i}_[0-9]+_[0-9]+$"))) X_test = pl.concat(test_stats, how="horizontal") test_X_params = df_test.select("model", "iter", "s", "f_i", "f_t", "t") test_X = ( X_test.select(X_test) .to_numpy() .reshape( X_test.shape[0], self.num_stats, self.windows.size * self.center.size, 1, ) ) if preprocess: self.mean = test_X.mean(axis=(0, 1, 2), keepdims=False) self.std = test_X.std(axis=(0, 1, 2), keepdims=False) test_X = self.preprocess(test_X) out = model.predict(test_X, batch_size=32) # two heads → [classifier_probs, discriminator_probs] cls = out[0] if isinstance(out, (list, tuple)) else out p = cls.ravel().astype(np.float32) # p_cal = self._apply_calibration(p) p_cal = p df_pred = pl.concat( [ test_X_params, pl.DataFrame( { "predicted_model": np.where(p_cal >= 0.5, "sweep", "neutral"), "prob_sweep": p_cal, "prob_neutral": 1.0 - p_cal, } ), ], how="horizontal", ) df_prediction = df_pred.with_columns(pl.Series("region", regions)) chr_start_end = np.array( [item.replace(":", "-").split("-") for item in regions] ) df_prediction = df_prediction.with_columns( pl.Series("chr", chr_start_end[:, 0]), pl.Series("start", chr_start_end[:, 1], dtype=pl.Int64), pl.Series("end", chr_start_end[:, 2], dtype=pl.Int64), pl.Series( "nchr", pl.Series(chr_start_end[:, 0]).str.replace("chr", "").cast(int), ), ) df_prediction = df_prediction.sort("nchr", "start").select( [ "chr", "start", "end", "f_i", "f_t", "s", "t", "predicted_model", "prob_sweep", "prob_neutral", ] ) self.prediction = df_prediction if self.output_folder: # Same folder custom fvs name based on input VCF. _output_prediction = f"{self.output_folder}/{os.path.basename(self.predict_data).replace('fvs_', '').replace('.parquet', '_da_predictions.txt')}" if fname is not None: _output_prediction = f"{self.output_folder}/{fname}" df_prediction.write_csv(_output_prediction) return df_prediction def plot_da_curves(self): """ Saves: - classifier_accuracy_hist.png (train + val) - discriminator_accuracy_hist.png (train only) - classifier_loss_hist.png (train + val) - discriminator_loss_hist.png (train only) - classifier_auc_hist.png (optional ROC/PR AUC, train + val) - confusion_matrix.png, auprc.png, calibration_curve.png, probability_hist.png """ H = self.history outdir = self.output_folder or "." os.makedirs(outdir, exist_ok=True) def get(key): return H[key].to_numpy() if key in H.columns else np.array([]) # names aligned to new training logs loss = get("loss") cls_loss, val_cls_loss = get("classifier_loss"), get("val_classifier_loss") cls_acc = get("classifier_accuracy") val_cls_acc = get("val_classifier_accuracy") disc_loss = get("discriminator_loss") disc_acc = get("discriminator_accuracy") # def L(*arrs): # return max([len(a) for a in arrs if len(a) > 0] + [0]) # T = L(loss, cls_loss, disc_loss, cls_acc, val_cls_acc, val_cls_loss) epochs = np.arange(1, len(loss) + 1) def savefig(name): plt.tight_layout() plt.savefig(os.path.join(outdir, name), dpi=150) plt.close() def plot_series(y, label=None, ls="-", lw=2): if y.size: m = min(len(epochs), len(y)) yy = y[:m] mask = np.isfinite(yy) if mask.any(): plt.plot(epochs[:m][mask], yy[mask], ls, linewidth=lw, label=label) # classifier accuracy plt.figure(figsize=(7, 4)) plot_series(cls_acc, "train") plot_series(val_cls_acc, "val", ls="--") plt.title("classifier accuracy") plt.xlabel("epoch") plt.ylabel("accuracy") plt.grid(True, alpha=0.3) plt.legend(loc="lower right") savefig("classifier_accuracy_hist.png") # discriminator accuracy (train only) plt.figure(figsize=(7, 4)) plot_series(disc_acc) plt.title("discriminator accuracy (train)") plt.xlabel("epoch") plt.ylabel("accuracy") plt.grid(True, alpha=0.3) savefig("discriminator_accuracy_hist.png") # classifier loss plt.figure(figsize=(7, 4)) plot_series(cls_loss, "train") plot_series(val_cls_loss, "val", ls="--") plt.title("classifier loss") plt.xlabel("epoch") plt.ylabel("loss") plt.grid(True, alpha=0.3) plt.legend(loc="upper right") savefig("classifier_loss_hist.png") # discriminator loss (train only) plt.figure(figsize=(7, 4)) plot_series(disc_loss) plt.title("discriminator loss (train)") plt.xlabel("epoch") plt.ylabel("loss") plt.grid(True, alpha=0.3) savefig("discriminator_loss_hist.png") # # optional AUCs # if any(len(s) > 0 for s in [ ]): # plt.figure(figsize=(7, 4)) # plot_series(cls_auc, "ROC AUC (train)") # plot_series(val_cls_auc, "ROC AUC (val)", ls="--") # plot_series(cls_auc_pr, "PR AUC (train)") # plot_series(val_cls_auc_pr, "PR AUC (val)", ls="--") # plt.title("classifier AUCs") # plt.xlabel("epoch") # plt.ylabel("AUC") # plt.ylim(0, 1) # plt.grid(True, alpha=0.3) # plt.legend(loc="lower right") # savefig("classifier_auc_hist.png") # --- downstream prediction plots (unchanged) --- pred = self.prediction # Polars DF with: model, predicted_model, prob_sweep y_true_labels = pred["model"] y_pred_labels = pred["predicted_model"] cm = confusion_matrix( y_true_labels, y_pred_labels, labels=["neutral", "sweep"], normalize="true" ) disp = ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=["neutral", "sweep"] ) disp.plot(cmap="Blues") savefig("confusion_matrix.png") y_true = (pred["model"] == "sweep").cast(int).to_numpy() y_score = pred["prob_sweep"].cast(float).to_numpy() pr, rc, _ = precision_recall_curve(y_true, y_score) auc_pr = auc(rc, pr) fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(rc, pr, linewidth=2, label=f"AUC-PR = {auc_pr:.3f}") ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.set_title("Precision-Recall (positive = sweep)") ax.grid(True, linestyle="--", alpha=0.4) ax.legend(loc="lower left") fig.tight_layout() savefig("auprc.png") y_score_clip = np.clip(y_score, 1e-6, 1 - 1e-6) prob_true, prob_pred = calibration_curve( y_true, y_score_clip, n_bins=10, strategy="quantile" ) brier = brier_score_loss(y_true, y_score_clip) plt.figure(figsize=(7, 5)) plt.plot([0, 1], [0, 1], "--", linewidth=1.5, label="perfect calibration") plt.plot( prob_pred, prob_true, marker="o", linewidth=2, label=f"model (Brier={brier:.3f})", ) plt.xlabel("Mean predicted probability (sweep)") plt.ylabel("Fraction of positives") plt.title("Calibration (Reliability Diagram)") plt.grid(True, alpha=0.4) plt.legend(loc="upper left") savefig("calibration_curve.png") plt.figure(figsize=(7, 3.2)) plt.hist(y_score_clip, bins=20, range=(0, 1)) plt.xlabel("Predicted probability (sweep)") plt.ylabel("Count") plt.title("Prediction Probability Histogram") plt.grid(True, alpha=0.25) savefig("probability_hist.png") def _fit_platt(self, y, p): # Only fit on finite logit values mask = (p > 0) & (p < 1) X = logit(p[mask]).reshape(-1, 1) lr = LogisticRegression(solver="lbfgs") lr.fit(X, y[mask].astype(int)) a = float(lr.coef_[0, 0]) b = float(lr.intercept_[0]) self.calibration = {"type": "platt", "a": a, "b": b} def _fit_temperature(self, y, p): from scipy.optimize import minimize # Only fit on finite logit values mask = (p > 0) & (p < 1) z = logit(p[mask]) y_fit = y[mask] def nll(T): q = expit(z / T) # q is naturally bounded by expit, no clipping needed return -(y_fit * np.log(q) + (1 - y_fit) * np.log(1 - q)).mean() res = minimize(lambda t: nll(t[0]), x0=[1.0], bounds=[(0.5, 10.0)]) T = float(res.x[0]) self.calibration = {"type": "temperature", "T": T} def _apply_calibration(self, p): """Apply calibration only where mathematically defined.""" if getattr(self, "calibration", None) is None: return p p_cal = p.copy() # Only transform where logit is defined mask = (p > 0) & (p < 1) if not np.any(mask): return p cal = self.calibration if cal["type"] == "platt": a, b = cal["a"], cal["b"] p_cal[mask] = expit(a * logit(p[mask]) + b) elif cal["type"] == "temperature": T = cal["T"] p_cal[mask] = expit(logit(p[mask]) / T) # p=0 and p=1 pass through unchanged (not in mask) return p_cal def _save_calibration(self): if getattr(self, "output_folder", None): with open(os.path.join(self.output_folder, "calibration.json"), "w") as f: json.dump(self.calibration, f) def _load_calibration(self): try: with open(os.path.join(self.output_folder, "calibration.json")) as f: self.calibration = json.load(f) except Exception: self.calibration = None
class DAParquetSequence(Sequence): """ Domain-adversarial generator (binary: neutral=0, sweep=1) with adjustable target ratio. Matches CustomDataGenBinary's contract: __getitem__ -> X, {'classifier': y_cls, 'discriminator': y_discr} where labels are 1D float32 arrays with -1 as mask sentinel. Per step: - Classifier chunk (size = batch_size): SOURCE only (true labels 0/1), domain masked (-1). - Discriminator chunk (size = batch_size): SOURCE (domain=0) + TARGET (domain=1) split by `tgt_ratio` (target:source). With tgt_ratio=1, discriminator is 50/50. Total samples per step = 2 * batch_size (same as CustomDataGenBinary). """ def __init__( self, src_neutral, src_sweep, tar_all, batch_size, tgt_ratio=1.0, shuffle=True, **kwargs, ): super().__init__(**kwargs) assert ( batch_size % 2 == 0 ), "batch_size must be even (keeps math tidy; total per step = 2*batch_size)." self.shuffle = bool(shuffle) # Trim to train set self.src_neutral = src_neutral self.src_sweep = src_sweep self.tar_all = tar_all # Basic sizes # classifier chunk size self.B = int(batch_size) # discriminator chunk size per step (keeps total = 2*B) self.disc_chunk = self.B # --- ratio → discriminator split (target:source = tgt_ratio) --- self.tgt_ratio = max(1e-6, float(tgt_ratio)) # fraction of discriminator chunk to SOURCE src_frac = 1.0 / (1.0 + self.tgt_ratio) self.dis_src = max(1, int(round(self.disc_chunk * src_frac))) self.dis_tgt = max(1, self.disc_chunk - self.dis_src) # ensure >=1 from both # Build flat SOURCE pool (combine neutral + sweep) neu_idx = np.arange(self.src_neutral.shape[0], dtype=np.int64) swp_idx = np.arange(self.src_sweep.shape[0], dtype=np.int64) # 0=neu,1=sweep self.src_pool_tag = np.concatenate( [ np.zeros_like(neu_idx), np.ones_like(swp_idx), ] ) # local indices into arrays self.src_pool_lidx = np.concatenate([neu_idx, swp_idx]) # TARGET pool indices self.tgt_idx = np.arange(self.tar_all.shape[0], dtype=np.int64) # Build epoch pools (shuffled views) self._reset_epoch() # Steps/epoch limited by consumptions of each pool per step n_cls = len(self.src_pool_cls) // self.B n_dis_src = len(self.src_pool_dis) // self.dis_src n_dis_tgt = len(self.tgt_pool_dis) // self.dis_tgt self.n_batches = int(min(n_cls, n_dis_src, n_dis_tgt)) def __len__(self): return self.n_batches def _reset_epoch(self): rng = np.random.default_rng() # independent shuffles for classifier and discriminator source pools base = np.arange(self.src_pool_tag.size, dtype=np.int64) self.src_pool_cls = rng.permutation(base) # for classifier (labels used) self.src_pool_dis = rng.permutation(base) # for discriminator (domain=0) self.tgt_pool_dis = rng.permutation( self.tgt_idx ) # for discriminator (domain=1) def on_epoch_end(self): if self.shuffle: self._reset_epoch() def _gather_source_arrays(self, take): """Return concatenated X and per-sample class labels (0/1) for given flat-source indices.""" pools = self.src_pool_tag[take] lidx = self.src_pool_lidx[take] X_neu = self.src_neutral[lidx[pools == 0]] X_swp = self.src_sweep[lidx[pools == 1]] # Concatenate in stable order (neutral first then sweep) X = np.concatenate([X_neu, X_swp], axis=0) y = np.concatenate( [ np.zeros((X_neu.shape[0],), dtype=np.float32), np.ones((X_swp.shape[0],), dtype=np.float32), ], axis=0, ) return X, y def __getitem__(self, idx): # --- A) Classifier chunk: SOURCE (labels 0/1), domain masked --- idxA = self.src_pool_cls[idx * self.B : (idx + 1) * self.B] XA, yA = self._gather_source_arrays(idxA) yA_cls = yA # shape (B,) yA_dom = -np.ones((XA.shape[0],), dtype=np.float32) # mask domain # --- B) Discriminator chunk (SOURCE): domain=0, classifier masked --- idxB = self.src_pool_dis[idx * self.dis_src : (idx + 1) * self.dis_src] XB, _ = self._gather_source_arrays(idxB) yB_cls = -np.ones((XB.shape[0],), dtype=np.float32) # mask classifier yB_dom = np.zeros((XB.shape[0],), dtype=np.float32) # source domain=0 # --- C) Discriminator chunk (TARGET): domain=1, classifier masked --- idxC = self.tgt_pool_dis[idx * self.dis_tgt : (idx + 1) * self.dis_tgt] XC = self.tar_all[idxC] yC_cls = -np.ones((XC.shape[0],), dtype=np.float32) # mask classifier yC_dom = np.ones((XC.shape[0],), dtype=np.float32) # target domain=1 # --- Concatenate in the same order as CustomDataGenBinary --- X = np.concatenate([XA, XB, XC], axis=0) y_cls = np.concatenate([yA_cls, yB_cls, yC_cls], axis=0) y_dis = np.concatenate([yA_dom, yB_dom, yC_dom], axis=0) # Safety (total = 2*B even with arbitrary tgt_ratio because disc_chunk==B) assert X.shape[0] == 2 * self.B assert y_cls.shape[0] == y_dis.shape[0] == X.shape[0] return X, {"classifier": y_cls, "discriminator": y_dis} def subset_genomic_windows( df: pl.DataFrame, centers: list[int], metrics: list[str] | None = None, window_size: int = 100_000, update_iter: bool = True, ) -> pl.DataFrame: window_size = int(window_size) centers = np.sort(centers).astype(int) half = window_size // 2 # Validate consecutive centers if len(centers) > 1: diffs = np.diff(centers) if not np.all(diffs == window_size): raise ValueError("Centers must be consecutive and spaced by window_size") left_offset = centers[0] - half right_offset = centers[-1] + half # Define base columns to keep base_keep = ["iter", "s", "t", "f_i", "f_t", "mu", "r", "model"] # Resolve metrics if not provided if metrics is None: # Get columns that aren't in base_keep other_cols = [c for c in df.columns if c not in base_keep] # Extract base names (dist_var, etc) preserving order raw_names = ["_".join(col.split("_")[:-2]) for col in other_cols] metrics = list(dict.fromkeys(raw_names)) # Generate the specific windowed column names we expect expected_cols = [f"{m}_{window_size}_{c}" for c in centers for m in metrics] # Final column list: intersection of (base + expected) and what actually exists all_potential = base_keep + expected_cols existing_cols = [c for c in all_potential if c in df.columns] out = df.select(existing_cols) if update_iter and "iter" in out.columns: out = ( out.with_columns( pl.col("iter").str.extract(r"(chr\w+):(\d+)-", 1).alias("_chrom"), pl.col("iter") .str.extract(r":(\d+)-", 1) .cast(pl.Int64) .alias("_start"), ) .with_columns( pl.format( "{}:{}-{}", pl.col("_chrom"), pl.col("_start") + left_offset, pl.col("_start") + right_offset, ).alias("iter") ) .drop(["_chrom", "_start"]) ) return out