CLAX is a modular framework to build click models with gradient-based optimization in JAX and Flax NNX. CLAX is built to be fast, providing orders of magnitudes speed-up compared to classic EM-based frameworks, such as PyClick, by leveraging auto-diff and vectorized computations on GPUs.
The current documentation is available here and our pre-print here.
CLAX requires JAX. For installing JAX with CUDA support, please refer to the JAX documentation. CLAX itself is available via pypi:
pip install clax-models
CLAX is designed with sensible defaults, while also allowing for a high-level of customization. E.g., training a User Browsing Model in CLAX is as simple as:
from clax import Trainer, UserBrowsingModel
from flax import nnx
from optax import adamw
model = UserBrowsingModel(
query_doc_pairs=100_000_000, # Number of query-document pairs in the dataset
positions=10, # Number of ranks per result page
rngs=nnx.Rngs(42), # NNX random number generator
)
trainer = Trainer(
optimizer=adamw(0.003),
epochs=50,
)
train_df = trainer.train(model, train_loader, val_loader)
test_df = trainer.test(model, test_loader)However, the modular design of CLAX also allows for more complex models from two-tower models, mixture models, or plugging-in custom FLAX modules as model parameters. We provide usage examples for getting started under examples/.
In the following, we cover how to reproduce the experiments from our paper or how to set up a fork of CLAX for development.
-
Install the UV package manager
UV is a fast Python dependency manager. Install it from: https://github.com/astral-sh/uv -
Clone the CLAX repository
git clone git@github.com:philipphager/clax.git
cd clax/- Install dependencies
uv syncThis creates a virtual environment and installs all required dependencies.
Our paper's experiments are located in the experiments/ directory. Each experiment contains:
- A Python script with the experiment logic:
main.py - A Hydra config file for configuration management:
config.yaml - A bash script with all experimental configurations:
main.sh
To run an experiment, follow these steps.
- Install experiment dependencies Installs additional packages for SLURM support and data analysis/plotting.
uv sync --group experiments- Download datasets
Clone the Yandex and Baidu-ULTR datasets from HuggingFace. If you have GIT LFS installed, clone the datasets using:
git lfs install
git clone https://huggingface.co/datasets/philipphager/clax-datasetsOtherwise, download the datasets manually from HuggingFace.
Note: The full datasets require 85GB of disk space. By default, CLAX expects datasets at ./clax-datasets/ relative to the project root. To use a custom path, update the dataset_dir parameter in each experiment's config.yaml:
dataset_dir: /my/custom/path/to/datasets/- Run an experiment script
Navigate to your experiment of interest and run the bash script, e.g.:
cd experiments/1-yandex-baseline/
chmod +x ./main.sh
./main.shOptionally, you can run the script directly on a SLURM cluster using:
sbatch ./main.sh +launcher=slurmYou can adjust the SLURM configuration to your cluster under: experiments/config/slurm.yaml
Baseline experiments using PyClick require the PyPy interpreter and are maintained in a separate repository: https://github.com/philipphager/clax-baselines
If CLAX is useful to you, please consider citing our paper:
@misc{hager2025clax,
title = {CLAX: Fast and Flexible Neural Click Models in JAX},
author = {Philipp Hager and Onno Zoeter and Maarten de Rijke},
year = {2025},
booktitle = {arxiv}
}