Knowledge-primed Neural Networks (KPNNs) for single cell data#

In this tutorial we will show how CORNETO can be used to build custom neural network architectures informed by prior knowledge. We will see how to implement a knowledge-primed neural network1. We will use the single cell data from the publication “Knowledge-primed neural networks enable biologically interpretable deep learning on single-cell sequencing data”, from Nikolaus Fortelny & Christoph Bock, where they used a single-cell RNA-seq dataset they previously generated2, which measures cellular responses to T cell receptor (TCR) stimulation in a standardized in vitro model. The dataset was chosen due to the TCR signaling pathway’s complexity and its well-characterized role in orchestrating transcriptional responses to antigen detection in T cells.

Why CORNETO?#

In the original publication, authors built a KPNN by searching on databases, building a Direct Acyclic Graph (DAG) by running shortest paths from TCR receptor to genes. However, this approach is not optimal. CORNETO, thanks to its advanced capabilities for modeling and optimization on networks, provides methods to automatically find DAG architectures in an optimal way.

In addition to this, CORNETO provides methods to build DAG NN architectures with ease using Keras +3, making KPNN implementation very flexible and interoperable with backends like Pytorch, Tensorflow and JAX.

How does it work?#

Thanks to CORNETO’s building blocks for optimization over networks, we can easily model optimization problems to find DAG architectures from a Prior Knowledge Network. After we have the backbone, we can convert it to a neural network using the utility functions included in CORNETO.

References#

  1. Fortelny, N., & Bock, C. (2020). Knowledge-primed neural networks enable biologically interpretable deep learning on single-cell sequencing data. Genome biology, 21, 1-36.

  2. Datlinger, P., Rendeiro, A. F., Schmidl, C., Krausgruber, T., Traxler, P., Klughammer, J., … & Bock, C. (2017). Pooled CRISPR screening with single-cell transcriptome readout. Nature methods, 14(3), 297-301.

Download and import the single cell dataset#

import os
import urllib.request
import urllib.parse
import tempfile
import pandas as pd
import scanpy as sc
import numpy as np
import corneto as cn

with urllib.request.urlopen("http://kpnn.computational-epigenetics.org/") as response:
    web_input = response.geturl()
print("Effective URL:", web_input)

files = ["TCR_Edgelist.csv", "TCR_ClassLabels.csv", "TCR_Data.h5"]

temp_dir = tempfile.mkdtemp()

# Download files
file_paths = []
for file in files:
    url = urllib.parse.urljoin(web_input, file)
    output_path = os.path.join(temp_dir, file)
    print(f"Downloading {url} to {output_path}")
    try:
        with urllib.request.urlopen(url) as response:
            with open(output_path, 'wb') as f:
                f.write(response.read())
        file_paths.append(output_path)
    except Exception as e:
        print(f"Failed to download {url}: {e}")

print("Downloaded files:")
for path in file_paths:
    print(path)
Effective URL: https://medical-epigenomics.org/papers/fortelny2019/
Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_Edgelist.csv to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Edgelist.csv
Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_ClassLabels.csv to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_ClassLabels.csv
Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_Data.h5 to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Data.h5
Downloaded files:
/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Edgelist.csv
/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_ClassLabels.csv
/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Data.h5
# The data contains also the original network they built with shortest paths.
# We will use it to replicate the study
df_edges = pd.read_csv(file_paths[0])
df_labels = pd.read_csv(file_paths[1])
# Import the 10x data with Scanpy
adata = sc.read_10x_h5(file_paths[2])
df_labels
barcode TCR
0 AAACCTGCACACATGT-1 0
1 AAACCTGCACGTCTCT-1 0
2 AAACCTGTCAATACCG-1 0
3 AAACCTGTCGTGGTCG-1 0
4 AAACGGGTCTGAGTGT-1 0
... ... ...
1730 TTTCCTCGTCATGCCG-2 1
1731 TTTGCGCGTAGCCTCG-2 1
1732 TTTGGTTAGATACACA-2 1
1733 TTTGGTTGTATGAATG-2 1
1734 TTTGGTTTCCAAGTAC-2 1

1735 rows × 2 columns

df_edges
parent child
0 TCR ZAP70
1 ZAP70 MAPK14
2 MAPK14 FOXO3
3 MAPK14 STAT1
4 MAPK14 STAT3
... ... ...
27574 HMGA1 MTRNR2L9_gene
27575 MYB C12orf50_gene
27576 MYB TRPC5OS_gene
27577 SOX2 TRPC5OS_gene
27578 CRTC1 MTRNR2L9_gene

27579 rows × 2 columns

adata.var
gene_ids
DDX11L1 ENSG00000223972
WASH7P ENSG00000227232
MIR6859-2 ENSG00000278267
MIR1302-10 ENSG00000243485
MIR1302-11 ENSG00000274890
... ...
Tcrlibrary_RUNX2_3_gene Tcrlibrary_RUNX2_3_gene
Tcrlibrary_ZAP70_1_gene Tcrlibrary_ZAP70_1_gene
Tcrlibrary_ZAP70_2_gene Tcrlibrary_ZAP70_2_gene
Tcrlibrary_ZAP70_3_gene Tcrlibrary_ZAP70_3_gene
Cas9_blast_gene Cas9_blast_gene

64370 rows × 1 columns

# We can normalize the data, however, it is better to avoid
# preprocessing the whole dataset before splitting in training and test
# to avoid data leakage.
# NOTE: Normalization can be done inside the cross-val loop
# sc.pp.normalize_total(adata, target_sum=1e6)

# Log-transform the data does not leak data as it does not estimate anything
sc.pp.log1p(adata)
adata.obs
AAACCTGAGAAACCAT-1
AAACCTGAGAAACCGC-1
AAACCTGAGAAACCTA-1
AAACCTGAGAAACGAG-1
AAACCTGAGAAACGCC-1
...
TTTGTCATCTTTACAC-2
TTTGTCATCTTTACGT-2
TTTGTCATCTTTAGGG-2
TTTGTCATCTTTAGTC-2
TTTGTCATCTTTCCTC-2

1474560 rows × 0 columns

barcodes = adata.obs_names
barcodes
Index(['AAACCTGAGAAACCAT-1', 'AAACCTGAGAAACCGC-1', 'AAACCTGAGAAACCTA-1',
       'AAACCTGAGAAACGAG-1', 'AAACCTGAGAAACGCC-1', 'AAACCTGAGAAAGTGG-1',
       'AAACCTGAGAACAACT-1', 'AAACCTGAGAACAATC-1', 'AAACCTGAGAACTCGG-1',
       'AAACCTGAGAACTGTA-1',
       ...
       'TTTGTCATCTTGGGTA-2', 'TTTGTCATCTTGTACT-2', 'TTTGTCATCTTGTATC-2',
       'TTTGTCATCTTGTCAT-2', 'TTTGTCATCTTGTTTG-2', 'TTTGTCATCTTTACAC-2',
       'TTTGTCATCTTTACGT-2', 'TTTGTCATCTTTAGGG-2', 'TTTGTCATCTTTAGTC-2',
       'TTTGTCATCTTTCCTC-2'],
      dtype='object', length=1474560)
gene_names = adata.var.index
print(gene_names)
Index(['DDX11L1', 'WASH7P', 'MIR6859-2', 'MIR1302-10', 'MIR1302-11', 'FAM138A',
       'OR4G4P', 'OR4G11P', 'OR4F5', 'RP11-34P13.7',
       ...
       'Tcrlibrary_RUNX1_1_gene', 'Tcrlibrary_RUNX1_2_gene',
       'Tcrlibrary_RUNX1_3_gene', 'Tcrlibrary_RUNX2_1_gene',
       'Tcrlibrary_RUNX2_2_gene', 'Tcrlibrary_RUNX2_3_gene',
       'Tcrlibrary_ZAP70_1_gene', 'Tcrlibrary_ZAP70_2_gene',
       'Tcrlibrary_ZAP70_3_gene', 'Cas9_blast_gene'],
      dtype='object', length=64370)
len(set(df_labels.barcode.tolist()))
1735
len(set(barcodes.tolist()))
1474560
matched_barcodes = sorted(set(barcodes.tolist()) & set(df_labels.barcode.tolist()))
len(matched_barcodes)
1735
#This is the InPathsY data in the original code of KPNNs
df_labels
barcode TCR
0 AAACCTGCACACATGT-1 0
1 AAACCTGCACGTCTCT-1 0
2 AAACCTGTCAATACCG-1 0
3 AAACCTGTCGTGGTCG-1 0
4 AAACGGGTCTGAGTGT-1 0
... ... ...
1730 TTTCCTCGTCATGCCG-2 1
1731 TTTGCGCGTAGCCTCG-2 1
1732 TTTGGTTAGATACACA-2 1
1733 TTTGGTTGTATGAATG-2 1
1734 TTTGGTTTCCAAGTAC-2 1

1735 rows × 2 columns

Import PKN with CORNETO#

import corneto as cn
cn.info()
Installed version:v1.0.0.dev1 (latest stable: v1.0.0-alpha)
Available backends:CVXPY v1.6.0
Default backend (corneto.opt):CVXPY
Installed solvers:CLARABEL, CVXOPT, GLPK, GLPK_MI, GUROBI, HIGHS, SCIP, SCIPY
Graphviz version:v0.20.3
Installed path:/Users/pablorodriguezmier/Documents/work/projects/corneto/corneto
Repository:https://github.com/saezlab/corneto
outputs_pkn = list(set(df_edges.parent.tolist()) - set(df_edges.child.tolist()))
inputs_pkn = set(df_edges.child.tolist()) - set(df_edges.parent.tolist())
input_pkn_genes = list(set(g.split("_")[0] for g in inputs_pkn))
len(inputs_pkn), len(outputs_pkn)
(13121, 1)
tuples = [(r.child, 1, r.parent) for _, r in df_edges.iterrows()]
G = cn.Graph.from_sif_tuples(tuples)
G = G.prune(inputs_pkn, outputs_pkn)

# Size of the original PKN provided by the authors
G.shape
(13439, 27579)

Select the single cell data for training#

adata_matched = adata[adata.obs_names.isin(matched_barcodes), adata.var_names.isin(input_pkn_genes)]
adata_matched.shape
(1735, 14229)
non_zero_genes = set(adata_matched.to_df().columns[adata_matched.to_df().sum(axis=0) >= 1e-6].values)
len(non_zero_genes)
12459
len(non_zero_genes.intersection(adata_matched.var_names))
12459
adata_matched = adata_matched[:, adata_matched.var_names.isin(non_zero_genes)]
# Many duplicates still 0 counts
adata_matched = adata_matched[:, adata_matched.to_df().sum(axis=0) != 0]
adata_matched.shape
(1735, 12487)
df_expr = adata_matched.to_df()
df_expr = df_expr.groupby(df_expr.columns, axis=1).max()
df_expr
A1BG A2ML1 AAAS AACS AADAT AAED1 AAGAB AAK1 AAMDC AAMP ... ZSWIM8 ZUFSP ZW10 ZWILCH ZXDC ZYG11A ZYG11B ZYX ZZEF1 ZZZ3
AAACCTGCACACATGT-1 0.0 0.0 0.000000 0.693147 0.0 0.000000 0.000000 0.000000 0.0 1.098612 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACCTGCACGTCTCT-1 0.0 0.0 0.000000 0.000000 0.0 0.693147 0.000000 0.000000 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACCTGTCAATACCG-1 0.0 0.0 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACCTGTCGTGGTCG-1 0.0 0.0 1.098612 0.000000 0.0 0.000000 0.000000 0.000000 0.0 1.098612 ... 0.000000 0.693147 0.000000 0.693147 0.0 0.0 0.0 0.000000 0.000000 0.000000
AAACGGGTCTGAGTGT-1 0.0 0.0 0.000000 0.000000 0.0 0.693147 0.000000 0.000000 0.0 0.000000 ... 0.693147 0.000000 0.693147 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
TTTCCTCGTCATGCCG-2 0.0 0.0 0.000000 0.000000 0.0 0.000000 0.000000 0.693147 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
TTTGCGCGTAGCCTCG-2 0.0 0.0 1.098612 0.000000 0.0 0.000000 0.000000 0.000000 0.0 1.386294 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.000000
TTTGGTTAGATACACA-2 0.0 0.0 1.098612 0.000000 0.0 0.000000 0.000000 1.098612 0.0 0.693147 ... 0.693147 0.000000 0.693147 0.000000 0.0 0.0 0.0 0.693147 0.693147 0.000000
TTTGGTTGTATGAATG-2 0.0 0.0 0.000000 0.000000 0.0 0.693147 1.098612 0.000000 0.0 0.693147 ... 0.000000 0.693147 0.000000 0.000000 0.0 0.0 0.0 0.000000 0.000000 0.693147
TTTGGTTTCCAAGTAC-2 0.0 0.0 0.000000 0.000000 0.0 0.000000 0.000000 0.000000 0.0 0.693147 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 1.098612 0.000000 0.000000

1735 rows × 12459 columns

Building and training the KPNN#

Now we will use the provided PKN by the authors and the utility functions in CORNETO to build a KPNN similar to the one used in the original manuscript

import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
import numpy as np
# Use the data from the experiment
X = df_expr.values
y = df_labels.set_index("barcode").loc[df_expr.index, "TCR"].values
X.shape, y.shape
((1735, 12459), (1735,))
# We can prefilter on top N genes to make this faster
top_n = None

# From the given PKN
outputs_pkn = list(set(df_edges.parent.tolist()) - set(df_edges.child.tolist()))
inputs_pkn = set(df_edges.child.tolist()) - set(df_edges.parent.tolist())
input_pkn_genes = list(set(g.split("_")[0] for g in inputs_pkn))

if top_n is not None and top_n > 0:
    input_pkn_genes = list(set(input_pkn_genes).intersection(df_expr.var(axis=0).sort_values(ascending=False).head(top_n).index))
    inputs_pkn = list(g + "_gene" for g in input_pkn_genes)

len(inputs_pkn), len(outputs_pkn)
(13121, 1)
input_nn_genes = list(set(input_pkn_genes).intersection(df_expr.columns))
input_nn = [g + "_gene" for g in input_nn_genes]
len(input_nn)
12459
# Build corneto graph
tuples = [(r.child, 1, r.parent) for _, r in df_edges.iterrows()]
G = cn.Graph.from_sif_tuples(tuples)
G = G.prune(input_nn, outputs_pkn)
G.shape
(12767, 25928)
len(input_nn), len(input_nn_genes)
(12459, 12459)
len(set(input_nn).intersection(G.V))
12459
X = df_expr.loc[:, input_nn_genes].values
y = df_labels.set_index("barcode").loc[df_expr.index, "TCR"].values
X.shape, y.shape
((1735, 12459), (1735,))
from corneto._ml import build_dagnn

def stratified_kfold(
    G,
    inputs,
    outputs,
    n_splits=5,
    shuffle=True,
    random_state=42,
    lr=0.001,
    patience=10,
    file_weights="weights",
    dagnn_config=dict(
        batch_norm_input=True,
        batch_norm_center=False,
        batch_norm_scale=False,
        bias_reg_l1=1e-3,
        bias_reg_l2=1e-2,
        dropout=0.20,
        default_hidden_activation="sigmoid",
        default_output_activation="sigmoid",
        verbose=False
    )
):
    kfold = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
    models = []
    metrics = {m: [] for m in ["accuracy", "precision", "recall", "f1", "roc_auc"]}
    for i, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
    
        print("Building DAG NN model with CORNETO using Keras with JAX...")
        print(f" > N. inputs: {len(input_nn)}")
        print(f" > N. outputs: {len(outputs_pkn)}")
        model = build_dagnn(
            G, 
            input_nn, 
            outputs_pkn,
            **dagnn_config
        )
        print(f" > N. parameters: {model.count_params()}")
    
        # Train the model with Adam
        opt=keras.optimizers.Adam(learning_rate=lr)
        early_stopping = EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True)
        print("Compiling...")
        model.compile(
            optimizer=opt,
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        print("Fitting...")
        model.fit(X_train, y_train,
                  validation_data=(X_val, y_val),
                  epochs=200,
                  batch_size=64,
                  verbose=0,
                  callbacks=[early_stopping])
        
        if file_weights is not None:
            filename = f"{file_weights}_{i}.keras"
            model.save(filename)
            print(f"Weights saved to {filename}")
    
        # Predictions and metrics calculation
        y_pred_proba = model.predict(X_val).flatten()
        y_pred = (y_pred_proba > 0.5).astype(int)
        acc = accuracy_score(y_val, y_pred)
        precision = precision_score(y_val, y_pred)
        recall = recall_score(y_val, y_pred)
        f1 = f1_score(y_val, y_pred)
        roc_auc = roc_auc_score(y_val, y_pred_proba)
        metrics["accuracy"].append(acc)
        metrics["precision"].append(precision)
        metrics["recall"].append(recall)
        metrics["f1"].append(f1)
        metrics["roc_auc"].append(roc_auc)
        print(f" > Fold {i} validation ROC-AUC={roc_auc:.3f}")
        models.append(model)
    return models, metrics

temp_weights = tempfile.mkdtemp()
models, metrics = stratified_kfold(G, input_nn, outputs_pkn, file_weights=os.path.join(temp_weights, "weights"))

print("Validation metrics:")
for k, v in metrics.items():
    print(f" - {k}: {np.mean(v):.3f}")
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_0.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 4s 208ms/step
 > Fold 0 validation ROC-AUC=0.991
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_1.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 4s 180ms/step
 > Fold 1 validation ROC-AUC=0.983
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_2.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 4s 179ms/step
 > Fold 2 validation ROC-AUC=0.986
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_3.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 4s 180ms/step
 > Fold 3 validation ROC-AUC=0.993
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 26236
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_4.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 4s 181ms/step
 > Fold 4 validation ROC-AUC=0.994
Validation metrics:
 - accuracy: 0.957
 - precision: 0.961
 - recall: 0.951
 - f1: 0.956
 - roc_auc: 0.989

Now, we will analyze the learned biases for each of the nodes in the graph. Note that authors in the KPNN paper explain a way to extract weights for the nodes, based on the learned interactions and accounting for biases in the structure of the NN. Here we just show the learned biases of the nodes of the NN across 5 folds. Please be careful interpreting these weights.

# We collect the weights obtained in each fold

def load_biases(file="weights", folds=5):
    biases = []
    mean_inputs = []
    for i in range(5):
        model = keras.models.load_model(f"{file}_{i}.keras")
        for layer in model.layers:
            weights = layer.get_weights()
            if weights:
                biases.append((i, layer.name, weights[1][0]))
                mean_inputs.append((i, layer.name, weights[0].mean()))
    df_biases = pd.DataFrame(biases, columns=["fold", "gene", "bias"])
    df_biases["abs_bias"] = df_biases.bias.abs()
    df_biases["pow2_bias"] = df_biases.bias.pow(2)
    df_biases = df_biases.set_index(["fold","gene"])
    return df_biases

df_biases = load_biases(file=os.path.join(temp_weights, "weights"), folds=5)
df_biases.sort_values(by="abs_bias", ascending=False).head(10)
bias abs_bias pow2_bias
fold gene
1 DUSP3 -2.632734 2.632734 6.931287
2 PRKCD -2.197672 2.197672 4.829763
SUZ12.EZH2 -2.126150 2.126150 4.520514
NfKb.p65.p50 -2.010682 2.010682 4.042840
TCR 1.988334 1.988334 3.953472
4 ZAP70 -1.952261 1.952261 3.811324
3 YAP1 -1.872004 1.872004 3.504399
4 NfKb.p65.p50 -1.834747 1.834747 3.366298
0 PRC2 -1.826627 1.826627 3.336568
4 TCR 1.816516 1.816516 3.299730
df_biases_full = df_biases.copy().reset_index()
gene_biases_score = df_biases_full.groupby("gene")["pow2_bias"].mean().sort_values(ascending=False)
gene_biases_score
gene
TCR             3.129825
NfKb.p65.p50    2.217130
PRKCD           2.162571
PRC2            2.140556
DUSP3           1.901882
                  ...   
MLL2.complex    0.000641
HSF2            0.000268
GATA3           0.000145
SRY             0.000138
REST            0.000032
Name: pow2_bias, Length: 308, dtype: float32
gene_biases_score.head(30).plot.bar()
<Axes: xlabel='gene'>
../_images/0d9829abfcea0d2d16e99141ddbdf81fb5878d40ac3436be0e842c301c9ead99.png
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Get the top genes sorted by mean bias
top_genes = gene_biases_score.head(30).index

# Filter the DataFrame to include only rows with these top genes
filtered_df = df_biases_full[df_biases_full['gene'].isin(top_genes)]

plt.figure(figsize=(10, 6))
sns.violinplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
sns.stripplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
plt.xticks(rotation=90)
plt.title('Biases of the trained neurons')
plt.xlabel('Gene')
plt.ylabel('Bias')
plt.axhline(0, linestyle="--", color="k")
plt.tight_layout()
../_images/ccfb10fea13558153d22504dbd4583c1bb936b18b86a02607e58399921b37426.png

Use CORNETO for NN pruning#

Now we will show how CORNETO can be used to extract a smaller, yet complete DAG from the original PKN provided by the authors. We will add input edges to each input node and an output edge through TCR to indicate which nodes are the inputs and which one the output. We will use then Acyclic Flow to find the smallest DAG comprising these nodes

G_dag = G.copy()
new_edges = []
for g in input_nn:
    new_edges.append(G_dag.add_edge((), g))
new_edges.append(G_dag.add_edge("TCR", ()))
print(G_dag.shape)

# Find small DAG. We use Acyclic Flow to find over the space of DAGs
P = cn.opt.AcyclicFlow(G_dag)
# We enforce that the input genes and the output gene are part of the solution
P += P.expr.with_flow[new_edges] == 1
# Minimize the number of active edges
P.add_objectives(sum(P.expr.with_flow), weights=1)
P.solve(solver="SCIP", verbosity=1, max_seconds=300);
(12767, 38388)
===============================================================================
                                     CVXPY                                     
                                     v1.6.0                                    
===============================================================================
(CVXPY) Dec 23 01:19:36 PM: Your problem has 127931 variables, 332945 constraints, and 0 parameters.
(CVXPY) Dec 23 01:19:36 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) Dec 23 01:19:36 PM: (If you need to solve this problem multiple times, but with different data, consider using parameters.)
(CVXPY) Dec 23 01:19:36 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
(CVXPY) Dec 23 01:19:36 PM: Your problem is compiled with the CPP canonicalization backend.
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
(CVXPY) Dec 23 01:19:37 PM: Compiling problem (target solver=SCIP).
(CVXPY) Dec 23 01:19:37 PM: Reduction chain: Dcp2Cone -> CvxAttr2Constr -> ConeMatrixStuffing -> SCIP
(CVXPY) Dec 23 01:19:37 PM: Applying reduction Dcp2Cone
(CVXPY) Dec 23 01:19:37 PM: Applying reduction CvxAttr2Constr
(CVXPY) Dec 23 01:19:37 PM: Applying reduction ConeMatrixStuffing
(CVXPY) Dec 23 01:20:56 PM: Applying reduction SCIP
(CVXPY) Dec 23 01:20:56 PM: Finished problem compilation (took 7.962e+01 seconds).
-------------------------------------------------------------------------------
                                Numerical solver                               
-------------------------------------------------------------------------------
(CVXPY) Dec 23 01:20:56 PM: Invoking solver SCIP  to obtain a solution.
presolving:
(round 1, fast)       72218 del vars, 263764 del conss, 0 add conss, 158463 chg bounds, 17010 chg sides, 17010 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 2, fast)       83111 del vars, 288686 del conss, 0 add conss, 158636 chg bounds, 17982 chg sides, 17982 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 3, fast)       83299 del vars, 288858 del conss, 0 add conss, 158667 chg bounds, 18157 chg sides, 18150 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 4, fast)       83323 del vars, 288959 del conss, 0 add conss, 158670 chg bounds, 18308 chg sides, 18242 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 5, fast)       87954 del vars, 288998 del conss, 0 add conss, 158670 chg bounds, 18310 chg sides, 18242 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 6, fast)       87954 del vars, 288998 del conss, 0 add conss, 169675 chg bounds, 18310 chg sides, 18242 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 7, fast)       87976 del vars, 289373 del conss, 0 add conss, 169675 chg bounds, 19713 chg sides, 29203 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 8, fast)       87993 del vars, 289470 del conss, 0 add conss, 169675 chg bounds, 19713 chg sides, 29203 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 9, exhaustive) 87994 del vars, 300376 del conss, 0 add conss, 169708 chg bounds, 19713 chg sides, 29203 chg coeffs, 0 upgd conss, 0 impls, 0 clqs
(round 10, fast)       98924 del vars, 311359 del conss, 0 add conss, 169708 chg bounds, 19722 chg sides, 29210 chg coeffs, 0 upgd conss, 4004 impls, 0 clqs
(round 11, fast)       99966 del vars, 312534 del conss, 0 add conss, 169708 chg bounds, 19722 chg sides, 29210 chg coeffs, 0 upgd conss, 4004 impls, 0 clqs
(round 12, fast)       99978 del vars, 312581 del conss, 0 add conss, 169709 chg bounds, 19722 chg sides, 29210 chg coeffs, 0 upgd conss, 4004 impls, 0 clqs
(round 13, exhaustive) 99980 del vars, 312589 del conss, 0 add conss, 169715 chg bounds, 19724 chg sides, 29212 chg coeffs, 9689 upgd conss, 4004 impls, 0 clqs
(round 14, fast)       99988 del vars, 312611 del conss, 0 add conss, 169715 chg bounds, 19724 chg sides, 29212 chg coeffs, 9689 upgd conss, 14702 impls, 925 clqs
(round 15, exhaustive) 99999 del vars, 312621 del conss, 0 add conss, 169718 chg bounds, 19726 chg sides, 29214 chg coeffs, 17411 upgd conss, 14704 impls, 908 clqs
(round 16, medium)     100057 del vars, 312642 del conss, 0 add conss, 169718 chg bounds, 19726 chg sides, 29214 chg coeffs, 17411 upgd conss, 22425 impls, 1333 clqs
(round 17, medium)     100059 del vars, 312756 del conss, 0 add conss, 169718 chg bounds, 19726 chg sides, 29214 chg coeffs, 17411 upgd conss, 22425 impls, 1333 clqs
(round 18, exhaustive) 100754 del vars, 312760 del conss, 0 add conss, 169720 chg bounds, 19726 chg sides, 29214 chg coeffs, 17413 upgd conss, 22425 impls, 1333 clqs
(round 19, fast)       100756 del vars, 312828 del conss, 0 add conss, 169720 chg bounds, 19726 chg sides, 29219 chg coeffs, 17413 upgd conss, 22427 impls, 1337 clqs
   (3.0s) probing: 1000/17563 (5.7%) - 0 fixings, 1 aggregations, 370 implications, 2 bound changes
   (3.0s) probing: 1003/17563 (5.7%) - 0 fixings, 1 aggregations, 370 implications, 2 bound changes
   (3.0s) probing aborted: 1000/1000 successive useless probings
   (3.0s) symmetry computation started: requiring (bin +, int +, cont +), (fixed: bin -, int -, cont -)
   (26.3s) symmetry computation finished: 1500 generators found (max: 1500, log10 of symmetry group size: 1750.0) (symcode time: 23.21)
dynamic symmetry handling statistics:
   orbitopal reduction:        5 components: 3x3, 3x3, 3x3, 3x3, 4x3
   orbital reduction:         11 components of sizes 3, 5, 9, 4, 4, 4, 18, 17, 8, 6, 12
   lexicographic reduction:  105 permutations with support sizes 8, 6, 8, 6, 6, 8, 6, 6, 6, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6
handled 45 out of 45 symmetry components
(round 20, exhaustive) 100757 del vars, 312828 del conss, 1385 add conss, 169722 chg bounds, 19726 chg sides, 29219 chg coeffs, 17417 upgd conss, 22520 impls, 1612 clqs
(round 21, exhaustive) 101738 del vars, 312830 del conss, 1385 add conss, 169722 chg bounds, 19726 chg sides, 29219 chg coeffs, 17417 upgd conss, 22528 impls, 1616 clqs
(round 22, fast)       101738 del vars, 313811 del conss, 1385 add conss, 169722 chg bounds, 19726 chg sides, 29219 chg coeffs, 17417 upgd conss, 22528 impls, 1616 clqs
   (26.5s) probing: 1103/17563 (6.3%) - 0 fixings, 1 aggregations, 563 implications, 2 bound changes
   (26.5s) probing aborted: 1000/1000 successive useless probings
presolving (23 rounds: 23 fast, 9 medium, 7 exhaustive):
 101738 deleted vars, 313811 deleted constraints, 1385 added constraints, 169722 tightened bounds, 0 added holes, 19726 changed sides, 29219 changed coefficients
 22621 implications, 1716 cliques
presolved problem has 26193 variables (17562 bin, 0 int, 0 impl, 8631 cont) and 20519 constraints
  16888 constraints of type <varbound>
   3323 constraints of type <linear>
    308 constraints of type <logicor>
transformed objective value is always integral (scale: 1)
Presolving Time: 26.29

 time | node  | left  |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr|  dualbound   | primalbound  |  gap   | compl. 
p26.7s|     1 |     0 |   115 |     - |   locks|   0 |  26k|  20k|  20k|   0 |  0 |   0 |   0 | 2.066500e+04 | 2.744200e+04 |  32.79%| unknown
i26.8s|     1 |     0 |   115 |     - |  oneopt|   0 |  26k|  20k|  20k|   0 |  0 |  24 |   0 | 2.066500e+04 | 2.707300e+04 |  31.01%| unknown
 27.6s|     1 |     0 | 11865 |     - |   661M |   0 |  26k|  20k|  20k|   0 |  0 |  24 |   0 | 2.097314e+04 | 2.707300e+04 |  29.08%| unknown
 28.4s|     1 |     0 | 16422 |     - |   671M |   0 |  26k|  20k|  21k|1624 |  1 |  24 |   0 | 2.258309e+04 | 2.707300e+04 |  19.88%| unknown
 29.4s|     1 |     0 | 21208 |     - |   674M |   0 |  26k|  20k|  21k|1857 |  2 |  24 |   0 | 2.279908e+04 | 2.707300e+04 |  18.75%| unknown
 29.7s|     1 |     0 | 23243 |     - |   676M |   0 |  26k|  20k|  22k|2002 |  3 |  24 |   0 | 2.292005e+04 | 2.707300e+04 |  18.12%| unknown
r29.8s|     1 |     0 | 23243 |     - |shifting|   0 |  26k|  20k|  22k|2002 |  3 |  24 |   0 | 2.292005e+04 | 2.515100e+04 |   9.73%| unknown
 30.2s|     1 |     0 | 25780 |     - |   680M |   0 |  26k|  20k|  22k|2120 |  4 |  24 |   0 | 2.302305e+04 | 2.515100e+04 |   9.24%| unknown
i32.7s|     1 |     0 | 37528 |     - |  oneopt|   0 |  26k|  20k|  22k|2120 |  4 |  24 |   0 | 2.302305e+04 | 2.514600e+04 |   9.22%| unknown
(node 1) unresolved numerical troubles in LP 10 -- using pseudo solution instead (loop 1)
 35.8s|     1 |     2 | 49334 |     - |   684M |   0 |  26k|  20k|  22k|2120 |  4 |  24 |   0 | 2.302305e+04 | 2.514600e+04 |   9.22%| unknown
d72.5s|    68 |    69 | 98165 | 730.5 |veclendi|  11 |  26k|  20k|  22k|   0 |  1 |  24 |1210 | 2.302805e+04 | 2.511900e+04 |   9.08%| unknown
o82.1s|    92 |    93 |122341 | 803.5 |rootsold|  13 |  26k|  20k|  22k|2233 |  1 |  24 |1568 | 2.302805e+04 | 2.509600e+04 |   8.98%| unknown
 84.6s|   100 |   101 |127795 | 793.7 |   734M |  13 |  26k|  20k|  22k|2251 |  1 |  24 |1659 | 2.302805e+04 | 2.509600e+04 |   8.98%| unknown
d92.0s|   125 |   126 |148749 | 802.7 |guideddi|  15 |  26k|  20k|  22k|   0 |  1 |  24 |1833 | 2.303005e+04 | 2.508900e+04 |   8.94%| unknown
r 114s|   190 |   191 |204186 | 819.9 |ziroundi|  17 |  26k|  20k|  22k|2380 |  1 |  24 |2696 | 2.303205e+04 | 2.508700e+04 |   8.92%| unknown
 time | node  | left  |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr|  dualbound   | primalbound  |  gap   | compl. 
  118s|   200 |   201 |209863 | 807.3 |   775M |  17 |  26k|  20k|  22k|2406 |  1 |  24 |2831 | 2.303304e+04 | 2.508700e+04 |   8.92%| unknown
r 139s|   262 |   263 |255895 | 791.9 |ziroundi|  20 |  26k|  20k|  22k|2504 |  1 |  28 |3608 | 2.303304e+04 | 2.508500e+04 |   8.91%| unknown
r 139s|   263 |   264 |255962 | 789.1 |ziroundi|  21 |  26k|  20k|  22k|2504 |  1 |  28 |3628 | 2.303304e+04 | 2.508400e+04 |   8.90%| unknown
r 148s|   297 |   298 |269173 | 743.1 |ziroundi|  21 |  26k|  20k|  22k|2538 |  1 |  29 |4061 | 2.303404e+04 | 2.508400e+04 |   8.90%| unknown
  150s|   300 |   301 |274671 | 754.0 |   818M |  21 |  26k|  20k|  22k|2538 |  1 |  30 |4103 | 2.303404e+04 | 2.508400e+04 |   8.90%| unknown
r 154s|   317 |   318 |279506 | 728.8 |ziroundi|  21 |  26k|  20k|  22k|2555 |  1 |  30 |4323 | 2.303404e+04 | 2.508400e+04 |   8.90%| unknown
r 155s|   322 |   323 |279561 | 717.6 |ziroundi|  21 |  26k|  20k|  22k|2555 |  1 |  30 |4365 | 2.303404e+04 | 2.508300e+04 |   8.90%| unknown
r 171s|   390 |   391 |313739 | 680.0 |ziroundi|  21 |  26k|  20k|  22k|2602 |  1 |  35 |5079 | 2.303404e+04 | 2.508300e+04 |   8.90%| unknown
  174s|   400 |   401 |320440 | 679.8 |   881M |  21 |  26k|  20k|  22k|2603 |  1 |  55 |5167 | 2.303404e+04 | 2.508300e+04 |   8.90%| unknown
r 177s|   418 |   419 |326124 | 664.0 |ziroundi|  21 |  26k|  20k|  22k|2603 |  1 |  56 |5295 | 2.303404e+04 | 2.508300e+04 |   8.90%| unknown
  202s|   500 |   501 |405291 | 713.6 |   909M |  21 |  26k|  20k|  22k|2734 |  1 |  65 |5922 | 2.303504e+04 | 2.508300e+04 |   8.89%| unknown
  226s|   600 |   601 |476672 | 713.6 |   922M |  23 |  26k|  20k|  22k|2845 |  2 |  73 |6493 | 2.303504e+04 | 2.508300e+04 |   8.89%| unknown
L 230s|   618 |   619 |481926 | 701.3 |    rins|  23 |  26k|  20k|  22k|2853 |  1 |  73 |6520 | 2.303504e+04 | 2.508200e+04 |   8.89%| unknown
  244s|   700 |   701 |519687 | 673.1 |   932M |  23 |  26k|  20k|  22k|2943 |  1 |  77 |6751 | 2.303505e+04 | 2.508200e+04 |   8.89%| unknown
  258s|   800 |   801 |554187 | 632.0 |   936M |  25 |  26k|  20k|  22k|3050 |  2 |  77 |7000 | 2.303602e+04 | 2.508200e+04 |   8.88%| unknown
 time | node  | left  |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr|  dualbound   | primalbound  |  gap   | compl. 
r 271s|   860 |   861 |603796 | 645.6 |ziroundi|  25 |  26k|  20k|  22k|3093 |  1 |  78 |7085 | 2.303602e+04 | 2.508200e+04*|   8.88%| unknown
r 272s|   874 |   875 |605526 | 637.2 |ziroundi|  25 |  26k|  20k|  22k|3093 |  1 |  78 |7096 | 2.303604e+04 | 2.508000e+04 |   8.87%| unknown
r 272s|   887 |   888 |609332 | 632.2 |ziroundi|  25 |  26k|  20k|  22k|3102 |  1 |  78 |7102 | 2.303604e+04 | 2.508000e+04 |   8.87%| unknown
r 273s|   888 |   889 |611391 | 633.8 |ziroundi|  25 |  26k|  20k|  22k|3102 |  1 |  78 |7102 | 2.303604e+04 | 2.507900e+04 |   8.87%| unknown
  273s|   900 |   901 |613629 | 627.8 |   965M |  25 |  26k|  20k|  22k|3102 |  1 |  78 |7119 | 2.303604e+04 | 2.507900e+04 |   8.87%| unknown
  290s|  1000 |  1001 |659345 | 610.7 |   971M |  25 |  26k|  20k|  22k|3208 |  1 |  81 |7495 | 2.303604e+04 | 2.507900e+04 |   8.87%| unknown
r 298s|  1044 |  1045 |672493 | 597.6 |ziroundi|  27 |  26k|  20k|  22k|3251 |  1 |  81 |7598 | 2.303604e+04 | 2.507800e+04 |   8.86%| unknown
Restart triggered after 50 consecutive estimations that the remaining tree will be large
(run 1, node 1049) performing user restart

(restart) converted 1743 cuts from the global cut pool into linear constraints

presolving:
(round 1, exhaustive) 0 del vars, 10 del conss, 0 add conss, 0 chg bounds, 0 chg sides, 0 chg coeffs, 1721 upgd conss, 22621 impls, 1903 clqs
(round 2, medium)     0 del vars, 20 del conss, 10 add conss, 2 chg bounds, 0 chg sides, 2 chg coeffs, 1721 upgd conss, 22626 impls, 1905 clqs
(round 3, exhaustive) 0 del vars, 45 del conss, 10 add conss, 2 chg bounds, 0 chg sides, 2 chg coeffs, 1721 upgd conss, 22626 impls, 1905 clqs
presolving (4 rounds: 4 fast, 4 medium, 3 exhaustive):
 0 deleted vars, 45 deleted constraints, 10 added constraints, 2 tightened bounds, 0 added holes, 0 changed sides, 2 changed coefficients
 22626 implications, 1905 cliques
presolved problem has 26193 variables (17562 bin, 0 int, 0 impl, 8631 cont) and 22285 constraints
  16893 constraints of type <varbound>
   1599 constraints of type <setppc>
   3345 constraints of type <linear>
    420 constraints of type <logicor>
     28 constraints of type <bounddisjunction>
transformed objective value is always integral (scale: 1)
Presolving Time: 28.20
transformed 100/100 original solutions to the transformed problem space


SCIP Status        : solving was interrupted [time limit reached]
Solving Time (sec) : 301.66
Solving Nodes      : 0 (total of 1049 nodes in 2 runs)
Primal Bound       : +2.50779999999950e+04 (124 solutions)
Dual Bound         : +2.30360421900002e+04
Gap                : 8.86 %
-------------------------------------------------------------------------------
                                    Summary                                    
-------------------------------------------------------------------------------
(CVXPY) Dec 23 01:26:03 PM: Problem status: optimal_inaccurate
(CVXPY) Dec 23 01:26:03 PM: Optimal value: 2.508e+04
(CVXPY) Dec 23 01:26:03 PM: Compilation took 7.962e+01 seconds
(CVXPY) Dec 23 01:26:03 PM: Solver (including time spent in interface) took 3.072e+02 seconds
G_subdag = G_dag.edge_subgraph(P.expr.with_flow.value > 0.5)
G_dag.shape, G_subdag.shape
((12767, 38388), (12619, 25078))
rel_dag_compression = (1 - (G_subdag.num_edges / G_dag.num_edges)) * 100
print(f"KPNN edge compression (0-100%): {rel_dag_compression:.2f}%")
KPNN edge compression (0-100%): 34.67%
pruned_models, pruned_metrics = stratified_kfold(G_subdag, input_nn, outputs_pkn, file_weights=os.path.join(temp_weights, "pruned_weights"))

print("Validation metrics:")
for k, v in pruned_metrics.items():
    print(f" - {k}: {np.mean(v):.3f}")
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12778
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_0.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 3s 75ms/step
 > Fold 0 validation ROC-AUC=0.992
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12778
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_1.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 73ms/step
 > Fold 1 validation ROC-AUC=0.983
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12778
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_2.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 69ms/step
 > Fold 2 validation ROC-AUC=0.987
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12778
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_3.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 68ms/step
 > Fold 3 validation ROC-AUC=0.994
Building DAG NN model with CORNETO using Keras with JAX...
 > N. inputs: 12459
 > N. outputs: 1
 > N. parameters: 12778
Compiling...
Fitting...
Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_4.keras
11/11 ━━━━━━━━━━━━━━━━━━━━ 1s 68ms/step
 > Fold 4 validation ROC-AUC=0.994
Validation metrics:
 - accuracy: 0.961
 - precision: 0.971
 - recall: 0.949
 - f1: 0.960
 - roc_auc: 0.990
df_biases_pruned = load_biases(file=os.path.join(temp_weights, "pruned_weights"), folds=5)
df_biases_pruned.sort_values(by="abs_bias", ascending=False).head(10)
bias abs_bias pow2_bias
fold gene
1 SUZ12.EZH2 -2.892245 2.892245 8.365083
4 SUZ12.EZH2 -2.736529 2.736529 7.488593
2 SUZ12.EZH2 -2.575348 2.575348 6.632418
0 SETDB1.NLK.CHD7 -2.476090 2.476090 6.131024
4 ZAP70 -2.432173 2.432173 5.915468
3 ZAP70 2.325642 2.325642 5.408611
0 PRKCD -2.315197 2.315197 5.360137
DUSP3 -2.269532 2.269532 5.150774
4 SETDB1.NLK.CHD7 -2.219929 2.219929 4.928085
3 NFYA -2.163613 2.163613 4.681221
df_biases_prunedr = df_biases_pruned.copy().reset_index()
gene_biases_score = df_biases_prunedr.groupby("gene")["pow2_bias"].mean().sort_values(ascending=False)
gene_biases_score
gene
SUZ12.EZH2         4.990975
TCR                3.651254
ZAP70              3.080308
SETDB1.NLK.CHD7    2.942616
NFYA               2.748323
                     ...   
PBX1               0.007306
NR2F1              0.005788
ASH2L              0.002699
BCL3               0.002370
GATA3              0.000018
Name: pow2_bias, Length: 160, dtype: float32
gene_biases_score.head(30).plot.bar()
<Axes: xlabel='gene'>
../_images/33822315de1839b55d48d834bc0f78bfc87246e48217cb27a6c073c27555885f.png
# Get the top genes sorted by mean bias
top_genes = gene_biases_score.head(30).index

# Filter the DataFrame to include only rows with these top genes
filtered_df = df_biases_prunedr[df_biases_prunedr['gene'].isin(top_genes)]

plt.figure(figsize=(10, 6))
sns.violinplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
sns.stripplot(data=filtered_df, x="gene", y="bias", order=top_genes, dodge=True)
plt.xticks(rotation=90)
plt.title('Biases of the trained neurons')
plt.xlabel('Gene')
plt.ylabel('Bias')
plt.axhline(0, linestyle="--", color="k")
plt.tight_layout()
../_images/8d55587c9719a43386566182fdfdf54d1b3d53a58e6821fd79682f8984b28e77.png
param_compression = (1-(pruned_models[0].count_params()/models[0].count_params())) * 100
print(f"Parameter compression: {param_compression:.2f}%")
Parameter compression: 51.30%
perf_degradation = ((np.mean(metrics["roc_auc"]) - np.mean(pruned_metrics["roc_auc"])) / np.mean(metrics["roc_auc"])) * 100
print(f"Degradation in ROC-AUC after compression (positive = decrease in performance, negative = increase in performance): {perf_degradation:.2f}%")
Degradation in ROC-AUC after compression (positive = decrease in performance, negative = increase in performance): -0.04%