Skip to content

A repository for reproducing experiments from the TxPert paper

Notifications You must be signed in to change notification settings

valence-labs/TxPert

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TxPert

A model for graph-supported perturbation prediction with transcriptomic data.

Installation

Install UV

curl -LsSf https://astral.sh/uv/install.sh | sh

Clone the repo

git clone git@github.com:valence-labs/TxPert.git
cd TxPert

Create and activate a Python 3.12 environment:

uv venv --python 3.12
source .venv/bin/activate

Install dependencies:

Make sure you are on a system with compatible NVIDIA GPUs and CUDA drivers (typically a compute node). If using environment modules, load CUDA 12.4:

module load CUDA/12.4.0

Install the project (this step will install all the dependencies)

# Assuming you are in the TxPert folder
uv pip install -e .

Note: The default uv cache dir is pointing to user's local home directory, which causes uv to fail hard-linking files (as partitions differ when using compute nodes). For enhanced performance, you can set the UV_CACHE_DIR variable to point to current folder (in the shell or in the slurm script). In interactive sessions, it would be something like:

 export UV_CACHE_DIR=<PROJECT-DIRT>/.cache

Data Download

Data should be automatically downloaded when running the code. The appropriate data will be downloaded to the cache folder in the current working directory.

Data Description

Single Cell Line Cache

The single cell line adata cache is derived from the raw K562 essential gene perturbation dataset provided by Replogle et al. (2022). To ensure comparability with prior work, we follow the same preprocessing pipeline used in GEARS (Roohani et al., 2024) and CausalBench (Chevalley et al., 2022). This preprocessing includes:

  • Filtering to retain strong perturbation signals.

  • Library size normalization (fixed to 4000 counts).

  • log1p transformation.

  • Highly variable gene selection (top 5000 genes).

The data processing was for consistently and comparability, tightly based on that provided by the GEARS authors in this notebook. Additional discussion can be found in this Github issue.

For convenience, there is also a repository-internal onboarding example under [data/onboard_dataset.ipynb].

To utilize newly onboarded data, a few parameters must be set (in config or as command line arguments), namely

  • datamodule.task_type name of folder under ./cache where the new data and task-defining splits were saved
  • datamodule.train_cell_types str of cell line/type, or list
  • datamodule.test_cell_type str with cell line/type
  • (if adding data with a cell line not in 'K562', 'RPE1', 'hepg2', 'jurkat') datamodule.suppress_cell_type_validation=true

Cross Cell Line Cache

The cross cell line adata cache includes perturbation data from four different cell lines:

Each cell line is initially processed independently using the same procedure as for the single-cell line cache, except that highly variable gene (HVG) selection is deferred. After individual processing, the datasets are concatenated, and HVGs are selected based on the intersection of:

  • HVGs from the test cell line, and

  • All genes present in the training cell lines

This ensures a consistent and comparable gene feature space across datasets.

Differential Expression Analysis

For each cell line, we also perform differential gene expression analysis. Perturbed cells are grouped by perturbation target, and unperturbed cells are used as the reference group. The results of this analysis are stored as metadata under the uns attribute of each adata object.

Experiments

We provide datasets and code here to reproduce the experiments shown in the paper for two OOD tasks:

  1. Transfer to unseen perturbations within cell type
  2. Transfer to an unseen cell line

We focus on the K562 cell type here as well as models that only rely on public data sources for training.

TxPert Inference with a Checkpoint

To run inference for prediction of perturbation effects on unseen perturbations with one of our pretrained models, use

python main.py --config-name=config-[model_type] 

Available models (model_type) are gat, exphormer and exphormer-mg. Note that we provide only a subset of models reported in the paper as most of our best models rely on proprietary knowledge graphs (KGs). In particular, the exphormer-mg reported here uses two KGs (STRINGdb and GO) as an example of multi-graph (MG) reasoning. However, the model reported in the paper uses two additional proprietary KGs and achieves higher performance.

To run the same experiment for transfer to an unseen cell type, use

python main.py --config-name=config-x-cell-gat

All provided configs are found in the folder configs/, and can be updated or added to as desired for new defaults or common usage patterns.

Parameters can also be set on the command line. E.g. python main.py --config-name=config-gat datamodule.batch_size=32 will run slower, but use less gpu memory.

Parameters from configs or commandline are handled via hydra.

TxPert Predictions with a Checkpoint

As TxPert Inference with a Checkpoint, except that it saves the predictions to disk, instead of evaluating them.

Call with e.g.

python main.py --config-name=gat mode=predict

The location of the results can be configured with the save_dir parameter (default in the cache/prediction folder, with a subfolder of a name matching the task_type).

This saves the data both in a format near to how the model processes is it, i.e. as the pickle file test_results.pkl with a dictionary including the following keys

  • 'pert_idxs': perturbation indices in graph
  • 'base_state': normalized control expression
  • 'output': predicted normalized perturbed expression
  • 'ground_truth': ground truth normalized perturbed expresion
  • 'pert_cond_names': string indicating the full cell type and perturbation to condition on, e.g. 'K562_SMG5+ctrl_1+1'
  • 'cell_types': cell type or line as appropriate
  • 'experimental_batches': batch id (e.g. gem group, plate, etc)

For convenience and familiarity, it also saves anndata files for predictions, controls, and ground state. Note that these generally don't match observation ordering in the input file (it's a subsetset of the total, test only); hence exporting all three for comparability.

These files include the 'pert_cond_names', 'cell_types', and 'experimental_batches' from above, and additionally have a boolean 'control' column to distinguish control from non control cells as in the input.

General Baseline

We provide a "general" baseline that combines the mean control expression in the test cell-type with the mean perturbation-specific delta observed across training cell-types. If the perturbation is not in the training set, the global delta is used instead. To apply this baseline for the prediction of unseen perturbation effects, run the following command:

python main.py --config-name=config-baseline

For the prediction of perturbation effects in unseen cell types, use the option --config-name=config-x-cell-baseline instead.

Experimental Reproducibility

A second baseline characterizing the experimental reproducibility of the experiment is available using the following command:

python main.py --config-name=config-baseline model.model_type=experimental_baseline +model.test_seen=0.5

This will estimate the perturbation-wise mean across samples from a random subsample (here using 50% of the samples per perturbation).

Changelog

  • May 2025: initial release

License

This work is licensed under Recursion Non-Commercial End User License Agreement

About

A repository for reproducing experiments from the TxPert paper

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 5