This repository contains code for training a GNN model (GIN) using PyTorch Geometric and generating extended persistence diagrams from graph datasets with giotto-deep.
Follow the steps below to set up the required environment:
git clone https://github.com/sehunfromdaegu/xpert.git
cd xpert
conda create --name xpert python=3.9
conda activate xpert
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
pip install -r requirements.txt
python -m pip install -U giotto-deep
conda install pyg=2.5.2 -c pygIf you encounter any errors related to scikit-learn, you can reinstall the correct version as follows:
pip uninstall scikit-learn
pip install scikit-learn==1.1.1There are issues with specific versions of PyTorch Geometric and giotto-deep. To resolve these, apply the following fixes:
Locate the library path by running:
pip show torch_geometricThe library path will be listed under Location: PATH_TO_LIBS.
Open the file /PATH_TO_LIBS/torch_geometric/io/fs.py and modify line 193 as follows:
fs1.mv(path1, path2, recursive) # Original
fs1.mv(path1, path2) # Updated- Open the file /PATH_TO_LIBS/gdeep/data/datasets/persistence_diagrams_from_builder.py
- Modify line 206 to correct the graph labeling logic:
labels = (np.loadtxt(graph_labels, delimiter=",", dtype=np.int32).T + 1) // 2 # Original
labels = np.loadtxt(graph_labels, delimiter=",", dtype=np.int32).T # UpdatedTo perform classification on graph datasets, use the following command:
python graph_classification.py --dataname <dataset> --model <modelname>-
dataset: Choose from the available options:- 'IMDB-BINARY'
- 'IMDB-MULTI'
- 'MUTAG'
- 'PROTEINS'
- 'COX2'
- 'DHFR'
-
modelname: Specify the model to be used:- 'xpert' (Extended Persistence Transformer)
- 'gin' (Graph Isomorphism Network)
- 'gin_assisted_concat' (GIN + xPerT by concat representations)
- 'gin_assisted_sum' (GIN + xPerT by summing representations)
For example, to train the xPerT model on the MUTAG dataset, run:
python graph_classification.py --dataname MUTAG --model xpertTo perform classification on ORBIT5K datasets, use the following command:
python orbit_classification.pyTo perform classification on ORBIT5K datasets, use the following command:
python orbit_classification.py --samples_per_class 20000