diff --git a/benchmarking/pyg_serial.py b/benchmarking/pyg_serial.py index 8f20405..f89d1ef 100644 --- a/benchmarking/pyg_serial.py +++ b/benchmarking/pyg_serial.py @@ -36,11 +36,13 @@ def create_parser(): def get_dataset(download_path=None): - dataset = PygNodePropPredDataset( - name="ogbn-products", - root=input_dir, - transform=T.NormalizeFeatures(), - ) + # dataset = PygNodePropPredDataset( + # name="ogbn-products", + # root=input_dir, + # transform=T.NormalizeFeatures(), + # ) + + dataset = Reddit(root=download_path, transform=T.NormalizeFeatures()) gcn_norm = T.GCNNorm() return (gcn_norm.forward(dataset[0]), dataset.num_classes) diff --git a/examples/train.py b/examples/train.py index adac42e..1eee2ba 100644 --- a/examples/train.py +++ b/examples/train.py @@ -62,7 +62,7 @@ def __init__(self, num_gcn_layers, input_size, hidden_size, output_size): self.num_gcn_layers = num_gcn_layers - self.layers = [] + self.layers = torch.nn.ModuleList() for i in range(self.num_gcn_layers): if i == 0: self.layers.append(GCNConv(input_size, hidden_size, i)) diff --git a/plexus/cross_entropy.py b/plexus/cross_entropy.py index cd92894..c3a4d34 100644 --- a/plexus/cross_entropy.py +++ b/plexus/cross_entropy.py @@ -66,20 +66,25 @@ def forward(ctx, logits, target, num_layers, num_nodes, num_classes): ) | (target < 0) softmax[invalid_nodes, :] = 0.0 - # create mask for classes that are outside the local range of classes - invalid_logits_mask = (target < (ranks[1] * logits.shape[1])) | ( - target >= ((ranks[1] + 1) * logits.shape[1]) - ) - - # convert from global label to local label - target[~invalid_logits_mask] -= ranks[1] * logits.shape[1] - target[invalid_logits_mask] = 0 - - # create one hot vector from the labels - target = F.one_hot(target, num_classes=logits.shape[1]) - - # for labels out of the local range, make the target vector 0 - target[invalid_logits_mask] = 0 + # # create mask for classes that are outside the local range of classes + # invalid_logits_mask = (target < (ranks[1] * logits.shape[1])) | ( + # target >= ((ranks[1] + 1) * logits.shape[1]) + # ) + + # # convert from global label to local label + # target[~invalid_logits_mask] -= ranks[1] * logits.shape[1] + # target[invalid_logits_mask] = 0 + + # # create one hot vector from the labels + # target = F.one_hot(target, num_classes=logits.shape[1]) + + # # for labels out of the local range, make the target vector 0 + # target[invalid_logits_mask] = 0 + + target = F.one_hot(target, num_classes=(logits.shape[1] * num_gpus[1])) + target = target[ + :, (ranks[1] * logits.shape[1]) : ((ranks[1] + 1) * logits.shape[1]) + ] # save softmax and target for backward pass ctx.save_for_backward(softmax, target) diff --git a/plexus/gcn_conv.py b/plexus/gcn_conv.py index 7565880..07a8fb7 100644 --- a/plexus/gcn_conv.py +++ b/plexus/gcn_conv.py @@ -93,7 +93,7 @@ def chunked_spmm_all_reduce(csr_matrix, H, ar_group): # spmm for current chunk if not plx.overlap_agg: - ax.get_timers.start("AGG = A * H") + ax.get_timers().start("AGG = A * H") results[i] = torch.sparse.mm(chunk_edge_index, H) @@ -184,9 +184,9 @@ def forward( if plx.block_agg: AGG = chunked_spmm_all_reduce(edge_index, H, aggregation_all_reduce_group) else: - ax.get_timers.start("AGG = A * H") + ax.get_timers().start("AGG = A * H") AGG = torch.sparse.mm(edge_index, H) - ax.get_timers.stop("AGG = A * H") + ax.get_timers().stop("AGG = A * H") _all_reduce(AGG, aggregation_all_reduce_group) diff --git a/plexus/utils/dataloader.py b/plexus/utils/dataloader.py index 74b19d8..6faedd1 100644 --- a/plexus/utils/dataloader.py +++ b/plexus/utils/dataloader.py @@ -479,6 +479,8 @@ def load(self): os.path.join(self.data_dir, pt_files[0]), weights_only=False, ) + + self.double_perm = hasattr(self.data, "edge_index_2") self.__set_graph_attributes() @@ -496,7 +498,7 @@ def load(self): adj_shards.append( self.__split_adj( self.data.edge_index_2, - self.data.edge_weight, + self.data.edge_weight_2, i, ) ) diff --git a/plexus/utils/dataset.py b/plexus/utils/dataset.py index 7cca2a0..4c0f679 100644 --- a/plexus/utils/dataset.py +++ b/plexus/utils/dataset.py @@ -198,7 +198,8 @@ def preprocess_graph( print("Completed A = A * P.T\n") del P - del P2 + if double_perm: + del P2 gc.collect() # convert back to edge index format diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..3c1762e --- /dev/null +++ b/run.sh @@ -0,0 +1,40 @@ +# Create directories for serial data and logs +mkdir -p ./data/serial +mkdir -p ./data/no_double_perm +mkdir -p ./data/double_perm +mkdir -p ./log + +# Run the serial baseline +echo "Running PyG serial baseline..." +python benchmarking/pyg_serial.py --download_path ./data/serial --num_epochs 1000 > log/serial.txt + +# Preprocess data for our model +echo "Preprocessing Reddit dataset without double permutation..." +python scripts/preprocess_reddit.py --no_double_perm --input_dir ./no_double_perm/raw --output_dir ./no_double_perm/processed + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 1 --G_intra_d 4 --num_epochs 1000 > log/r1c1d4.txt + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 2 --G_intra_d 2 --num_epochs 1000 > log/r1c2d2.txt + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 4 --G_intra_d 1 --num_epochs 1000 > log/r1c4d1.txt + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./no_double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 2 --G_intra_c 1 --G_intra_d 2 --num_epochs 1000 > log/r2c1d2.txt + + +echo "Preprocessing Reddit dataset with double permutation..." +python scripts/preprocess_reddit.py --double_perm --input_dir ./double_perm/raw --output_dir ./double_perm/processed + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 1 --G_intra_d 4 --num_epochs 1000 > log/double_perm_r1c1d4.txt + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 2 --G_intra_d 2 --num_epochs 1000 > log/double_perm_r1c2d2.txt + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 1 --G_intra_c 4 --G_intra_d 1 --num_epochs 1000 > log/double_perm_r1c4d1.txt + +torchrun --nproc_per_node=4 --master_port=29500 examples/train.py --data_dir ./double_perm/processed --block_aggregation --tune_gemms --gpus_per_node 4 --G_intra_r 2 --G_intra_c 1 --G_intra_d 2 --num_epochs 1000 > log/double_perm_r2c1d2.txt + +# Generate plot from logs +echo "Generating loss plot..." +python scripts/plot_loss.py --log_dir ./log + +echo "All experiments completed." + diff --git a/scripts/plot_loss.py b/scripts/plot_loss.py new file mode 100644 index 0000000..8f7b580 --- /dev/null +++ b/scripts/plot_loss.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Read training log files from the log folder and plot loss curves over epochs +""" + +import os +import re +import argparse +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path + +def parse_log_file(file_path): + """ + Parse log file to extract epoch and loss data + + Args: + file_path: log file path + + Returns: + tuple: (epochs, losses) two lists + """ + epochs = [] + losses = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + # Use regex to match "Epoch: XXX, Train Loss: Y.YYYY" format + match = re.match(r'Epoch:\s*(\d+),\s*Train Loss:\s*([\d.]+)', line.strip()) + if match: + epoch = int(match.group(1)) + loss = float(match.group(2)) + epochs.append(epoch) + losses.append(loss) + + return epochs, losses + +def plot_loss_curves(log_dir): + """ + Plot loss curves for all log files + + Args: + log_dir: log folder path + """ + # Set font for better display + plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial'] + plt.rcParams['axes.unicode_minus'] = False + + # Create figure + plt.figure(figsize=(12, 8)) + + # Color list, ensure enough colors + colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', + '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] + + # Get all txt files + log_path = Path(log_dir) + txt_files = list(log_path.glob('*.txt')) + + if not txt_files: + print(f"No txt files found in {log_dir}") + return + + not_drawn_files = [] # Track files that were not drawn and reasons + # Draw a line for each file + for i, file_path in enumerate(sorted(txt_files)): + try: + epochs, losses = parse_log_file(file_path) + + if not epochs or not losses: + reason = f"No valid epoch data found (format mismatch or empty content)" + not_drawn_files.append((file_path.name, reason)) + print(f"Warning: {file_path.name} {reason}") + continue + + # Filename (without .txt extension) as legend + legend_name = file_path.stem + + # Select color + color = colors[i % len(colors)] + + # Draw line + plt.plot(epochs, losses, label=legend_name, color=color, linewidth=2, alpha=0.8) + + print(f"Processed: {file_path.name} - {len(epochs)} epochs") + + except Exception as e: + not_drawn_files.append((file_path.name, f"Parse error: {e}")) + print(f"Error processing file {file_path.name}: {e}") + + if not_drawn_files: + print("\nThe following files were not drawn:") + for fname, reason in not_drawn_files: + print(f" {fname}: {reason}") + else: + print("\nAll files were successfully drawn!") + + # Set chart properties + plt.xlabel('Epoch', fontsize=12) + plt.ylabel('Train Loss', fontsize=12) + # plt.title('', fontsize=14, fontweight='bold') + plt.grid(True, alpha=0.3) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10) + + # Adjust layout to ensure legend is not clipped + plt.tight_layout() + + # Save image + output_path = log_path / 'loss_curves.png' + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"\nChart saved to: {output_path}") + + # Show chart + plt.show() + +def main(): + """Main function""" + # Parse command line arguments + parser = argparse.ArgumentParser(description='Plot loss curves from training log files') + parser.add_argument('--log_dir', type=str, default="/home/cc/plexus/log", + help='Path to the log directory (default: /home/cc/plexus/log)') + + args = parser.parse_args() + log_dir = args.log_dir + + if not os.path.exists(log_dir): + print(f"Error: Directory {log_dir} does not exist") + return + + print("Starting to process log files...") + plot_loss_curves(log_dir) + print("Processing completed!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/preprocess_reddit.py b/scripts/preprocess_reddit.py new file mode 100644 index 0000000..6e6b903 --- /dev/null +++ b/scripts/preprocess_reddit.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +import os +import sys +import argparse +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from plexus.utils.dataset import preprocess_graph +from plexus.utils.general import set_seed + +def parse_args(): + parser = argparse.ArgumentParser(description='Preprocess graph dataset') + parser.add_argument('--name', type=str, default='reddit', + help='Dataset name (default: reddit)') + parser.add_argument('--input_dir', type=str, default='./data/raw', + help='Input directory path (default: ./data/raw)') + parser.add_argument('--output_dir', type=str, default='./data/processed', + help='Output directory path (default: ./data/processed)') + parser.add_argument('--double_perm', action='store_true', default=True, + help='Use double permutation (default: True)') + parser.add_argument('--no_double_perm', dest='double_perm', action='store_false', + help='Disable double permutation') + parser.add_argument('--unsupervised', action='store_true', default=False, + help='Unsupervised mode (default: False)') + parser.add_argument('--directed', action='store_true', default=False, + help='Directed graph mode (default: False)') + parser.add_argument('--seed', type=int, default=42, + help='Random seed (default: 42)') + + return parser.parse_args() + +def main(): + args = parse_args() + + set_seed(args.seed) + + print(f"Starting preprocessing for {args.name} dataset...") + print(f"Input directory: {args.input_dir}") + print(f"Output directory: {args.output_dir}") + print(f"Double permutation: {args.double_perm}") + print(f"Unsupervised mode: {args.unsupervised}") + print(f"Directed graph: {args.directed}") + + # Create output directory if it doesn't exist + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + print(f"Created output directory: {args.output_dir}") + + preprocess_graph( + name=args.name, + input_dir=args.input_dir, + output_dir=args.output_dir, + double_perm=args.double_perm, + unsupervised=args.unsupervised, + directed=args.directed + ) + print("Preprocessing completed!") + +if __name__ == "__main__": + main()