{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "k9qe_aUJSeho" }, "source": [ "# Knowledge-primed Neural Networks (KPNNs) for single cell data\n", "\n", "In this tutorial we will show how CORNETO can be used to build custom neural network architectures informed by prior knowledge. We will see how to implement a knowledge-primed neural network1. We will use the single cell data from the publication \"Knowledge-primed neural networks enable biologically interpretable deep learning on single-cell sequencing data\", from Nikolaus Fortelny & Christoph Bock, where they used a single-cell RNA-seq dataset they previously generated2, which measures cellular responses to T cell receptor (TCR) stimulation in a standardized in vitro model. The dataset was chosen due to the TCR signaling pathway's complexity and its well-characterized role in orchestrating transcriptional responses to antigen detection in T cells.\n", "\n", "## Why CORNETO?\n", "\n", "In the original publication, authors built a KPNN by searching on databases, building a Direct Acyclic Graph (DAG) by running shortest paths from TCR receptor to genes. However, this approach is not optimal. CORNETO, thanks to its advanced capabilities for modeling and optimization on networks, provides methods to automatically find DAG architectures in an optimal way. \n", "\n", "In addition to this, CORNETO provides methods to build DAG NN architectures with ease using Keras +3, making KPNN implementation very flexible and interoperable with backends like Pytorch, Tensorflow and JAX.\n", "\n", "## How does it work?\n", "\n", "Thanks to CORNETO's building blocks for optimization over networks, we can easily model optimization problems to find DAG architectures from a Prior Knowledge Network. After we have the backbone, we can convert it to a neural network using the utility functions included in CORNETO.\n", "\n", "\n", "## References\n", "1. Fortelny, N., & Bock, C. (2020). Knowledge-primed neural networks enable biologically interpretable deep learning on single-cell sequencing data. Genome biology, 21, 1-36.\n", "2. Datlinger, P., Rendeiro, A. F., Schmidl, C., Krausgruber, T., Traxler, P., Klughammer, J., ... & Bock, C. (2017). Pooled CRISPR screening with single-cell transcriptome readout. Nature methods, 14(3), 297-301.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download and import the single cell dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AYKxvyAC_ZCn", "outputId": "c4c55b18-5438-4f94-de96-bd402e760dc6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Effective URL: https://medical-epigenomics.org/papers/fortelny2019/\n", "Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_Edgelist.csv to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Edgelist.csv\n", "Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_ClassLabels.csv to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_ClassLabels.csv\n", "Downloading https://medical-epigenomics.org/papers/fortelny2019/TCR_Data.h5 to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Data.h5\n", "Downloaded files:\n", "/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Edgelist.csv\n", "/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_ClassLabels.csv\n", "/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpb21ulczk/TCR_Data.h5\n" ] } ], "source": [ "import os\n", "import urllib.request\n", "import urllib.parse\n", "import tempfile\n", "import pandas as pd\n", "import scanpy as sc\n", "import numpy as np\n", "import corneto as cn\n", "\n", "with urllib.request.urlopen(\"http://kpnn.computational-epigenetics.org/\") as response:\n", " web_input = response.geturl()\n", "print(\"Effective URL:\", web_input)\n", "\n", "files = [\"TCR_Edgelist.csv\", \"TCR_ClassLabels.csv\", \"TCR_Data.h5\"]\n", "\n", "temp_dir = tempfile.mkdtemp()\n", "\n", "# Download files\n", "file_paths = []\n", "for file in files:\n", " url = urllib.parse.urljoin(web_input, file)\n", " output_path = os.path.join(temp_dir, file)\n", " print(f\"Downloading {url} to {output_path}\")\n", " try:\n", " with urllib.request.urlopen(url) as response:\n", " with open(output_path, 'wb') as f:\n", " f.write(response.read())\n", " file_paths.append(output_path)\n", " except Exception as e:\n", " print(f\"Failed to download {url}: {e}\")\n", "\n", "print(\"Downloaded files:\")\n", "for path in file_paths:\n", " print(path)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9qaRPRLeA05k", "outputId": "21b2325e-5603-434d-a65e-996925754bb3" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/pablorodriguezmier/miniforge3/envs/corneto/lib/python3.12/site-packages/anndata/_core/anndata.py:1758: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", " utils.warn_names_duplicates(\"var\")\n" ] } ], "source": [ "# The data contains also the original network they built with shortest paths.\n", "# We will use it to replicate the study\n", "df_edges = pd.read_csv(file_paths[0])\n", "df_labels = pd.read_csv(file_paths[1])\n", "# Import the 10x data with Scanpy\n", "adata = sc.read_10x_h5(file_paths[2])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
barcodeTCR
0AAACCTGCACACATGT-10
1AAACCTGCACGTCTCT-10
2AAACCTGTCAATACCG-10
3AAACCTGTCGTGGTCG-10
4AAACGGGTCTGAGTGT-10
.........
1730TTTCCTCGTCATGCCG-21
1731TTTGCGCGTAGCCTCG-21
1732TTTGGTTAGATACACA-21
1733TTTGGTTGTATGAATG-21
1734TTTGGTTTCCAAGTAC-21
\n", "

1735 rows × 2 columns

\n", "
" ], "text/plain": [ " barcode TCR\n", "0 AAACCTGCACACATGT-1 0\n", "1 AAACCTGCACGTCTCT-1 0\n", "2 AAACCTGTCAATACCG-1 0\n", "3 AAACCTGTCGTGGTCG-1 0\n", "4 AAACGGGTCTGAGTGT-1 0\n", "... ... ...\n", "1730 TTTCCTCGTCATGCCG-2 1\n", "1731 TTTGCGCGTAGCCTCG-2 1\n", "1732 TTTGGTTAGATACACA-2 1\n", "1733 TTTGGTTGTATGAATG-2 1\n", "1734 TTTGGTTTCCAAGTAC-2 1\n", "\n", "[1735 rows x 2 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_labels" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
parentchild
0TCRZAP70
1ZAP70MAPK14
2MAPK14FOXO3
3MAPK14STAT1
4MAPK14STAT3
.........
27574HMGA1MTRNR2L9_gene
27575MYBC12orf50_gene
27576MYBTRPC5OS_gene
27577SOX2TRPC5OS_gene
27578CRTC1MTRNR2L9_gene
\n", "

27579 rows × 2 columns

\n", "
" ], "text/plain": [ " parent child\n", "0 TCR ZAP70\n", "1 ZAP70 MAPK14\n", "2 MAPK14 FOXO3\n", "3 MAPK14 STAT1\n", "4 MAPK14 STAT3\n", "... ... ...\n", "27574 HMGA1 MTRNR2L9_gene\n", "27575 MYB C12orf50_gene\n", "27576 MYB TRPC5OS_gene\n", "27577 SOX2 TRPC5OS_gene\n", "27578 CRTC1 MTRNR2L9_gene\n", "\n", "[27579 rows x 2 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_edges" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 423 }, "id": "gZVpIe8UN011", "outputId": "6a468338-26e6-4ac4-b115-af6169126445" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
gene_ids
DDX11L1ENSG00000223972
WASH7PENSG00000227232
MIR6859-2ENSG00000278267
MIR1302-10ENSG00000243485
MIR1302-11ENSG00000274890
......
Tcrlibrary_RUNX2_3_geneTcrlibrary_RUNX2_3_gene
Tcrlibrary_ZAP70_1_geneTcrlibrary_ZAP70_1_gene
Tcrlibrary_ZAP70_2_geneTcrlibrary_ZAP70_2_gene
Tcrlibrary_ZAP70_3_geneTcrlibrary_ZAP70_3_gene
Cas9_blast_geneCas9_blast_gene
\n", "

64370 rows × 1 columns

\n", "
" ], "text/plain": [ " gene_ids\n", "DDX11L1 ENSG00000223972\n", "WASH7P ENSG00000227232\n", "MIR6859-2 ENSG00000278267\n", "MIR1302-10 ENSG00000243485\n", "MIR1302-11 ENSG00000274890\n", "... ...\n", "Tcrlibrary_RUNX2_3_gene Tcrlibrary_RUNX2_3_gene\n", "Tcrlibrary_ZAP70_1_gene Tcrlibrary_ZAP70_1_gene\n", "Tcrlibrary_ZAP70_2_gene Tcrlibrary_ZAP70_2_gene\n", "Tcrlibrary_ZAP70_3_gene Tcrlibrary_ZAP70_3_gene\n", "Cas9_blast_gene Cas9_blast_gene\n", "\n", "[64370 rows x 1 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata.var" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5AUbkDZFP2IK", "outputId": "1ca27838-348a-4b2a-98df-ef3de27cc4c8" }, "outputs": [], "source": [ "# We can normalize the data, however, it is better to avoid\n", "# preprocessing the whole dataset before splitting in training and test\n", "# to avoid data leakage.\n", "# NOTE: Normalization can be done inside the cross-val loop\n", "# sc.pp.normalize_total(adata, target_sum=1e6)\n", "\n", "# Log-transform the data does not leak data as it does not estimate anything\n", "sc.pp.log1p(adata)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 406 }, "id": "ugELTxTtQx9p", "outputId": "a476ce35-cc52-4d73-ff0e-b6487a1d95b6" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AAACCTGAGAAACCAT-1
AAACCTGAGAAACCGC-1
AAACCTGAGAAACCTA-1
AAACCTGAGAAACGAG-1
AAACCTGAGAAACGCC-1
...
TTTGTCATCTTTACAC-2
TTTGTCATCTTTACGT-2
TTTGTCATCTTTAGGG-2
TTTGTCATCTTTAGTC-2
TTTGTCATCTTTCCTC-2
\n", "

1474560 rows × 0 columns

\n", "
" ], "text/plain": [ "Empty DataFrame\n", "Columns: []\n", "Index: [AAACCTGAGAAACCAT-1, AAACCTGAGAAACCGC-1, AAACCTGAGAAACCTA-1, AAACCTGAGAAACGAG-1, AAACCTGAGAAACGCC-1, AAACCTGAGAAAGTGG-1, AAACCTGAGAACAACT-1, AAACCTGAGAACAATC-1, AAACCTGAGAACTCGG-1, AAACCTGAGAACTGTA-1, AAACCTGAGAAGAAGC-1, AAACCTGAGAAGATTC-1, AAACCTGAGAAGCCCA-1, AAACCTGAGAAGGACA-1, AAACCTGAGAAGGCCT-1, AAACCTGAGAAGGGTA-1, AAACCTGAGAAGGTGA-1, AAACCTGAGAAGGTTT-1, AAACCTGAGAATAGGG-1, AAACCTGAGAATCTCC-1, AAACCTGAGAATGTGT-1, AAACCTGAGAATGTTG-1, AAACCTGAGAATTCCC-1, AAACCTGAGAATTGTG-1, AAACCTGAGACAAAGG-1, AAACCTGAGACAAGCC-1, AAACCTGAGACAATAC-1, AAACCTGAGACACGAC-1, AAACCTGAGACACTAA-1, AAACCTGAGACAGACC-1, AAACCTGAGACAGAGA-1, AAACCTGAGACAGGCT-1, AAACCTGAGACATAAC-1, AAACCTGAGACCACGA-1, AAACCTGAGACCCACC-1, AAACCTGAGACCGGAT-1, AAACCTGAGACCTAGG-1, AAACCTGAGACCTTTG-1, AAACCTGAGACGACGT-1, AAACCTGAGACGCAAC-1, AAACCTGAGACGCACA-1, AAACCTGAGACGCTTT-1, AAACCTGAGACTAAGT-1, AAACCTGAGACTACAA-1, AAACCTGAGACTAGAT-1, AAACCTGAGACTAGGC-1, AAACCTGAGACTCGGA-1, AAACCTGAGACTGGGT-1, AAACCTGAGACTGTAA-1, AAACCTGAGACTTGAA-1, AAACCTGAGACTTTCG-1, AAACCTGAGAGAACAG-1, AAACCTGAGAGACGAA-1, AAACCTGAGAGACTAT-1, AAACCTGAGAGACTTA-1, AAACCTGAGAGAGCTC-1, AAACCTGAGAGATGAG-1, AAACCTGAGAGCAATT-1, AAACCTGAGAGCCCAA-1, AAACCTGAGAGCCTAG-1, AAACCTGAGAGCTATA-1, AAACCTGAGAGCTGCA-1, AAACCTGAGAGCTGGT-1, AAACCTGAGAGCTTCT-1, AAACCTGAGAGGACGG-1, AAACCTGAGAGGGATA-1, AAACCTGAGAGGGCTT-1, AAACCTGAGAGGTACC-1, AAACCTGAGAGGTAGA-1, AAACCTGAGAGGTTAT-1, AAACCTGAGAGGTTGC-1, AAACCTGAGAGTAAGG-1, AAACCTGAGAGTAATC-1, AAACCTGAGAGTACAT-1, AAACCTGAGAGTACCG-1, AAACCTGAGAGTCGGT-1, AAACCTGAGAGTCTGG-1, AAACCTGAGAGTGACC-1, AAACCTGAGAGTGAGA-1, AAACCTGAGAGTTGGC-1, AAACCTGAGATACACA-1, AAACCTGAGATAGCAT-1, AAACCTGAGATAGGAG-1, AAACCTGAGATAGTCA-1, AAACCTGAGATATACG-1, AAACCTGAGATATGCA-1, AAACCTGAGATATGGT-1, AAACCTGAGATCACGG-1, AAACCTGAGATCCCAT-1, AAACCTGAGATCCCGC-1, AAACCTGAGATCCGAG-1, AAACCTGAGATCCTGT-1, AAACCTGAGATCGATA-1, AAACCTGAGATCGGGT-1, AAACCTGAGATCTGAA-1, AAACCTGAGATCTGCT-1, AAACCTGAGATGAGAG-1, AAACCTGAGATGCCAG-1, AAACCTGAGATGCCTT-1, AAACCTGAGATGCGAC-1, ...]\n", "\n", "[1474560 rows x 0 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata.obs" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f5utYHPCE_aT", "outputId": "1a32a812-9c3d-4bdc-ca0f-652943aad40e" }, "outputs": [ { "data": { "text/plain": [ "Index(['AAACCTGAGAAACCAT-1', 'AAACCTGAGAAACCGC-1', 'AAACCTGAGAAACCTA-1',\n", " 'AAACCTGAGAAACGAG-1', 'AAACCTGAGAAACGCC-1', 'AAACCTGAGAAAGTGG-1',\n", " 'AAACCTGAGAACAACT-1', 'AAACCTGAGAACAATC-1', 'AAACCTGAGAACTCGG-1',\n", " 'AAACCTGAGAACTGTA-1',\n", " ...\n", " 'TTTGTCATCTTGGGTA-2', 'TTTGTCATCTTGTACT-2', 'TTTGTCATCTTGTATC-2',\n", " 'TTTGTCATCTTGTCAT-2', 'TTTGTCATCTTGTTTG-2', 'TTTGTCATCTTTACAC-2',\n", " 'TTTGTCATCTTTACGT-2', 'TTTGTCATCTTTAGGG-2', 'TTTGTCATCTTTAGTC-2',\n", " 'TTTGTCATCTTTCCTC-2'],\n", " dtype='object', length=1474560)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "barcodes = adata.obs_names\n", "barcodes" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "caIATAb6F6KO", "outputId": "46ae6909-b20d-438e-8cfc-24b26d43ffe2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['DDX11L1', 'WASH7P', 'MIR6859-2', 'MIR1302-10', 'MIR1302-11', 'FAM138A',\n", " 'OR4G4P', 'OR4G11P', 'OR4F5', 'RP11-34P13.7',\n", " ...\n", " 'Tcrlibrary_RUNX1_1_gene', 'Tcrlibrary_RUNX1_2_gene',\n", " 'Tcrlibrary_RUNX1_3_gene', 'Tcrlibrary_RUNX2_1_gene',\n", " 'Tcrlibrary_RUNX2_2_gene', 'Tcrlibrary_RUNX2_3_gene',\n", " 'Tcrlibrary_ZAP70_1_gene', 'Tcrlibrary_ZAP70_2_gene',\n", " 'Tcrlibrary_ZAP70_3_gene', 'Cas9_blast_gene'],\n", " dtype='object', length=64370)\n" ] } ], "source": [ "gene_names = adata.var.index\n", "print(gene_names)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3_qlaVGLHLTU", "outputId": "937253d4-e12a-437c-ed3b-9077a4043fd8" }, "outputs": [ { "data": { "text/plain": [ "1735" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(set(df_labels.barcode.tolist()))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NudJw90wHM9B", "outputId": "8c46748f-d965-4706-f3d5-d5c8129fb2c0" }, "outputs": [ { "data": { "text/plain": [ "1474560" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(set(barcodes.tolist()))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-MUOcHm6HWS2", "outputId": "29d31d6c-9f11-46c2-e3da-8f2fde1c7f93" }, "outputs": [ { "data": { "text/plain": [ "1735" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matched_barcodes = sorted(set(barcodes.tolist()) & set(df_labels.barcode.tolist()))\n", "len(matched_barcodes)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 423 }, "id": "bH_1lBq6Lp8L", "outputId": "740058ff-2918-466b-9575-76c10db36701" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
barcodeTCR
0AAACCTGCACACATGT-10
1AAACCTGCACGTCTCT-10
2AAACCTGTCAATACCG-10
3AAACCTGTCGTGGTCG-10
4AAACGGGTCTGAGTGT-10
.........
1730TTTCCTCGTCATGCCG-21
1731TTTGCGCGTAGCCTCG-21
1732TTTGGTTAGATACACA-21
1733TTTGGTTGTATGAATG-21
1734TTTGGTTTCCAAGTAC-21
\n", "

1735 rows × 2 columns

\n", "
" ], "text/plain": [ " barcode TCR\n", "0 AAACCTGCACACATGT-1 0\n", "1 AAACCTGCACGTCTCT-1 0\n", "2 AAACCTGTCAATACCG-1 0\n", "3 AAACCTGTCGTGGTCG-1 0\n", "4 AAACGGGTCTGAGTGT-1 0\n", "... ... ...\n", "1730 TTTCCTCGTCATGCCG-2 1\n", "1731 TTTGCGCGTAGCCTCG-2 1\n", "1732 TTTGGTTAGATACACA-2 1\n", "1733 TTTGGTTGTATGAATG-2 1\n", "1734 TTTGGTTTCCAAGTAC-2 1\n", "\n", "[1735 rows x 2 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#This is the InPathsY data in the original code of KPNNs\n", "df_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import PKN with CORNETO" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", "
Installed version:v1.0.0.dev1 (latest stable: v1.0.0-alpha)
Available backends:CVXPY v1.6.0
Default backend (corneto.opt):CVXPY
Installed solvers:CLARABEL, CVXOPT, GLPK, GLPK_MI, GUROBI, HIGHS, SCIP, SCIPY
Graphviz version:v0.20.3
Installed path:/Users/pablorodriguezmier/Documents/work/projects/corneto/corneto
Repository:https://github.com/saezlab/corneto
\n", "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import corneto as cn\n", "cn.info()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(13121, 1)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outputs_pkn = list(set(df_edges.parent.tolist()) - set(df_edges.child.tolist()))\n", "inputs_pkn = set(df_edges.child.tolist()) - set(df_edges.parent.tolist())\n", "input_pkn_genes = list(set(g.split(\"_\")[0] for g in inputs_pkn))\n", "len(inputs_pkn), len(outputs_pkn)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(13439, 27579)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tuples = [(r.child, 1, r.parent) for _, r in df_edges.iterrows()]\n", "G = cn.Graph.from_sif_tuples(tuples)\n", "G = G.prune(inputs_pkn, outputs_pkn)\n", "\n", "# Size of the original PKN provided by the authors\n", "G.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Select the single cell data for training" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XCy2cd5zK2cp", "outputId": "5307650d-0633-48fd-a086-d2a64ef1b1e7" }, "outputs": [ { "data": { "text/plain": [ "(1735, 14229)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_matched = adata[adata.obs_names.isin(matched_barcodes), adata.var_names.isin(input_pkn_genes)]\n", "adata_matched.shape" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "L0qiukuPK-ky", "outputId": "91d4bb75-f51c-46be-e270-22afd8d73d36" }, "outputs": [ { "data": { "text/plain": [ "12459" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "non_zero_genes = set(adata_matched.to_df().columns[adata_matched.to_df().sum(axis=0) >= 1e-6].values)\n", "len(non_zero_genes)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "l2Ia3ViolksN", "outputId": "e8a9e21c-c52a-48f6-a213-89b88f2df238" }, "outputs": [ { "data": { "text/plain": [ "12459" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(non_zero_genes.intersection(adata_matched.var_names))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dKPAE_g4hefN", "outputId": "c1a5b3e1-c627-4ac5-e390-19061fb32cfc" }, "outputs": [ { "data": { "text/plain": [ "(1735, 12487)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_matched = adata_matched[:, adata_matched.var_names.isin(non_zero_genes)]\n", "# Many duplicates still 0 counts\n", "adata_matched = adata_matched[:, adata_matched.to_df().sum(axis=0) != 0]\n", "adata_matched.shape" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 617 }, "id": "-dVxay1DjH0X", "outputId": "7a0816fc-cbee-48eb-ecdb-62a0e9d18488" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/ipykernel_91875/3016726407.py:2: FutureWarning: DataFrame.groupby with axis=1 is deprecated. Do `frame.T.groupby(...)` without axis instead.\n", " df_expr = df_expr.groupby(df_expr.columns, axis=1).max()\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
A1BGA2ML1AAASAACSAADATAAED1AAGABAAK1AAMDCAAMP...ZSWIM8ZUFSPZW10ZWILCHZXDCZYG11AZYG11BZYXZZEF1ZZZ3
AAACCTGCACACATGT-10.00.00.0000000.6931470.00.0000000.0000000.0000000.01.098612...0.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
AAACCTGCACGTCTCT-10.00.00.0000000.0000000.00.6931470.0000000.0000000.00.693147...0.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
AAACCTGTCAATACCG-10.00.00.0000000.0000000.00.0000000.0000000.0000000.00.693147...0.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
AAACCTGTCGTGGTCG-10.00.01.0986120.0000000.00.0000000.0000000.0000000.01.098612...0.0000000.6931470.0000000.6931470.00.00.00.0000000.0000000.000000
AAACGGGTCTGAGTGT-10.00.00.0000000.0000000.00.6931470.0000000.0000000.00.000000...0.6931470.0000000.6931470.0000000.00.00.00.0000000.0000000.000000
..................................................................
TTTCCTCGTCATGCCG-20.00.00.0000000.0000000.00.0000000.0000000.6931470.00.693147...0.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
TTTGCGCGTAGCCTCG-20.00.01.0986120.0000000.00.0000000.0000000.0000000.01.386294...0.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
TTTGGTTAGATACACA-20.00.01.0986120.0000000.00.0000000.0000001.0986120.00.693147...0.6931470.0000000.6931470.0000000.00.00.00.6931470.6931470.000000
TTTGGTTGTATGAATG-20.00.00.0000000.0000000.00.6931471.0986120.0000000.00.693147...0.0000000.6931470.0000000.0000000.00.00.00.0000000.0000000.693147
TTTGGTTTCCAAGTAC-20.00.00.0000000.0000000.00.0000000.0000000.0000000.00.693147...0.0000000.0000000.0000000.0000000.00.00.01.0986120.0000000.000000
\n", "

1735 rows × 12459 columns

\n", "
" ], "text/plain": [ " A1BG A2ML1 AAAS AACS AADAT AAED1 \\\n", "AAACCTGCACACATGT-1 0.0 0.0 0.000000 0.693147 0.0 0.000000 \n", "AAACCTGCACGTCTCT-1 0.0 0.0 0.000000 0.000000 0.0 0.693147 \n", "AAACCTGTCAATACCG-1 0.0 0.0 0.000000 0.000000 0.0 0.000000 \n", "AAACCTGTCGTGGTCG-1 0.0 0.0 1.098612 0.000000 0.0 0.000000 \n", "AAACGGGTCTGAGTGT-1 0.0 0.0 0.000000 0.000000 0.0 0.693147 \n", "... ... ... ... ... ... ... \n", "TTTCCTCGTCATGCCG-2 0.0 0.0 0.000000 0.000000 0.0 0.000000 \n", "TTTGCGCGTAGCCTCG-2 0.0 0.0 1.098612 0.000000 0.0 0.000000 \n", "TTTGGTTAGATACACA-2 0.0 0.0 1.098612 0.000000 0.0 0.000000 \n", "TTTGGTTGTATGAATG-2 0.0 0.0 0.000000 0.000000 0.0 0.693147 \n", "TTTGGTTTCCAAGTAC-2 0.0 0.0 0.000000 0.000000 0.0 0.000000 \n", "\n", " AAGAB AAK1 AAMDC AAMP ... ZSWIM8 \\\n", "AAACCTGCACACATGT-1 0.000000 0.000000 0.0 1.098612 ... 0.000000 \n", "AAACCTGCACGTCTCT-1 0.000000 0.000000 0.0 0.693147 ... 0.000000 \n", "AAACCTGTCAATACCG-1 0.000000 0.000000 0.0 0.693147 ... 0.000000 \n", "AAACCTGTCGTGGTCG-1 0.000000 0.000000 0.0 1.098612 ... 0.000000 \n", "AAACGGGTCTGAGTGT-1 0.000000 0.000000 0.0 0.000000 ... 0.693147 \n", "... ... ... ... ... ... ... \n", "TTTCCTCGTCATGCCG-2 0.000000 0.693147 0.0 0.693147 ... 0.000000 \n", "TTTGCGCGTAGCCTCG-2 0.000000 0.000000 0.0 1.386294 ... 0.000000 \n", "TTTGGTTAGATACACA-2 0.000000 1.098612 0.0 0.693147 ... 0.693147 \n", "TTTGGTTGTATGAATG-2 1.098612 0.000000 0.0 0.693147 ... 0.000000 \n", "TTTGGTTTCCAAGTAC-2 0.000000 0.000000 0.0 0.693147 ... 0.000000 \n", "\n", " ZUFSP ZW10 ZWILCH ZXDC ZYG11A ZYG11B \\\n", "AAACCTGCACACATGT-1 0.000000 0.000000 0.000000 0.0 0.0 0.0 \n", "AAACCTGCACGTCTCT-1 0.000000 0.000000 0.000000 0.0 0.0 0.0 \n", "AAACCTGTCAATACCG-1 0.000000 0.000000 0.000000 0.0 0.0 0.0 \n", "AAACCTGTCGTGGTCG-1 0.693147 0.000000 0.693147 0.0 0.0 0.0 \n", "AAACGGGTCTGAGTGT-1 0.000000 0.693147 0.000000 0.0 0.0 0.0 \n", "... ... ... ... ... ... ... \n", "TTTCCTCGTCATGCCG-2 0.000000 0.000000 0.000000 0.0 0.0 0.0 \n", "TTTGCGCGTAGCCTCG-2 0.000000 0.000000 0.000000 0.0 0.0 0.0 \n", "TTTGGTTAGATACACA-2 0.000000 0.693147 0.000000 0.0 0.0 0.0 \n", "TTTGGTTGTATGAATG-2 0.693147 0.000000 0.000000 0.0 0.0 0.0 \n", "TTTGGTTTCCAAGTAC-2 0.000000 0.000000 0.000000 0.0 0.0 0.0 \n", "\n", " ZYX ZZEF1 ZZZ3 \n", "AAACCTGCACACATGT-1 0.000000 0.000000 0.000000 \n", "AAACCTGCACGTCTCT-1 0.000000 0.000000 0.000000 \n", "AAACCTGTCAATACCG-1 0.000000 0.000000 0.000000 \n", "AAACCTGTCGTGGTCG-1 0.000000 0.000000 0.000000 \n", "AAACGGGTCTGAGTGT-1 0.000000 0.000000 0.000000 \n", "... ... ... ... \n", "TTTCCTCGTCATGCCG-2 0.000000 0.000000 0.000000 \n", "TTTGCGCGTAGCCTCG-2 0.000000 0.000000 0.000000 \n", "TTTGGTTAGATACACA-2 0.693147 0.693147 0.000000 \n", "TTTGGTTGTATGAATG-2 0.000000 0.000000 0.693147 \n", "TTTGGTTTCCAAGTAC-2 1.098612 0.000000 0.000000 \n", "\n", "[1735 rows x 12459 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_expr = adata_matched.to_df()\n", "df_expr = df_expr.groupby(df_expr.columns, axis=1).max()\n", "df_expr" ] }, { "cell_type": "markdown", "metadata": { "id": "RrIZBw2PqKQ2" }, "source": [ "## Building and training the KPNN\n", "\n", "Now we will use the provided PKN by the authors and the utility functions in CORNETO to build a KPNN similar to the one used in the original manuscript" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "5X0GxeLeqLRK" }, "outputs": [], "source": [ "import keras\n", "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score\n", "from keras.optimizers import Adam\n", "from keras.callbacks import EarlyStopping\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "I9PoXtGOtvYs", "outputId": "c22ed990-e595-4b6e-e4c4-da4ba061b93a" }, "outputs": [ { "data": { "text/plain": [ "((1735, 12459), (1735,))" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Use the data from the experiment\n", "X = df_expr.values\n", "y = df_labels.set_index(\"barcode\").loc[df_expr.index, \"TCR\"].values\n", "X.shape, y.shape" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(13121, 1)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We can prefilter on top N genes to make this faster\n", "top_n = None\n", "\n", "# From the given PKN\n", "outputs_pkn = list(set(df_edges.parent.tolist()) - set(df_edges.child.tolist()))\n", "inputs_pkn = set(df_edges.child.tolist()) - set(df_edges.parent.tolist())\n", "input_pkn_genes = list(set(g.split(\"_\")[0] for g in inputs_pkn))\n", "\n", "if top_n is not None and top_n > 0:\n", " input_pkn_genes = list(set(input_pkn_genes).intersection(df_expr.var(axis=0).sort_values(ascending=False).head(top_n).index))\n", " inputs_pkn = list(g + \"_gene\" for g in input_pkn_genes)\n", "\n", "len(inputs_pkn), len(outputs_pkn)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12459" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_nn_genes = list(set(input_pkn_genes).intersection(df_expr.columns))\n", "input_nn = [g + \"_gene\" for g in input_nn_genes]\n", "len(input_nn)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(12767, 25928)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Build corneto graph\n", "tuples = [(r.child, 1, r.parent) for _, r in df_edges.iterrows()]\n", "G = cn.Graph.from_sif_tuples(tuples)\n", "G = G.prune(input_nn, outputs_pkn)\n", "G.shape" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(12459, 12459)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(input_nn), len(input_nn_genes)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12459" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(set(input_nn).intersection(G.V))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1735, 12459), (1735,))" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = df_expr.loc[:, input_nn_genes].values\n", "y = df_labels.set_index(\"barcode\").loc[df_expr.index, \"TCR\"].values\n", "X.shape, y.shape" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 26236\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_0.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 208ms/step\n", " > Fold 0 validation ROC-AUC=0.991\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 26236\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_1.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step\n", " > Fold 1 validation ROC-AUC=0.983\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 26236\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_2.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 179ms/step\n", " > Fold 2 validation ROC-AUC=0.986\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 26236\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_3.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 180ms/step\n", " > Fold 3 validation ROC-AUC=0.993\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 26236\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/weights_4.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 181ms/step\n", " > Fold 4 validation ROC-AUC=0.994\n", "Validation metrics:\n", " - accuracy: 0.957\n", " - precision: 0.961\n", " - recall: 0.951\n", " - f1: 0.956\n", " - roc_auc: 0.989\n" ] } ], "source": [ "from corneto._ml import build_dagnn\n", "\n", "def stratified_kfold(\n", " G,\n", " inputs,\n", " outputs,\n", " n_splits=5,\n", " shuffle=True,\n", " random_state=42,\n", " lr=0.001,\n", " patience=10,\n", " file_weights=\"weights\",\n", " dagnn_config=dict(\n", " batch_norm_input=True,\n", " batch_norm_center=False,\n", " batch_norm_scale=False,\n", " bias_reg_l1=1e-3,\n", " bias_reg_l2=1e-2,\n", " dropout=0.20,\n", " default_hidden_activation=\"sigmoid\",\n", " default_output_activation=\"sigmoid\",\n", " verbose=False\n", " )\n", "):\n", " kfold = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)\n", " models = []\n", " metrics = {m: [] for m in [\"accuracy\", \"precision\", \"recall\", \"f1\", \"roc_auc\"]}\n", " for i, (train_idx, val_idx) in enumerate(kfold.split(X, y)):\n", " X_train, X_val = X[train_idx], X[val_idx]\n", " y_train, y_val = y[train_idx], y[val_idx]\n", " \n", " print(\"Building DAG NN model with CORNETO using Keras with JAX...\")\n", " print(f\" > N. inputs: {len(input_nn)}\")\n", " print(f\" > N. outputs: {len(outputs_pkn)}\")\n", " model = build_dagnn(\n", " G, \n", " input_nn, \n", " outputs_pkn,\n", " **dagnn_config\n", " )\n", " print(f\" > N. parameters: {model.count_params()}\")\n", " \n", " # Train the model with Adam\n", " opt=keras.optimizers.Adam(learning_rate=lr)\n", " early_stopping = EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True)\n", " print(\"Compiling...\")\n", " model.compile(\n", " optimizer=opt,\n", " loss='binary_crossentropy',\n", " metrics=['accuracy']\n", " )\n", " print(\"Fitting...\")\n", " model.fit(X_train, y_train,\n", " validation_data=(X_val, y_val),\n", " epochs=200,\n", " batch_size=64,\n", " verbose=0,\n", " callbacks=[early_stopping])\n", " \n", " if file_weights is not None:\n", " filename = f\"{file_weights}_{i}.keras\"\n", " model.save(filename)\n", " print(f\"Weights saved to {filename}\")\n", " \n", " # Predictions and metrics calculation\n", " y_pred_proba = model.predict(X_val).flatten()\n", " y_pred = (y_pred_proba > 0.5).astype(int)\n", " acc = accuracy_score(y_val, y_pred)\n", " precision = precision_score(y_val, y_pred)\n", " recall = recall_score(y_val, y_pred)\n", " f1 = f1_score(y_val, y_pred)\n", " roc_auc = roc_auc_score(y_val, y_pred_proba)\n", " metrics[\"accuracy\"].append(acc)\n", " metrics[\"precision\"].append(precision)\n", " metrics[\"recall\"].append(recall)\n", " metrics[\"f1\"].append(f1)\n", " metrics[\"roc_auc\"].append(roc_auc)\n", " print(f\" > Fold {i} validation ROC-AUC={roc_auc:.3f}\")\n", " models.append(model)\n", " return models, metrics\n", "\n", "temp_weights = tempfile.mkdtemp()\n", "models, metrics = stratified_kfold(G, input_nn, outputs_pkn, file_weights=os.path.join(temp_weights, \"weights\"))\n", "\n", "print(\"Validation metrics:\")\n", "for k, v in metrics.items():\n", " print(f\" - {k}: {np.mean(v):.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we will analyze the learned biases for each of the nodes in the graph. Note that authors in the KPNN paper explain a way to extract weights for the nodes, based on the learned interactions and accounting for biases in the structure of the NN. Here we just show the learned biases of the nodes of the NN across 5 folds. Please be careful interpreting these weights." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
biasabs_biaspow2_bias
foldgene
1DUSP3-2.6327342.6327346.931287
2PRKCD-2.1976722.1976724.829763
SUZ12.EZH2-2.1261502.1261504.520514
NfKb.p65.p50-2.0106822.0106824.042840
TCR1.9883341.9883343.953472
4ZAP70-1.9522611.9522613.811324
3YAP1-1.8720041.8720043.504399
4NfKb.p65.p50-1.8347471.8347473.366298
0PRC2-1.8266271.8266273.336568
4TCR1.8165161.8165163.299730
\n", "
" ], "text/plain": [ " bias abs_bias pow2_bias\n", "fold gene \n", "1 DUSP3 -2.632734 2.632734 6.931287\n", "2 PRKCD -2.197672 2.197672 4.829763\n", " SUZ12.EZH2 -2.126150 2.126150 4.520514\n", " NfKb.p65.p50 -2.010682 2.010682 4.042840\n", " TCR 1.988334 1.988334 3.953472\n", "4 ZAP70 -1.952261 1.952261 3.811324\n", "3 YAP1 -1.872004 1.872004 3.504399\n", "4 NfKb.p65.p50 -1.834747 1.834747 3.366298\n", "0 PRC2 -1.826627 1.826627 3.336568\n", "4 TCR 1.816516 1.816516 3.299730" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We collect the weights obtained in each fold\n", "\n", "def load_biases(file=\"weights\", folds=5):\n", " biases = []\n", " mean_inputs = []\n", " for i in range(5):\n", " model = keras.models.load_model(f\"{file}_{i}.keras\")\n", " for layer in model.layers:\n", " weights = layer.get_weights()\n", " if weights:\n", " biases.append((i, layer.name, weights[1][0]))\n", " mean_inputs.append((i, layer.name, weights[0].mean()))\n", " df_biases = pd.DataFrame(biases, columns=[\"fold\", \"gene\", \"bias\"])\n", " df_biases[\"abs_bias\"] = df_biases.bias.abs()\n", " df_biases[\"pow2_bias\"] = df_biases.bias.pow(2)\n", " df_biases = df_biases.set_index([\"fold\",\"gene\"])\n", " return df_biases\n", "\n", "df_biases = load_biases(file=os.path.join(temp_weights, \"weights\"), folds=5)\n", "df_biases.sort_values(by=\"abs_bias\", ascending=False).head(10)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "gene\n", "TCR 3.129825\n", "NfKb.p65.p50 2.217130\n", "PRKCD 2.162571\n", "PRC2 2.140556\n", "DUSP3 1.901882\n", " ... \n", "MLL2.complex 0.000641\n", "HSF2 0.000268\n", "GATA3 0.000145\n", "SRY 0.000138\n", "REST 0.000032\n", "Name: pow2_bias, Length: 308, dtype: float32" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_biases_full = df_biases.copy().reset_index()\n", "gene_biases_score = df_biases_full.groupby(\"gene\")[\"pow2_bias\"].mean().sort_values(ascending=False)\n", "gene_biases_score" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gene_biases_score.head(30).plot.bar()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "# Get the top genes sorted by mean bias\n", "top_genes = gene_biases_score.head(30).index\n", "\n", "# Filter the DataFrame to include only rows with these top genes\n", "filtered_df = df_biases_full[df_biases_full['gene'].isin(top_genes)]\n", "\n", "plt.figure(figsize=(10, 6))\n", "sns.violinplot(data=filtered_df, x=\"gene\", y=\"bias\", order=top_genes, dodge=True)\n", "sns.stripplot(data=filtered_df, x=\"gene\", y=\"bias\", order=top_genes, dodge=True)\n", "plt.xticks(rotation=90)\n", "plt.title('Biases of the trained neurons')\n", "plt.xlabel('Gene')\n", "plt.ylabel('Bias')\n", "plt.axhline(0, linestyle=\"--\", color=\"k\")\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use CORNETO for NN pruning\n", "\n", "Now we will show how CORNETO can be used to extract a smaller, yet complete DAG from the original PKN provided by the authors. We will add input edges to each input node and an output edge through TCR to indicate which nodes are the inputs and which one the output. We will use then Acyclic Flow to find the smallest DAG comprising these nodes" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(12767, 38388)\n", "===============================================================================\n", " CVXPY \n", " v1.6.0 \n", "===============================================================================\n", "(CVXPY) Dec 23 01:19:36 PM: Your problem has 127931 variables, 332945 constraints, and 0 parameters.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/pablorodriguezmier/miniforge3/envs/corneto/lib/python3.12/site-packages/cvxpy/problems/problem.py:158: UserWarning: Objective contains too many subexpressions. Consider vectorizing your CVXPY code to speed up compilation.\n", " warnings.warn(\"Objective contains too many subexpressions. \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(CVXPY) Dec 23 01:19:36 PM: It is compliant with the following grammars: DCP, DQCP\n", "(CVXPY) Dec 23 01:19:36 PM: (If you need to solve this problem multiple times, but with different data, consider using parameters.)\n", "(CVXPY) Dec 23 01:19:36 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.\n", "(CVXPY) Dec 23 01:19:36 PM: Your problem is compiled with the CPP canonicalization backend.\n", "-------------------------------------------------------------------------------\n", " Compilation \n", "-------------------------------------------------------------------------------\n", "(CVXPY) Dec 23 01:19:37 PM: Compiling problem (target solver=SCIP).\n", "(CVXPY) Dec 23 01:19:37 PM: Reduction chain: Dcp2Cone -> CvxAttr2Constr -> ConeMatrixStuffing -> SCIP\n", "(CVXPY) Dec 23 01:19:37 PM: Applying reduction Dcp2Cone\n", "(CVXPY) Dec 23 01:19:37 PM: Applying reduction CvxAttr2Constr\n", "(CVXPY) Dec 23 01:19:37 PM: Applying reduction ConeMatrixStuffing\n", "(CVXPY) Dec 23 01:20:56 PM: Applying reduction SCIP\n", "(CVXPY) Dec 23 01:20:56 PM: Finished problem compilation (took 7.962e+01 seconds).\n", "-------------------------------------------------------------------------------\n", " Numerical solver \n", "-------------------------------------------------------------------------------\n", "(CVXPY) Dec 23 01:20:56 PM: Invoking solver SCIP to obtain a solution.\n", "presolving:\n", "(round 1, fast) 72218 del vars, 263764 del conss, 0 add conss, 158463 chg bounds, 17010 chg sides, 17010 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 2, fast) 83111 del vars, 288686 del conss, 0 add conss, 158636 chg bounds, 17982 chg sides, 17982 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 3, fast) 83299 del vars, 288858 del conss, 0 add conss, 158667 chg bounds, 18157 chg sides, 18150 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 4, fast) 83323 del vars, 288959 del conss, 0 add conss, 158670 chg bounds, 18308 chg sides, 18242 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 5, fast) 87954 del vars, 288998 del conss, 0 add conss, 158670 chg bounds, 18310 chg sides, 18242 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 6, fast) 87954 del vars, 288998 del conss, 0 add conss, 169675 chg bounds, 18310 chg sides, 18242 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 7, fast) 87976 del vars, 289373 del conss, 0 add conss, 169675 chg bounds, 19713 chg sides, 29203 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 8, fast) 87993 del vars, 289470 del conss, 0 add conss, 169675 chg bounds, 19713 chg sides, 29203 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 9, exhaustive) 87994 del vars, 300376 del conss, 0 add conss, 169708 chg bounds, 19713 chg sides, 29203 chg coeffs, 0 upgd conss, 0 impls, 0 clqs\n", "(round 10, fast) 98924 del vars, 311359 del conss, 0 add conss, 169708 chg bounds, 19722 chg sides, 29210 chg coeffs, 0 upgd conss, 4004 impls, 0 clqs\n", "(round 11, fast) 99966 del vars, 312534 del conss, 0 add conss, 169708 chg bounds, 19722 chg sides, 29210 chg coeffs, 0 upgd conss, 4004 impls, 0 clqs\n", "(round 12, fast) 99978 del vars, 312581 del conss, 0 add conss, 169709 chg bounds, 19722 chg sides, 29210 chg coeffs, 0 upgd conss, 4004 impls, 0 clqs\n", "(round 13, exhaustive) 99980 del vars, 312589 del conss, 0 add conss, 169715 chg bounds, 19724 chg sides, 29212 chg coeffs, 9689 upgd conss, 4004 impls, 0 clqs\n", "(round 14, fast) 99988 del vars, 312611 del conss, 0 add conss, 169715 chg bounds, 19724 chg sides, 29212 chg coeffs, 9689 upgd conss, 14702 impls, 925 clqs\n", "(round 15, exhaustive) 99999 del vars, 312621 del conss, 0 add conss, 169718 chg bounds, 19726 chg sides, 29214 chg coeffs, 17411 upgd conss, 14704 impls, 908 clqs\n", "(round 16, medium) 100057 del vars, 312642 del conss, 0 add conss, 169718 chg bounds, 19726 chg sides, 29214 chg coeffs, 17411 upgd conss, 22425 impls, 1333 clqs\n", "(round 17, medium) 100059 del vars, 312756 del conss, 0 add conss, 169718 chg bounds, 19726 chg sides, 29214 chg coeffs, 17411 upgd conss, 22425 impls, 1333 clqs\n", "(round 18, exhaustive) 100754 del vars, 312760 del conss, 0 add conss, 169720 chg bounds, 19726 chg sides, 29214 chg coeffs, 17413 upgd conss, 22425 impls, 1333 clqs\n", "(round 19, fast) 100756 del vars, 312828 del conss, 0 add conss, 169720 chg bounds, 19726 chg sides, 29219 chg coeffs, 17413 upgd conss, 22427 impls, 1337 clqs\n", " (3.0s) probing: 1000/17563 (5.7%) - 0 fixings, 1 aggregations, 370 implications, 2 bound changes\n", " (3.0s) probing: 1003/17563 (5.7%) - 0 fixings, 1 aggregations, 370 implications, 2 bound changes\n", " (3.0s) probing aborted: 1000/1000 successive useless probings\n", " (3.0s) symmetry computation started: requiring (bin +, int +, cont +), (fixed: bin -, int -, cont -)\n", " (26.3s) symmetry computation finished: 1500 generators found (max: 1500, log10 of symmetry group size: 1750.0) (symcode time: 23.21)\n", "dynamic symmetry handling statistics:\n", " orbitopal reduction: 5 components: 3x3, 3x3, 3x3, 3x3, 4x3\n", " orbital reduction: 11 components of sizes 3, 5, 9, 4, 4, 4, 18, 17, 8, 6, 12\n", " lexicographic reduction: 105 permutations with support sizes 8, 6, 8, 6, 6, 8, 6, 6, 6, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6\n", "handled 45 out of 45 symmetry components\n", "(round 20, exhaustive) 100757 del vars, 312828 del conss, 1385 add conss, 169722 chg bounds, 19726 chg sides, 29219 chg coeffs, 17417 upgd conss, 22520 impls, 1612 clqs\n", "(round 21, exhaustive) 101738 del vars, 312830 del conss, 1385 add conss, 169722 chg bounds, 19726 chg sides, 29219 chg coeffs, 17417 upgd conss, 22528 impls, 1616 clqs\n", "(round 22, fast) 101738 del vars, 313811 del conss, 1385 add conss, 169722 chg bounds, 19726 chg sides, 29219 chg coeffs, 17417 upgd conss, 22528 impls, 1616 clqs\n", " (26.5s) probing: 1103/17563 (6.3%) - 0 fixings, 1 aggregations, 563 implications, 2 bound changes\n", " (26.5s) probing aborted: 1000/1000 successive useless probings\n", "presolving (23 rounds: 23 fast, 9 medium, 7 exhaustive):\n", " 101738 deleted vars, 313811 deleted constraints, 1385 added constraints, 169722 tightened bounds, 0 added holes, 19726 changed sides, 29219 changed coefficients\n", " 22621 implications, 1716 cliques\n", "presolved problem has 26193 variables (17562 bin, 0 int, 0 impl, 8631 cont) and 20519 constraints\n", " 16888 constraints of type \n", " 3323 constraints of type \n", " 308 constraints of type \n", "transformed objective value is always integral (scale: 1)\n", "Presolving Time: 26.29\n", "\n", " time | node | left |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr| dualbound | primalbound | gap | compl. \n", "p26.7s| 1 | 0 | 115 | - | locks| 0 | 26k| 20k| 20k| 0 | 0 | 0 | 0 | 2.066500e+04 | 2.744200e+04 | 32.79%| unknown\n", "i26.8s| 1 | 0 | 115 | - | oneopt| 0 | 26k| 20k| 20k| 0 | 0 | 24 | 0 | 2.066500e+04 | 2.707300e+04 | 31.01%| unknown\n", " 27.6s| 1 | 0 | 11865 | - | 661M | 0 | 26k| 20k| 20k| 0 | 0 | 24 | 0 | 2.097314e+04 | 2.707300e+04 | 29.08%| unknown\n", " 28.4s| 1 | 0 | 16422 | - | 671M | 0 | 26k| 20k| 21k|1624 | 1 | 24 | 0 | 2.258309e+04 | 2.707300e+04 | 19.88%| unknown\n", " 29.4s| 1 | 0 | 21208 | - | 674M | 0 | 26k| 20k| 21k|1857 | 2 | 24 | 0 | 2.279908e+04 | 2.707300e+04 | 18.75%| unknown\n", " 29.7s| 1 | 0 | 23243 | - | 676M | 0 | 26k| 20k| 22k|2002 | 3 | 24 | 0 | 2.292005e+04 | 2.707300e+04 | 18.12%| unknown\n", "r29.8s| 1 | 0 | 23243 | - |shifting| 0 | 26k| 20k| 22k|2002 | 3 | 24 | 0 | 2.292005e+04 | 2.515100e+04 | 9.73%| unknown\n", " 30.2s| 1 | 0 | 25780 | - | 680M | 0 | 26k| 20k| 22k|2120 | 4 | 24 | 0 | 2.302305e+04 | 2.515100e+04 | 9.24%| unknown\n", "i32.7s| 1 | 0 | 37528 | - | oneopt| 0 | 26k| 20k| 22k|2120 | 4 | 24 | 0 | 2.302305e+04 | 2.514600e+04 | 9.22%| unknown\n", "(node 1) unresolved numerical troubles in LP 10 -- using pseudo solution instead (loop 1)\n", " 35.8s| 1 | 2 | 49334 | - | 684M | 0 | 26k| 20k| 22k|2120 | 4 | 24 | 0 | 2.302305e+04 | 2.514600e+04 | 9.22%| unknown\n", "d72.5s| 68 | 69 | 98165 | 730.5 |veclendi| 11 | 26k| 20k| 22k| 0 | 1 | 24 |1210 | 2.302805e+04 | 2.511900e+04 | 9.08%| unknown\n", "o82.1s| 92 | 93 |122341 | 803.5 |rootsold| 13 | 26k| 20k| 22k|2233 | 1 | 24 |1568 | 2.302805e+04 | 2.509600e+04 | 8.98%| unknown\n", " 84.6s| 100 | 101 |127795 | 793.7 | 734M | 13 | 26k| 20k| 22k|2251 | 1 | 24 |1659 | 2.302805e+04 | 2.509600e+04 | 8.98%| unknown\n", "d92.0s| 125 | 126 |148749 | 802.7 |guideddi| 15 | 26k| 20k| 22k| 0 | 1 | 24 |1833 | 2.303005e+04 | 2.508900e+04 | 8.94%| unknown\n", "r 114s| 190 | 191 |204186 | 819.9 |ziroundi| 17 | 26k| 20k| 22k|2380 | 1 | 24 |2696 | 2.303205e+04 | 2.508700e+04 | 8.92%| unknown\n", " time | node | left |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr| dualbound | primalbound | gap | compl. \n", " 118s| 200 | 201 |209863 | 807.3 | 775M | 17 | 26k| 20k| 22k|2406 | 1 | 24 |2831 | 2.303304e+04 | 2.508700e+04 | 8.92%| unknown\n", "r 139s| 262 | 263 |255895 | 791.9 |ziroundi| 20 | 26k| 20k| 22k|2504 | 1 | 28 |3608 | 2.303304e+04 | 2.508500e+04 | 8.91%| unknown\n", "r 139s| 263 | 264 |255962 | 789.1 |ziroundi| 21 | 26k| 20k| 22k|2504 | 1 | 28 |3628 | 2.303304e+04 | 2.508400e+04 | 8.90%| unknown\n", "r 148s| 297 | 298 |269173 | 743.1 |ziroundi| 21 | 26k| 20k| 22k|2538 | 1 | 29 |4061 | 2.303404e+04 | 2.508400e+04 | 8.90%| unknown\n", " 150s| 300 | 301 |274671 | 754.0 | 818M | 21 | 26k| 20k| 22k|2538 | 1 | 30 |4103 | 2.303404e+04 | 2.508400e+04 | 8.90%| unknown\n", "r 154s| 317 | 318 |279506 | 728.8 |ziroundi| 21 | 26k| 20k| 22k|2555 | 1 | 30 |4323 | 2.303404e+04 | 2.508400e+04 | 8.90%| unknown\n", "r 155s| 322 | 323 |279561 | 717.6 |ziroundi| 21 | 26k| 20k| 22k|2555 | 1 | 30 |4365 | 2.303404e+04 | 2.508300e+04 | 8.90%| unknown\n", "r 171s| 390 | 391 |313739 | 680.0 |ziroundi| 21 | 26k| 20k| 22k|2602 | 1 | 35 |5079 | 2.303404e+04 | 2.508300e+04 | 8.90%| unknown\n", " 174s| 400 | 401 |320440 | 679.8 | 881M | 21 | 26k| 20k| 22k|2603 | 1 | 55 |5167 | 2.303404e+04 | 2.508300e+04 | 8.90%| unknown\n", "r 177s| 418 | 419 |326124 | 664.0 |ziroundi| 21 | 26k| 20k| 22k|2603 | 1 | 56 |5295 | 2.303404e+04 | 2.508300e+04 | 8.90%| unknown\n", " 202s| 500 | 501 |405291 | 713.6 | 909M | 21 | 26k| 20k| 22k|2734 | 1 | 65 |5922 | 2.303504e+04 | 2.508300e+04 | 8.89%| unknown\n", " 226s| 600 | 601 |476672 | 713.6 | 922M | 23 | 26k| 20k| 22k|2845 | 2 | 73 |6493 | 2.303504e+04 | 2.508300e+04 | 8.89%| unknown\n", "L 230s| 618 | 619 |481926 | 701.3 | rins| 23 | 26k| 20k| 22k|2853 | 1 | 73 |6520 | 2.303504e+04 | 2.508200e+04 | 8.89%| unknown\n", " 244s| 700 | 701 |519687 | 673.1 | 932M | 23 | 26k| 20k| 22k|2943 | 1 | 77 |6751 | 2.303505e+04 | 2.508200e+04 | 8.89%| unknown\n", " 258s| 800 | 801 |554187 | 632.0 | 936M | 25 | 26k| 20k| 22k|3050 | 2 | 77 |7000 | 2.303602e+04 | 2.508200e+04 | 8.88%| unknown\n", " time | node | left |LP iter|LP it/n|mem/heur|mdpt |vars |cons |rows |cuts |sepa|confs|strbr| dualbound | primalbound | gap | compl. \n", "r 271s| 860 | 861 |603796 | 645.6 |ziroundi| 25 | 26k| 20k| 22k|3093 | 1 | 78 |7085 | 2.303602e+04 | 2.508200e+04*| 8.88%| unknown\n", "r 272s| 874 | 875 |605526 | 637.2 |ziroundi| 25 | 26k| 20k| 22k|3093 | 1 | 78 |7096 | 2.303604e+04 | 2.508000e+04 | 8.87%| unknown\n", "r 272s| 887 | 888 |609332 | 632.2 |ziroundi| 25 | 26k| 20k| 22k|3102 | 1 | 78 |7102 | 2.303604e+04 | 2.508000e+04 | 8.87%| unknown\n", "r 273s| 888 | 889 |611391 | 633.8 |ziroundi| 25 | 26k| 20k| 22k|3102 | 1 | 78 |7102 | 2.303604e+04 | 2.507900e+04 | 8.87%| unknown\n", " 273s| 900 | 901 |613629 | 627.8 | 965M | 25 | 26k| 20k| 22k|3102 | 1 | 78 |7119 | 2.303604e+04 | 2.507900e+04 | 8.87%| unknown\n", " 290s| 1000 | 1001 |659345 | 610.7 | 971M | 25 | 26k| 20k| 22k|3208 | 1 | 81 |7495 | 2.303604e+04 | 2.507900e+04 | 8.87%| unknown\n", "r 298s| 1044 | 1045 |672493 | 597.6 |ziroundi| 27 | 26k| 20k| 22k|3251 | 1 | 81 |7598 | 2.303604e+04 | 2.507800e+04 | 8.86%| unknown\n", "Restart triggered after 50 consecutive estimations that the remaining tree will be large\n", "(run 1, node 1049) performing user restart\n", "\n", "(restart) converted 1743 cuts from the global cut pool into linear constraints\n", "\n", "presolving:\n", "(round 1, exhaustive) 0 del vars, 10 del conss, 0 add conss, 0 chg bounds, 0 chg sides, 0 chg coeffs, 1721 upgd conss, 22621 impls, 1903 clqs\n", "(round 2, medium) 0 del vars, 20 del conss, 10 add conss, 2 chg bounds, 0 chg sides, 2 chg coeffs, 1721 upgd conss, 22626 impls, 1905 clqs\n", "(round 3, exhaustive) 0 del vars, 45 del conss, 10 add conss, 2 chg bounds, 0 chg sides, 2 chg coeffs, 1721 upgd conss, 22626 impls, 1905 clqs\n", "presolving (4 rounds: 4 fast, 4 medium, 3 exhaustive):\n", " 0 deleted vars, 45 deleted constraints, 10 added constraints, 2 tightened bounds, 0 added holes, 0 changed sides, 2 changed coefficients\n", " 22626 implications, 1905 cliques\n", "presolved problem has 26193 variables (17562 bin, 0 int, 0 impl, 8631 cont) and 22285 constraints\n", " 16893 constraints of type \n", " 1599 constraints of type \n", " 3345 constraints of type \n", " 420 constraints of type \n", " 28 constraints of type \n", "transformed objective value is always integral (scale: 1)\n", "Presolving Time: 28.20\n", "transformed 100/100 original solutions to the transformed problem space\n", "\n", "\n", "SCIP Status : solving was interrupted [time limit reached]\n", "Solving Time (sec) : 301.66\n", "Solving Nodes : 0 (total of 1049 nodes in 2 runs)\n", "Primal Bound : +2.50779999999950e+04 (124 solutions)\n", "Dual Bound : +2.30360421900002e+04\n", "Gap : 8.86 %\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/pablorodriguezmier/miniforge3/envs/corneto/lib/python3.12/site-packages/cvxpy/problems/problem.py:1481: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-------------------------------------------------------------------------------\n", " Summary \n", "-------------------------------------------------------------------------------\n", "(CVXPY) Dec 23 01:26:03 PM: Problem status: optimal_inaccurate\n", "(CVXPY) Dec 23 01:26:03 PM: Optimal value: 2.508e+04\n", "(CVXPY) Dec 23 01:26:03 PM: Compilation took 7.962e+01 seconds\n", "(CVXPY) Dec 23 01:26:03 PM: Solver (including time spent in interface) took 3.072e+02 seconds\n" ] } ], "source": [ "G_dag = G.copy()\n", "new_edges = []\n", "for g in input_nn:\n", " new_edges.append(G_dag.add_edge((), g))\n", "new_edges.append(G_dag.add_edge(\"TCR\", ()))\n", "print(G_dag.shape)\n", "\n", "# Find small DAG. We use Acyclic Flow to find over the space of DAGs\n", "P = cn.opt.AcyclicFlow(G_dag)\n", "# We enforce that the input genes and the output gene are part of the solution\n", "P += P.expr.with_flow[new_edges] == 1\n", "# Minimize the number of active edges\n", "P.add_objectives(sum(P.expr.with_flow), weights=1)\n", "P.solve(solver=\"SCIP\", verbosity=1, max_seconds=300);" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "((12767, 38388), (12619, 25078))" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "G_subdag = G_dag.edge_subgraph(P.expr.with_flow.value > 0.5)\n", "G_dag.shape, G_subdag.shape" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KPNN edge compression (0-100%): 34.67%\n" ] } ], "source": [ "rel_dag_compression = (1 - (G_subdag.num_edges / G_dag.num_edges)) * 100\n", "print(f\"KPNN edge compression (0-100%): {rel_dag_compression:.2f}%\")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 12778\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_0.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 75ms/step\n", " > Fold 0 validation ROC-AUC=0.992\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 12778\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_1.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 73ms/step\n", " > Fold 1 validation ROC-AUC=0.983\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 12778\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_2.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 69ms/step\n", " > Fold 2 validation ROC-AUC=0.987\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 12778\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_3.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 68ms/step\n", " > Fold 3 validation ROC-AUC=0.994\n", "Building DAG NN model with CORNETO using Keras with JAX...\n", " > N. inputs: 12459\n", " > N. outputs: 1\n", " > N. parameters: 12778\n", "Compiling...\n", "Fitting...\n", "Weights saved to /var/folders/b4/gwkwsdb93sv11rtztqbm3l040000gn/T/tmpj5f_9545/pruned_weights_4.keras\n", "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 68ms/step\n", " > Fold 4 validation ROC-AUC=0.994\n", "Validation metrics:\n", " - accuracy: 0.961\n", " - precision: 0.971\n", " - recall: 0.949\n", " - f1: 0.960\n", " - roc_auc: 0.990\n" ] } ], "source": [ "pruned_models, pruned_metrics = stratified_kfold(G_subdag, input_nn, outputs_pkn, file_weights=os.path.join(temp_weights, \"pruned_weights\"))\n", "\n", "print(\"Validation metrics:\")\n", "for k, v in pruned_metrics.items():\n", " print(f\" - {k}: {np.mean(v):.3f}\")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
biasabs_biaspow2_bias
foldgene
1SUZ12.EZH2-2.8922452.8922458.365083
4SUZ12.EZH2-2.7365292.7365297.488593
2SUZ12.EZH2-2.5753482.5753486.632418
0SETDB1.NLK.CHD7-2.4760902.4760906.131024
4ZAP70-2.4321732.4321735.915468
3ZAP702.3256422.3256425.408611
0PRKCD-2.3151972.3151975.360137
DUSP3-2.2695322.2695325.150774
4SETDB1.NLK.CHD7-2.2199292.2199294.928085
3NFYA-2.1636132.1636134.681221
\n", "
" ], "text/plain": [ " bias abs_bias pow2_bias\n", "fold gene \n", "1 SUZ12.EZH2 -2.892245 2.892245 8.365083\n", "4 SUZ12.EZH2 -2.736529 2.736529 7.488593\n", "2 SUZ12.EZH2 -2.575348 2.575348 6.632418\n", "0 SETDB1.NLK.CHD7 -2.476090 2.476090 6.131024\n", "4 ZAP70 -2.432173 2.432173 5.915468\n", "3 ZAP70 2.325642 2.325642 5.408611\n", "0 PRKCD -2.315197 2.315197 5.360137\n", " DUSP3 -2.269532 2.269532 5.150774\n", "4 SETDB1.NLK.CHD7 -2.219929 2.219929 4.928085\n", "3 NFYA -2.163613 2.163613 4.681221" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_biases_pruned = load_biases(file=os.path.join(temp_weights, \"pruned_weights\"), folds=5)\n", "df_biases_pruned.sort_values(by=\"abs_bias\", ascending=False).head(10)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "gene\n", "SUZ12.EZH2 4.990975\n", "TCR 3.651254\n", "ZAP70 3.080308\n", "SETDB1.NLK.CHD7 2.942616\n", "NFYA 2.748323\n", " ... \n", "PBX1 0.007306\n", "NR2F1 0.005788\n", "ASH2L 0.002699\n", "BCL3 0.002370\n", "GATA3 0.000018\n", "Name: pow2_bias, Length: 160, dtype: float32" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_biases_prunedr = df_biases_pruned.copy().reset_index()\n", "gene_biases_score = df_biases_prunedr.groupby(\"gene\")[\"pow2_bias\"].mean().sort_values(ascending=False)\n", "gene_biases_score" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "gene_biases_score.head(30).plot.bar()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get the top genes sorted by mean bias\n", "top_genes = gene_biases_score.head(30).index\n", "\n", "# Filter the DataFrame to include only rows with these top genes\n", "filtered_df = df_biases_prunedr[df_biases_prunedr['gene'].isin(top_genes)]\n", "\n", "plt.figure(figsize=(10, 6))\n", "sns.violinplot(data=filtered_df, x=\"gene\", y=\"bias\", order=top_genes, dodge=True)\n", "sns.stripplot(data=filtered_df, x=\"gene\", y=\"bias\", order=top_genes, dodge=True)\n", "plt.xticks(rotation=90)\n", "plt.title('Biases of the trained neurons')\n", "plt.xlabel('Gene')\n", "plt.ylabel('Bias')\n", "plt.axhline(0, linestyle=\"--\", color=\"k\")\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter compression: 51.30%\n" ] } ], "source": [ "param_compression = (1-(pruned_models[0].count_params()/models[0].count_params())) * 100\n", "print(f\"Parameter compression: {param_compression:.2f}%\")" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Degradation in ROC-AUC after compression (positive = decrease in performance, negative = increase in performance): -0.04%\n" ] } ], "source": [ "perf_degradation = ((np.mean(metrics[\"roc_auc\"]) - np.mean(pruned_metrics[\"roc_auc\"])) / np.mean(metrics[\"roc_auc\"])) * 100\n", "print(f\"Degradation in ROC-AUC after compression (positive = decrease in performance, negative = increase in performance): {perf_degradation:.2f}%\")" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" }, "mystnb": { "execution_mode": "off" } }, "nbformat": 4, "nbformat_minor": 4 }