Skip to content

LXXXXR/AESL

Repository files navigation

Getting Your LLMs Ready for Reinforcement Learning with Lightweight SFT

This is the implementation of our paper "Getting Your LLMs Ready for Reinforcement Learning with Lightweight SFT" in ICLR 2026.

Getting Started

Create Conda Environment

Install Python environment with conda:

conda create -n AESL python=3.8
conda activate AESL
pip install -r requirements.txt

Acknowledgement

The code is implement based on the following open-source projects

Please refer to those repo for more documentation.

Supervised Finetuning (SFT)

deepspeed --module openrlhf.cli.train_sft \
   --max_len 20_000 \
   --dataset_class SFTFixDataset \
   --dataset $DATASET_PATH \
   --input_key prompt_message \
   --output_key response_message \
   --apply_chat_template \
   --drop_samples_total_length_exceeding_max_length \
   --train_batch_size 64 \
   --micro_train_batch_size 1 \
   --seed 42 \
   --max_samples 500_000 \
   --pretrain Qwen/Qwen2.5-7B-Instruct \
   --save_path $SAVE_PATH \
   --save_steps -1 \
   --logging_steps 1 \
   --eval_steps -1 \
   --zero_stage 3 \
   --max_epochs 10 \
   --loss_type "AESL_loss" \
   --temperature_scaling 5.0 \
   --packing_samples \
   --bf16 \
   --flash_attn \
   --learning_rate 5e-6 \
   --lr_scheduler "cosine_with_min_lr" \
   --adam_betas 0.9 0.95 \
   --l2 0.1 \
   --lr_warmup_ratio 0.01 \
   --gradient_checkpointing \

Reinforcement Learning (RL)

# launch the master node of ray in container
ray start --head --node-ip-address 0.0.0.0 --num-gpus 8

# if you want to launch ray on more nodes, use
ray start --address {MASTER-NODE-ADDRESS}:6379  --num-gpus 8

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json='{"working_dir": "/openrlhf"}' \
   -- python3 -m openrlhf.cli.train_ppo_ray \
   --ref_num_nodes 1 \
   --ref_num_gpus_per_node 8 \
   --reward_num_nodes 1 \
   --reward_num_gpus_per_node 8 \
   --critic_num_nodes 1 \
   --critic_num_gpus_per_node 8 \
   --actor_num_nodes 1 \
   --actor_num_gpus_per_node 8 \
   --vllm_num_engines 2 \
   --vllm_tensor_parallel_size 4 \
   --colocate_all_models \
   --vllm_gpu_memory_utilization 0.5 \
   --pretrain $PRETRAIN_CKPT \
   --remote_rm_url openrlhf/models/math_verifier/verifier_from_gt.py \
   --save_path $SAVE_PATH/ckpts/final \
   --ckpt_path $SAVE_PATH/ckpts \
   --save_hf_ckpt \
   --micro_train_batch_size 2 \
   --train_batch_size 128 \
   --micro_rollout_batch_size 4 \
   --rollout_batch_size 128 \
   --n_samples_per_prompt 8 \
   --temperature 1.0 \
   --max_epochs 1 \
   --num_episodes 1 \
   --prompt_max_len 1024 \
   --max_samples 100000 \
   --generate_max_len 8192 \
   --zero_stage 3 \
   --bf16 \
   --actor_learning_rate 1e-6 \
   --init_kl_coef 0.001 \
   --gamma 1.0 \
   --use_kl_loss \
   --kl_estimator k3 \
   --advantage_estimator group_norm \
   --prompt_data "$PROMPT_DATA" \
   --input_key ds_prompt \
   --label_key gt_answer \
   --apply_chat_template \
   --eval_temperature 0.6 \
   --eval_n_samples_per_prompt 16 \
   --normalize_reward \
   --gradient_checkpointing \
   --packing_samples \
   --flash_attn \
   --vllm_sync_backend nccl \
   --enforce_eager \
   --vllm_enable_sleep \
   --deepspeed_enable_sleep \
   --use_wandb {wandb_token}

Citing

If you use this code in your research or find it helpful, please consider citing our paper:

@inproceedings{ligetting,
  title={Getting Your LLMs Ready for Reinforcement Learning with Lightweight SFT},
  author={Li, Xinran and Huzhang, Guangda and Shen, Siqi and Chen, Qing-guo and Xu, Zhao and Luo, Weihua and Zhang, Kaifu and Zhang, Jun},
  booktitle={The Fourteenth International Conference on Learning Representations (ICLR)},
  year={2026}
}

About

[ICLR' 26] The PyTorch implementation of our paper: "Getting Your LLMs Ready for Reinforcement Learning with Lightweight SFT".

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages