Skip to content

fix bugs #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions benchmarking/pyg_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def create_parser():


def get_dataset(download_path=None):
dataset = PygNodePropPredDataset(
name="ogbn-products",
root=input_dir,
transform=T.NormalizeFeatures(),
)
# dataset = PygNodePropPredDataset(
# name="ogbn-products",
# root=input_dir,
# transform=T.NormalizeFeatures(),
# )

dataset = Reddit(root=download_path, transform=T.NormalizeFeatures())
gcn_norm = T.GCNNorm()
return (gcn_norm.forward(dataset[0]), dataset.num_classes)

Expand Down
2 changes: 1 addition & 1 deletion examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, num_gcn_layers, input_size, hidden_size, output_size):

self.num_gcn_layers = num_gcn_layers

self.layers = []
self.layers = torch.nn.ModuleList()
for i in range(self.num_gcn_layers):
if i == 0:
self.layers.append(GCNConv(input_size, hidden_size, i))
Expand Down
33 changes: 19 additions & 14 deletions plexus/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,25 @@ def forward(ctx, logits, target, num_layers, num_nodes, num_classes):
) | (target < 0)
softmax[invalid_nodes, :] = 0.0

# create mask for classes that are outside the local range of classes
invalid_logits_mask = (target < (ranks[1] * logits.shape[1])) | (
target >= ((ranks[1] + 1) * logits.shape[1])
)

# convert from global label to local label
target[~invalid_logits_mask] -= ranks[1] * logits.shape[1]
target[invalid_logits_mask] = 0

# create one hot vector from the labels
target = F.one_hot(target, num_classes=logits.shape[1])

# for labels out of the local range, make the target vector 0
target[invalid_logits_mask] = 0
# # create mask for classes that are outside the local range of classes
# invalid_logits_mask = (target < (ranks[1] * logits.shape[1])) | (
# target >= ((ranks[1] + 1) * logits.shape[1])
# )

# # convert from global label to local label
# target[~invalid_logits_mask] -= ranks[1] * logits.shape[1]
# target[invalid_logits_mask] = 0

# # create one hot vector from the labels
# target = F.one_hot(target, num_classes=logits.shape[1])

# # for labels out of the local range, make the target vector 0
# target[invalid_logits_mask] = 0

target = F.one_hot(target, num_classes=(logits.shape[1] * num_gpus[1]))
target = target[
:, (ranks[1] * logits.shape[1]) : ((ranks[1] + 1) * logits.shape[1])
]

# save softmax and target for backward pass
ctx.save_for_backward(softmax, target)
Expand Down
6 changes: 3 additions & 3 deletions plexus/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def chunked_spmm_all_reduce(csr_matrix, H, ar_group):
# spmm for current chunk

if not plx.overlap_agg:
ax.get_timers.start("AGG = A * H")
ax.get_timers().start("AGG = A * H")

results[i] = torch.sparse.mm(chunk_edge_index, H)

Expand Down Expand Up @@ -184,9 +184,9 @@ def forward(
if plx.block_agg:
AGG = chunked_spmm_all_reduce(edge_index, H, aggregation_all_reduce_group)
else:
ax.get_timers.start("AGG = A * H")
ax.get_timers().start("AGG = A * H")
AGG = torch.sparse.mm(edge_index, H)
ax.get_timers.stop("AGG = A * H")
ax.get_timers().stop("AGG = A * H")

_all_reduce(AGG, aggregation_all_reduce_group)

Expand Down
4 changes: 3 additions & 1 deletion plexus/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def load(self):
os.path.join(self.data_dir, pt_files[0]),
weights_only=False,
)

self.double_perm = hasattr(self.data, "edge_index_2")

self.__set_graph_attributes()

Expand All @@ -496,7 +498,7 @@ def load(self):
adj_shards.append(
self.__split_adj(
self.data.edge_index_2,
self.data.edge_weight,
self.data.edge_weight_2,
i,
)
)
Expand Down
3 changes: 2 additions & 1 deletion plexus/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def preprocess_graph(
print("Completed A = A * P.T\n")

del P
del P2
if double_perm:
del P2
gc.collect()

# convert back to edge index format
Expand Down
40 changes: 40 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Create directories for serial data and logs
mkdir -p ./data/serial
mkdir -p ./data/no_double_perm
mkdir -p ./data/double_perm
mkdir -p ./log

# Run the serial baseline
echo "Running PyG serial baseline..."
python benchmarking/pyg_serial.py --download_path ./data/serial --num_epochs 1000 > log/serial.txt

# Preprocess data for our model
echo "Preprocessing Reddit dataset without double permutation..."
python scripts/preprocess_reddit.py --no_double_perm --input_dir ./no_double_perm/raw --output_dir ./no_double_perm/processed

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 1 --G_intra_d 4 --num_epochs 1000 > log/r1c1d4.txt

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 2 --G_intra_d 2 --num_epochs 1000 > log/r1c2d2.txt

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 4 --G_intra_d 1 --num_epochs 1000 > log/r1c4d1.txt

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 2 --G_intra_c 1 --G_intra_d 2 --num_epochs 1000 > log/r2c1d2.txt


echo "Preprocessing Reddit dataset with double permutation..."
python scripts/preprocess_reddit.py --double_perm --input_dir ./double_perm/raw --output_dir ./double_perm/processed

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 1 --G_intra_d 4 --num_epochs 1000 > log/double_perm_r1c1d4.txt

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 2 --G_intra_d 2 --num_epochs 1000 > log/double_perm_r1c2d2.txt

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 4 --G_intra_d 1 --num_epochs 1000 > log/double_perm_r1c4d1.txt

torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 2 --G_intra_c 1 --G_intra_d 2 --num_epochs 1000 > log/double_perm_r2c1d2.txt

# Generate plot from logs
echo "Generating loss plot..."
python scripts/plot_loss.py --log_dir ./log

echo "All experiments completed."

136 changes: 136 additions & 0 deletions scripts/plot_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read training log files from the log folder and plot loss curves over epochs
"""

import os
import re
import argparse
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def parse_log_file(file_path):
"""
Parse log file to extract epoch and loss data

Args:
file_path: log file path

Returns:
tuple: (epochs, losses) two lists
"""
epochs = []
losses = []

with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
# Use regex to match "Epoch: XXX, Train Loss: Y.YYYY" format
match = re.match(r'Epoch:\s*(\d+),\s*Train Loss:\s*([\d.]+)', line.strip())
if match:
epoch = int(match.group(1))
loss = float(match.group(2))
epochs.append(epoch)
losses.append(loss)

return epochs, losses

def plot_loss_curves(log_dir):
"""
Plot loss curves for all log files

Args:
log_dir: log folder path
"""
# Set font for better display
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial']
plt.rcParams['axes.unicode_minus'] = False

# Create figure
plt.figure(figsize=(12, 8))

# Color list, ensure enough colors
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
'#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# Get all txt files
log_path = Path(log_dir)
txt_files = list(log_path.glob('*.txt'))

if not txt_files:
print(f"No txt files found in {log_dir}")
return

not_drawn_files = [] # Track files that were not drawn and reasons
# Draw a line for each file
for i, file_path in enumerate(sorted(txt_files)):
try:
epochs, losses = parse_log_file(file_path)

if not epochs or not losses:
reason = f"No valid epoch data found (format mismatch or empty content)"
not_drawn_files.append((file_path.name, reason))
print(f"Warning: {file_path.name} {reason}")
continue

# Filename (without .txt extension) as legend
legend_name = file_path.stem

# Select color
color = colors[i % len(colors)]

# Draw line
plt.plot(epochs, losses, label=legend_name, color=color, linewidth=2, alpha=0.8)

print(f"Processed: {file_path.name} - {len(epochs)} epochs")

except Exception as e:
not_drawn_files.append((file_path.name, f"Parse error: {e}"))
print(f"Error processing file {file_path.name}: {e}")

if not_drawn_files:
print("\nThe following files were not drawn:")
for fname, reason in not_drawn_files:
print(f" {fname}: {reason}")
else:
print("\nAll files were successfully drawn!")

# Set chart properties
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Train Loss', fontsize=12)
# plt.title('', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)

# Adjust layout to ensure legend is not clipped
plt.tight_layout()

# Save image
output_path = log_path / 'loss_curves.png'
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"\nChart saved to: {output_path}")

# Show chart
plt.show()

def main():
"""Main function"""
# Parse command line arguments
parser = argparse.ArgumentParser(description='Plot loss curves from training log files')
parser.add_argument('--log_dir', type=str, default="/home/cc/plexus/log",
help='Path to the log directory (default: /home/cc/plexus/log)')

args = parser.parse_args()
log_dir = args.log_dir

if not os.path.exists(log_dir):
print(f"Error: Directory {log_dir} does not exist")
return

print("Starting to process log files...")
plot_loss_curves(log_dir)
print("Processing completed!")

if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions scripts/preprocess_reddit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python
import os
import sys
import argparse
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from plexus.utils.dataset import preprocess_graph
from plexus.utils.general import set_seed

def parse_args():
parser = argparse.ArgumentParser(description='Preprocess graph dataset')
parser.add_argument('--name', type=str, default='reddit',
help='Dataset name (default: reddit)')
parser.add_argument('--input_dir', type=str, default='./data/raw',
help='Input directory path (default: ./data/raw)')
parser.add_argument('--output_dir', type=str, default='./data/processed',
help='Output directory path (default: ./data/processed)')
parser.add_argument('--double_perm', action='store_true', default=True,
help='Use double permutation (default: True)')
parser.add_argument('--no_double_perm', dest='double_perm', action='store_false',
help='Disable double permutation')
parser.add_argument('--unsupervised', action='store_true', default=False,
help='Unsupervised mode (default: False)')
parser.add_argument('--directed', action='store_true', default=False,
help='Directed graph mode (default: False)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed (default: 42)')

return parser.parse_args()

def main():
args = parse_args()

set_seed(args.seed)

print(f"Starting preprocessing for {args.name} dataset...")
print(f"Input directory: {args.input_dir}")
print(f"Output directory: {args.output_dir}")
print(f"Double permutation: {args.double_perm}")
print(f"Unsupervised mode: {args.unsupervised}")
print(f"Directed graph: {args.directed}")

# Create output directory if it doesn't exist
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
print(f"Created output directory: {args.output_dir}")

preprocess_graph(
name=args.name,
input_dir=args.input_dir,
output_dir=args.output_dir,
double_perm=args.double_perm,
unsupervised=args.unsupervised,
directed=args.directed
)
print("Preprocessing completed!")

if __name__ == "__main__":
main()