A model for graph-supported perturbation prediction with transcriptomic data.
curl -LsSf https://astral.sh/uv/install.sh | shgit clone git@github.com:valence-labs/TxPert.git
cd TxPertuv venv --python 3.12
source .venv/bin/activateMake sure you are on a system with compatible NVIDIA GPUs and CUDA drivers (typically a compute node). If using environment modules, load CUDA 12.4:
module load CUDA/12.4.0Install the project (this step will install all the dependencies)
# Assuming you are in the TxPert folder
uv pip install -e .Note: The default uv cache dir is pointing to user's local home directory, which causes uv to fail hard-linking files (as partitions differ when using compute nodes). For enhanced performance, you can set the UV_CACHE_DIR variable to point to current folder (in the shell or in the slurm script). In interactive sessions, it would be something like:
export UV_CACHE_DIR=<PROJECT-DIRT>/.cacheData should be automatically downloaded when running the code. The appropriate data will be downloaded to the cache folder in the current working directory.
The single cell line adata cache is derived from the raw K562 essential gene perturbation dataset provided by Replogle et al. (2022). To ensure comparability with prior work, we follow the same preprocessing pipeline used in GEARS (Roohani et al., 2024) and CausalBench (Chevalley et al., 2022). This preprocessing includes:
-
Filtering to retain strong perturbation signals.
-
Library size normalization (fixed to 4000 counts).
-
log1p transformation.
-
Highly variable gene selection (top 5000 genes).
The data processing was for consistently and comparability, tightly based on that provided by the GEARS authors in this notebook. Additional discussion can be found in this Github issue.
For convenience, there is also a repository-internal onboarding example under [data/onboard_dataset.ipynb].
To utilize newly onboarded data, a few parameters must be set (in config or as command line arguments), namely
datamodule.task_typename of folder under./cachewhere the new data and task-defining splits were saveddatamodule.train_cell_typesstr of cell line/type, or listdatamodule.test_cell_typestr with cell line/type- (if adding data with a cell line not in 'K562', 'RPE1', 'hepg2', 'jurkat')
datamodule.suppress_cell_type_validation=true
The cross cell line adata cache includes perturbation data from four different cell lines:
-
K562 and RPE1 from Replogle et al. (2022)
-
Jurkat and HepG2 from Nadig et al. (2025)
-
K562_adamson from Adamson et al. (2016)
Each cell line is initially processed independently using the same procedure as for the single-cell line cache, except that highly variable gene (HVG) selection is deferred. After individual processing, the datasets are concatenated, and HVGs are selected based on the intersection of:
-
HVGs from the test cell line, and
-
All genes present in the training cell lines
This ensures a consistent and comparable gene feature space across datasets.
For each cell line, we also perform differential gene expression analysis. Perturbed cells are grouped by perturbation target, and unperturbed cells are used as the reference group. The results of this analysis are stored as metadata under the uns attribute of each adata object.
We provide datasets and code here to reproduce the experiments shown in the paper for two OOD tasks:
- Transfer to unseen perturbations within cell type
- Transfer to an unseen cell line
We focus on the K562 cell type here as well as models that only rely on public data sources for training.
To run inference for prediction of perturbation effects on unseen perturbations with one of our pretrained models, use
python main.py --config-name=config-[model_type]
Available models (model_type) are gat, exphormer and exphormer-mg. Note that we provide only a subset of models reported in the paper as most of our best models rely on proprietary knowledge graphs (KGs). In particular, the exphormer-mg reported here uses two KGs (STRINGdb and GO) as an example of multi-graph (MG) reasoning. However, the model reported in the paper uses two additional proprietary KGs and achieves higher performance.
To run the same experiment for transfer to an unseen cell type, use
python main.py --config-name=config-x-cell-gat
All provided configs are found in the folder configs/, and can be updated or added to
as desired for new defaults or common usage patterns.
Parameters can also be set on the command line. E.g. python main.py --config-name=config-gat datamodule.batch_size=32 will run slower, but use less gpu memory.
Parameters from configs or commandline are handled via hydra.
As TxPert Inference with a Checkpoint, except that it saves the predictions to disk, instead of evaluating them.
Call with e.g.
python main.py --config-name=gat mode=predict
The location of the results can be configured with the save_dir parameter
(default in the cache/prediction folder, with a subfolder of a name matching the task_type).
This saves the data both in a format near to how the model processes is it, i.e. as the pickle file test_results.pkl with a dictionary including the following keys
- 'pert_idxs': perturbation indices in graph
- 'base_state': normalized control expression
- 'output': predicted normalized perturbed expression
- 'ground_truth': ground truth normalized perturbed expresion
- 'pert_cond_names': string indicating the full cell type and perturbation to condition on, e.g. 'K562_SMG5+ctrl_1+1'
- 'cell_types': cell type or line as appropriate
- 'experimental_batches': batch id (e.g. gem group, plate, etc)
For convenience and familiarity, it also saves anndata files for predictions, controls, and ground state. Note that these generally don't match observation ordering in the input file (it's a subsetset of the total, test only); hence exporting all three for comparability.
These files include the 'pert_cond_names', 'cell_types', and 'experimental_batches' from above, and additionally have a boolean 'control' column to distinguish control from non control cells as in the input.
We provide a "general" baseline that combines the mean control expression in the test cell-type with the mean perturbation-specific delta observed across training cell-types. If the perturbation is not in the training set, the global delta is used instead. To apply this baseline for the prediction of unseen perturbation effects, run the following command:
python main.py --config-name=config-baseline
For the prediction of perturbation effects in unseen cell types, use the option --config-name=config-x-cell-baseline instead.
A second baseline characterizing the experimental reproducibility of the experiment is available using the following command:
python main.py --config-name=config-baseline model.model_type=experimental_baseline +model.test_seen=0.5
This will estimate the perturbation-wise mean across samples from a random subsample (here using 50% of the samples per perturbation).
- May 2025: initial release
This work is licensed under Recursion Non-Commercial End User License Agreement