Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
cca4934
Update README.md
jethrocsau Apr 21, 2025
a5900b3
reorganized documents from forked repo
Apr 21, 2025
8a5dc53
format ref file & set framework
Apr 23, 2025
b7edd39
update load data utils
jethrocsau Apr 23, 2025
ee01335
Add files via upload
JosephLaiCY Apr 23, 2025
710c1b7
updated data utils
jethrocsau Apr 24, 2025
f588e49
Merge branch 'main' of https://github.com/jethrocsau/GNN-language-emb…
jethrocsau Apr 24, 2025
aff5af8
updated data_utils
jethrocsau Apr 24, 2025
f79cdf5
updated modules
Apr 24, 2025
2bb6179
creating superclass GraphAlign_e5
Apr 24, 2025
eb50c0e
include data preparation func
jethrocsau Apr 25, 2025
82da1fd
uncommented gcn
Apr 25, 2025
f7b4259
expanded start + datset
Apr 25, 2025
3b8e8c6
update save paths
jethrocsau Apr 25, 2025
0e424d1
generate and save graph alignment
Apr 25, 2025
049c9a3
Update generate_embedding.py
Apr 25, 2025
176d6f0
add load_graph
Apr 25, 2025
558b901
Update generate_embedding.py
Apr 25, 2025
fc6d77d
Update data_utils.py
1324fgg Apr 26, 2025
c062773
Here is the description for arxiv dataset and sampled mag dataset(300…
1324fgg Apr 26, 2025
bf23e8a
Add files via upload
1324fgg Apr 27, 2025
fc65a46
Add files via upload
1324fgg Apr 27, 2025
fed19f0
Add files via upload
1324fgg Apr 27, 2025
5f873c7
updated mag nodeidx2papers
jethrocsau Apr 27, 2025
ea7d3cb
Merge branch 'main' of https://github.com/jethrocsau/GNN-language-emb…
jethrocsau Apr 27, 2025
e8cdf89
Image of 200000 subgraph of mug dataset
1324fgg Apr 27, 2025
8cebab3
Merge branch 'main' of https://github.com/jethrocsau/GNN-language-emb…
jethrocsau Apr 27, 2025
2d0fcbc
updated for graphalign embeddings
jethrocsau Apr 27, 2025
7880bf8
updated
jethrocsau Apr 28, 2025
09e68d5
map_graph.py
1324fgg Apr 29, 2025
839fe03
update
jethrocsau Apr 29, 2025
4544169
Merge branch 'main' of https://github.com/jethrocsau/GNN-language-emb…
jethrocsau Apr 29, 2025
b5cc815
Create README.md
1324fgg Apr 29, 2025
b746443
Update README.md
1324fgg Apr 29, 2025
21eba53
Update README.md
1324fgg Apr 29, 2025
9787d19
Update README.md
1324fgg Apr 29, 2025
847dddf
Merge branch 'main' of https://github.com/jethrocsau/GNN-language-emb…
jethrocsau Apr 29, 2025
6955e46
graph train save to dir
jethrocsau Apr 29, 2025
8e252a7
debug fix
jethrocsau Apr 29, 2025
1edcd86
added batchify
jethrocsau Apr 29, 2025
0381490
debug
jethrocsau Apr 29, 2025
49d09f6
updated for pca normalizated & 3-layers GAT
jethrocsau Apr 30, 2025
b3cc8ae
debug and set argprase
jethrocsau Apr 30, 2025
e4b22ae
modified relu and drop out
jethrocsau Apr 30, 2025
4fd4109
update utils
jethrocsau May 2, 2025
0923f06
debug
jethrocsau May 2, 2025
0728d97
typo fix
jethrocsau May 2, 2025
8f5d317
adding epoch values
jethrocsau May 2, 2025
6393608
Add files via upload
JosephLaiCY May 4, 2025
9ea0904
organizing repo
jethrocsau May 6, 2025
fce8fb9
organizing
jethrocsau May 6, 2025
cbf306d
update readme
jethrocsau May 6, 2025
97441b3
Add files via upload
1324fgg May 9, 2025
52fad33
Delete dyf_graphsage_second_version.ipynb
1324fgg May 9, 2025
ae445f5
Add files via upload
1324fgg May 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
170 changes: 8 additions & 162 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,162 +1,8 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.tsv
data/.DS_Store
.DS_Store
fastmoe/
.DS_Store
/data/stark-mag
/data
processed/ogbn-arxiv_graphalign_embeddings.pt
23 changes: 23 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"python.terminal.activateEnvInCurrentTerminal": true,
"python.testing.pytestEnabled": true,
"files.autoSave": "afterDelay",
"python.languageServer": "Pylance",
"rewrap.wrappingColumn": 80,
"git.autofetch": true,
"diffEditor.renderSideBySide": true,
"diffEditor.ignoreTrimWhitespace": true,
"gitlens.currentLine.enabled": false,
"gitlens.hovers.enabled": false,
"gitlens.hovers.currentLine.over": "line",
"gitlens.codeLens.enabled": false,
"gitlens.defaultDateStyle": "absolute",
"cSpell.enabled": false,
"cSpell.language": "de,de-DE,en",
"files.trimTrailingWhitespace": true,
"files.insertFinalNewline": true,
"python.analysis.completeFunctionParens": true,
"vsintellicode.sql.completionsEnabled": false,
"githubIssues.issueBranchTitle": "feature/${issueNumber}_${sanitizedIssueTitle}",
"errorLens.messageEnabled": false
}
130 changes: 33 additions & 97 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,115 +1,51 @@
# GraphAlign: Pretraining One Graph Neural Network on Multiple Graphs via Feature Alignment
# Exploration into cross-domain task generalization of graphs with language embeddings

Paper link: [arxiv](https://arxiv.org/abs/2406.02953)
## ABSTRACT
Scaling and generalizing cross-domain tasks amongst graph datasets remains a challenge due to the variability in node features, edge-based relationships, and the inherit challenges for transfer-learning amongst graphs. The aim of this project is to explore the capabilities of using language embeddings to achieve task generalizations across different graph structures and build models that could learn cross-domain relationships. By evaluating the performance of joinly trained graph neural networks across different language embeddings, the authors evaluate the effectiveness of various encoding architectures. Contrary to expectations, the simpler word2vec achieved greater performance compared to the E5-Small-V2 and GraphAlign pretrained embeddings. Finally, the author discusses limitations and the conclusiveness of the study and discusses future research directions in unifying cross-domain graphs with scalable architecture.

## Dependencies

1. PyTorch >= v2.1.1 and CUDA >= 11.4 are recommended.
## OVERVIEW
This project provides tools for:
- Generating text embeddings for graph nodes using various language models
- Training Graph Neural Networks on these embeddings
- Evaluating model performance on node classification tasks
- The implementation supports multiple datasets (ogbn-arxiv, ogbn-mag, combined) and embedding methods (E5, GraphAlign).

2. [dgl](https://www.dgl.ai/pages/start.html) >= 0.7.2
### Requirements
- Install the required dependencies:

3. [localclustering](https://github.com/kfoynt/LocalGraphClustering) (optional for data preprocessing)

4. Run `bash setup.sh` to install necessary dependences, including [fmoe](https://github.com/zhan72/fastmoe).

5. You can use `wandb` to monitor the training process.

## Dataset Preprocessing

For Large scale graphs, before starting mini-batch training, you'll need to generate local clusters if you want to use local-clustering for training. To generate a local cluster, you should first install [localclustering](https://github.com/kfoynt/LocalGraphClustering) and then run the following command:
```bash
pip install -r requirements.txt
```

````python
python generate_data.py \
--data_save_path <path/to/data_dir> \
--device <gpu_id> \
--batch_size 512 \
--dataset_name ogbn-arxiv ogbn-products ogbn-papers100M FB15K237 Cora WN18RR
````
### Train GNNs using preprocessed datasets

And we also provide the pre-generated local clusters which can be downloaded [here](https://drive.google.com/drive/folders/1f736S0pl_ypmh---b_pM3U1tK0PfFV0x?usp=sharing) for usage.
This module currently supports multiple graph datasets, which can be passed as an argument --graph_idx. Two graph sampling methods are used during training process a node-batching ("batchify") function and a multi-hop sample. This can be toggled with --sample as well.

## Pretrained Model Download
```python
GRAPHS = ['combined_graph_pca.bin','graph0.bin','graph1.bin','pca_graph.bin']
```

You can download our pretrainded model [here](https://drive.google.com/drive/folders/1wpTE40SPVwysw8e30I-NJnZiI2Y-rADt?usp=drive_link) and run below code to eval our model.
```bash
# Eval download pretrained model (linear probe result)
bash scripts/evaluation.sh <gpu_id> <path/to/data> <path/to/download/gnn_ckpt>
# Eval download pretrained model (few-shot result)
bash scripts/few_shot_eval.sh <gpu_id> <path/to/data> <path/to/download/gnn_ckpt>
```
# Train using batch sampling
python graph_train.py --graph_idx 0

## Quick Start
# Train using multi-hop neighbor sampling
python graph_train.py --graph_idx 0 --sample

To reproduce individually pretraining results, run: (first param is device, second param is path/to/save/data)
```bash
bash scripts/individually_pretrain.sh 0 your/path/to/save/data
# Train using multi-hop neighbor sampling
python graph_train.py --graph_idx 0 --sample
```

To reproduce GraphAlign results,
```bash
# For GNN pretraining
# Multi-GPU training is supported. <gpu_ids> can be set as "0,1" or more gpus.
bash scripts/graphalign.sh <gpu_ids> <path/to/data>

# Evaluation after GNN pretraining checkpoint
bash scripts/evaluation.sh <gpu_id> <path/to/data> <path/to/gnn_ckpt>
```
## DATASET INFO

To reproduce few-shot results:
```bash
# Evaluate the pretraining GNN in few-shot classification
bash scripts/few_shot_eval.sh <gpu_id> <path/to/data> </path/to/gnn_ckpt>
```
<img width="589" alt="image" src="https://github.com/user-attachments/assets/b7926e35-7417-4b50-b45d-7c47cb92bc8e" />
<img width="584" alt="image" src="https://github.com/user-attachments/assets/98d8d5ea-8a00-409b-98f8-230c970fb8f1" />

## Experimental Results

- Linear probing results in unsupervised representation learning for node classification

| Method | Setting | ogbn-arxiv | ogbn-products | ogbn-papers100M | Avg. gain |
| --------- | ------------------------ | -------------- | -------------- | --------------- | --------- |
| MLP | supervised | 69.85±0.36 | 73.74±0.43 | 56.62±0.21 | - |
| GAT | supervised | 74.15±0.15 | 83.42±0.35 | 66.63±0.23 | - |
| GCN | supervised | 74.77±0.34 | 80.76±0.50 | 68.15±0.08 | - |
| SGC | supervised | 71.56±0.41 | 74.36±0.27 | 58.82±0.08 | - |
| BGRL | individually-pretrain | 72.98±0.14 | 80.45±0.16 | 65.40±0.23 | - |
| | vanilla jointly-pretrain | 69.00±0.08 | 81.11±0.27 | 63.93±0.22 | -1.60 |
| | **GraphAlign** | **73.20±0.20** | **80.79±0.45** | **65.62±0.14** | **+0.26** |
| GRACE | individually-pretrain | 73.33±0.19 | 81.91±0.27 | 65.59±0.13 | - |
| | vanilla jointly-pretrain | 72.10±0.18 | 81.96±0.34 | 65.54±0.18 | -0.41 |
| | **GraphAlign** | **73.69±0.26** | **81.90±0.19** | **65.61±0.17** | **+0.12** |
| GraphMAE | individually-pretrain | 72.35±0.12 | 81.69±0.11 | 65.68±0.28 | - |
| | vanilla jointly-pretrain | 71.98±0.24 | 82.36±0.19 | 65.92±0.13 | +0.18 |
| | **GraphAlign** | **72.97±0.22** | **82.51±0.18** | **66.08±0.18** | **+0.61** |
| GraphMAE2 | individually-pretrain | 73.10±0.11 | 82.53±0.17 | 66.28±0.10 | - |
| | vanilla jointly-pretrain | 71.28±0.25 | 80.05±0.35 | 64.28±0.33 | -2.10 |
| | **GraphAlign** | **73.56±0.26** | **82.93±0.42** | **66.39±0.14** | **+0.32** |

- Few-shot node classification results on ogbn-arxiv and Cora, and link classification results on FB15K237 and WN18RR. We report *m*-way-*k*-shot accuracy(%), i.e., 5-way for ogbn-arxiv, Cora, WN18RR and 20-way for FB15K237.

| Method | ogbn-arxiv 5-shot | ogbn-arxiv 1-shot | Cora 5-shot | Cora 1-shot | WN18RR 5-shot | WN18RR 1-shot | FB15K237 5-shot | FB15K237 1-shot |
| --------------------- | ----------------- | ----------------- | ----------- | ----------- | ------------- | ------------- | --------------- | --------------- |
| GPN | 50.53±3.07 | 38.58±1.61 | - | - | - | - | - | - |
| TENT | 60.83±7.45 | 45.62±10.70 | - | - | - | - | - | - |
| GLITTER | 56.00±4.40 | 47.12±2.73 | - | - | - | - | - | - |
| Prodigy | 61.09±5.85 | 48.23±6.18 | - | - | - | - | 74.92±6.03 | 55.49±6.88 |
| OFA | 61.45±2.56 | 50.20±4.27 | 48.76±2.65 | 34.04±4.10 | 46.32±4.18 | 33.86±3.41 | 82.56±1.58 | 75.39±2.86 |
| OFA-emb-only | 61.27±7.09 | 43.22±8.45 | 58.60±6.72 | 40.87±8.26 | 54.87±9.73 | 39.72±9.35 | 59.11±6.95 | 43.03±7.17 |
| | | | | | ||||
| **GraphAlign**(GraphMAE) | 81.93±6.22 | 65.02±10.62 | 74.49±6.43 | 55.55±9.86 | 60.19±10.31 | 45.08±10.55 | 79.92±5.54 | 63.01±7.29 |
| **GraphAlign**(GraphMAE2) | 83.97±5.85 | 70.65±10.45 | 73.66±6.75 | 56.87±9.98 | 55.95±10.49 | 42.22±10.04 | 79.86±5.53 | 63.56±7.31 |
| **GraphAlign**(GRACE) | 84.76±5.71 | 71.18±10.29 | 69.85±7.19 | 52.60±10.10 | 53.11±10.24 | 39.58±9.42 | 75.04±5.98 | 60.09±7.36 |
| **GraphAlign**(BGRL) | 81.88±6.26 | 66.31±10.63 | 68.13±6.84 | 50.19±9.49 | 51.97±10.66 | 38.72±9.77 | 77.74±5.87 | 61.48±7.44 |
| E5-emb-only | 65.67±7.02 | 47.13±8.68 | 59.71±6.71 | 41.58±8.11 | 56.52±9.65 | 41.53±9.36 | 58.43±6.94 | 42.06±7.11 |

## Citing
If you find this work is helpful to your research, please consider citing our paper:

```latex
@article{hou2024graphalign,
title={GraphAlign: Pretraining One Graph Neural Network on Multiple Graphs via Feature Alignment},
author={Hou, Zhenyu and Li, Haozhan and Cen, Yukuo and Tang, Jie and Dong, Yuxiao},
journal={arXiv preprint arXiv:2406.02953},
year={2024}
}
```
### baseline of original embedding in paper obg:

<img width="418" alt="image" src="https://github.com/user-attachments/assets/7f4ddee9-1fc7-4052-ad7c-e54f0ded301f" />
<img width="385" alt="image" src="https://github.com/user-attachments/assets/4f12edb2-fecb-473e-a800-4174ebb20b9b" />

<img width="442" alt="image" src="https://github.com/user-attachments/assets/82e65470-45da-4cf4-9b42-0e92e8c05b37" />
19 changes: 19 additions & 0 deletions batch jobs/parallel-train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

#SBATCH -J graph
#SBATCH -t 12:00:00
#SBATCH --mail-user=csauac@connect.ust.hk
#SBATCH --mail-type=begin
#SBATCH --mail-type=end
#SBATCH -p normal
#SBATCH --nodes=3 --gpus-per-node=1
#SBATCH --account=mscbdt2024
#SBATCH --output=job-%j.out
#SBATCH --error=job-%j.err

cd msbd5008
conda activate graph
module load slurm 'nvhpc-hpcx-cuda12/23.11'
srun -n1 -N1 --gpus-per-node=1 python graph_train.py --graph_idx 0
srun -n1 -N1 --gpus-per-node=1 python graph_train.py --graph_idx 1
srun -n1 -N1 --gpus-per-node=1 python graph_train.py --graph_idx 2
17 changes: 17 additions & 0 deletions batch jobs/train_0.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

#SBATCH -J graph
#SBATCH -t 12:00:00
#SBATCH --mail-user=csauac@connect.ust.hk
#SBATCH --mail-type=begin
#SBATCH --mail-type=end
#SBATCH -p normal
#SBATCH --nodes=1 --gpus=1
#SBATCH --account=mscbdt2024
#SBATCH --output=job-%j.out
#SBATCH --error=job-%j.err

cd msbd5008
conda activate graph
module load slurm 'nvhpc-hpcx-cuda12/23.11'
python graph_train.py --graph_idx 0
17 changes: 17 additions & 0 deletions batch jobs/train_1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

#SBATCH -J graph
#SBATCH -t 12:00:00
#SBATCH --mail-user=csauac@connect.ust.hk
#SBATCH --mail-type=begin
#SBATCH --mail-type=end
#SBATCH -p normal
#SBATCH --nodes=1 --gpus-per-node=1
#SBATCH --account=mscbdt2024
#SBATCH --output=job-%j.out
#SBATCH --error=job-%j.err

cd msbd5008
conda activate graph
module load slurm 'nvhpc-hpcx-cuda12/23.11'
python graph_train.py --graph_idx 1
Loading