Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions examples/linkproppred/tkgl/edgebank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import argparse

import numpy as np
import torch
from tgb.linkproppred.evaluate import Evaluator
from tqdm import tqdm

from tgm import DGraph
from tgm.constants import METRIC_TGB_LINKPROPPRED
from tgm.data import DGData, DGDataLoader
from tgm.hooks import HookManager, TGBTKGNegativeEdgeSamplerHook
from tgm.nn import EdgeBankPredictor
from tgm.util.logging import enable_logging, log_latency, log_metric
from tgm.util.seed import seed_everything

parser = argparse.ArgumentParser(
description='EdgeBank LinkPropPred Example for knowledge graph',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--seed', type=int, default=1337, help='random seed to use')
parser.add_argument(
'--dataset', type=str, default='tkgl-smallpedia', help='Dataset name'
)
parser.add_argument('--bsize', type=int, default=200, help='batch size')
parser.add_argument('--window-ratio', type=float, default=0.15, help='Window ratio')
parser.add_argument('--pos-prob', type=float, default=1.0, help='Positive edge prob')
parser.add_argument(
'--memory-mode',
type=str,
default='unlimited',
choices=['unlimited', 'fixed'],
help='Memory mode',
)
parser.add_argument(
'--log-file-path', type=str, default=None, help='Optional path to write logs'
)

args = parser.parse_args()
enable_logging(log_file_path=args.log_file_path)


@log_latency
def eval(
loader: DGDataLoader,
model: EdgeBankPredictor,
evaluator: Evaluator,
) -> float:
perf_list = []
for batch in tqdm(loader):
for idx, neg_batch in enumerate(batch.neg_batch_list):
query_src = batch.edge_src[idx].repeat(len(neg_batch) + 1)
query_dst = torch.cat([batch.edge_dst[idx].unsqueeze(0), neg_batch])

y_pred = model(query_src, query_dst)
input_dict = {
'y_pred_pos': y_pred[0],
'y_pred_neg': y_pred[1:],
'eval_metric': [METRIC_TGB_LINKPROPPRED],
}
perf_list.append(evaluator.eval(input_dict)[METRIC_TGB_LINKPROPPRED])
model.update(batch.edge_src, batch.edge_dst, batch.edge_time)

return float(np.mean(perf_list))


seed_everything(args.seed)
evaluator = Evaluator(name=args.dataset)

data = DGData.from_tgb(args.dataset)
min_dst_node = data.edge_index[:, 1].min().int()
max_dst_node = data.edge_index[:, 1].max().int()

train_data, val_data, test_data = data.split()
train_dg = DGraph(train_data)
val_dg = DGraph(val_data)
test_dg = DGraph(test_data)

train_data = train_dg.materialize(materialize_features=False)


hm = HookManager(keys=['val', 'test'])
hm.register(
'val',
TGBTKGNegativeEdgeSamplerHook(
args.dataset,
split_mode='val',
first_dst_id=min_dst_node,
last_dst_id=max_dst_node,
),
)
hm.register(
'test',
TGBTKGNegativeEdgeSamplerHook(
args.dataset,
split_mode='test',
first_dst_id=min_dst_node,
last_dst_id=max_dst_node,
),
)

val_loader = DGDataLoader(val_dg, args.bsize, hook_manager=hm)
test_loader = DGDataLoader(test_dg, args.bsize, hook_manager=hm)

model = EdgeBankPredictor(
train_data.edge_src,
train_data.edge_dst,
train_data.edge_time,
memory_mode=args.memory_mode,
window_ratio=args.window_ratio,
pos_prob=args.pos_prob,
)

with hm.activate('val'):
val_mrr = eval(val_loader, model, evaluator)
log_metric(f'Validation {METRIC_TGB_LINKPROPPRED}', val_mrr)

with hm.activate('test'):
test_mrr = eval(test_loader, model, evaluator)
log_metric(f'Test {METRIC_TGB_LINKPROPPRED}', test_mrr)
3 changes: 2 additions & 1 deletion scripts/download_tgb_datasets.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DATASETS=(
"tgbl_wiki"
"tgbn_trade"
"thgl_software"
"tkgl-smallpedia"
#"tgbn_genre"
#"tgbl_coin"
#"tgbl_flight" TODO: Start working with the large graphs
Expand Down Expand Up @@ -67,7 +68,7 @@ download_dataset() {
local dataset_name="${dataset//_/-}" # 'tgbl_wiki' -> 'tgbl-wiki'
echo "Downloading dataset: $dataset_name"

if [[ "$dataset" == tgbl_* || "$dataset" == thgl_* ]]; then
if [[ "$dataset" == tgbl_* || "$dataset" == thgl_* || "$dataset" == tkgl_* ]]; then
.venv/bin/python -c "from tgb.linkproppred.dataset import LinkPropPredDataset as DS; DS(name='$dataset_name')"
elif [[ "$dataset" == tgbn_* ]]; then
.venv/bin/python -c "from tgb.nodeproppred.dataset import NodePropPredDataset as DS; DS(name='$dataset_name')"
Expand Down
36 changes: 36 additions & 0 deletions test/integration/test_edgebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,39 @@ def test_edgebank_linkprop_pred_fixed_memory_thgl(slurm_job_runner, dataset):
--dataset {dataset} --memory-mode fixed"""
state = slurm_job_runner(cmd)
assert state == 'COMPLETED'


@pytest.mark.integration
@pytest.mark.parametrize('dataset', ['tkgl-smallpedia'])
@pytest.mark.slurm(
resources=[
'--partition=main',
'--cpus-per-task=2',
'--mem=8G',
'--time=1:15:00',
]
)
def test_edgebank_linkprop_pred_unlimited_memory_tkgl(slurm_job_runner, dataset):
cmd = f"""
python "$ROOT_DIR/examples/linkproppred/tkgl/edgebank.py" \
--dataset {dataset}"""
state = slurm_job_runner(cmd)
assert state == 'COMPLETED'


@pytest.mark.integration
@pytest.mark.parametrize('dataset', ['tkgl-smallpedia'])
@pytest.mark.slurm(
resources=[
'--partition=main',
'--cpus-per-task=2',
'--mem=8G',
'--time=2:00:00',
]
)
def test_edgebank_linkprop_pred_fixed_memory_tkgl(slurm_job_runner, dataset):
cmd = f"""
python "$ROOT_DIR/examples/linkproppred/tkgl/edgebank.py" \
--dataset {dataset} --memory-mode fixed"""
state = slurm_job_runner(cmd)
assert state == 'COMPLETED'
4 changes: 4 additions & 0 deletions test/unit/test_core/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def test_tgb_native_time_deltas():
'thgl-forum': TimeDeltaDG('s'),
'thgl-github': TimeDeltaDG('s'),
'thgl-myket': TimeDeltaDG('s'),
'tkgl-smallpedia': TimeDeltaDG('Y'),
'tkgl-polecat': TimeDeltaDG('D'),
'tkgl-icews': TimeDeltaDG('D'),
'tkgl-wikidata': TimeDeltaDG('Y'),
}
assert TGB_TIME_DELTAS == exp_dict

Expand Down
131 changes: 124 additions & 7 deletions test/unit/test_data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,13 @@ def test_from_pandas_bad_node_cols_not_specified():

@pytest.fixture
def tgb_dataset_factory():
def _make_dataset(split: str = 'all', with_node_feats: bool = False, thgl=False):
def _make_dataset(
split: str = 'all',
with_node_feats: bool = False,
with_edge_feats: bool = False,
thgl=False,
tkgl=False,
):
num_events, num_train, num_val = 10, 7, 2
train_indices = np.arange(0, num_train)
val_indices = np.arange(num_train, num_train + num_val)
Expand All @@ -1205,7 +1211,12 @@ def _make_dataset(split: str = 'all', with_node_feats: bool = False, thgl=False)
sources = np.random.randint(0, 1000, size=num_events)
destinations = np.random.randint(0, 1000, size=num_events)
timestamps = np.arange(num_events)
edge_feat = None
if not with_edge_feats:
edge_feat = None
elif with_edge_feats and not tkgl:
edge_feat = np.random.rand(num_events, 10)
elif with_edge_feats and tkgl:
edge_feat = np.random.rand(num_events // 2, 10)

train_mask = np.zeros(num_events, dtype=bool)
val_mask = np.zeros(num_events, dtype=bool)
Expand Down Expand Up @@ -1249,6 +1260,9 @@ def _make_dataset(split: str = 'all', with_node_feats: bool = False, thgl=False)
max(sources.max(), destinations.max()) + 1
)

if tkgl:
mock_dataset.full_data['edge_type'] = np.arange(num_events)

return mock_dataset

return _make_dataset
Expand Down Expand Up @@ -1316,6 +1330,59 @@ def _make_dataset(split: str = 'all', with_edge_type=True, with_node_type=False)
return _make_dataset


@pytest.fixture
def bad_tkgl_dataset_factory(): # Missing edge_type
def _make_dataset(split: str = 'all'):
num_events, num_train, num_val = 10, 7, 2
train_indices = np.arange(0, num_train)
val_indices = np.arange(num_train, num_train + num_val)
test_indices = np.arange(num_train + num_val, num_events)

sources = np.random.randint(0, 1000, size=num_events)
destinations = np.random.randint(0, 1000, size=num_events)
timestamps = np.arange(num_events)
edge_feat = None
w = np.random.rand(num_events, 10)

train_mask = np.zeros(num_events, dtype=bool)
val_mask = np.zeros(num_events, dtype=bool)
test_mask = np.zeros(num_events, dtype=bool)

train_mask[train_indices] = True
val_mask[val_indices] = True
test_mask[test_indices] = True

mock_dataset = MagicMock()
mock_dataset.train_mask = train_mask
mock_dataset.val_mask = val_mask
mock_dataset.test_mask = test_mask
mock_dataset.num_edges = num_events
mock_dataset.full_data = {
'sources': sources,
'destinations': destinations,
'timestamps': timestamps,
'edge_feat': edge_feat,
'w': w,
}

if split == 'all':
1 + max(np.max(sources), np.max(destinations))
else:
mask = {'train': train_mask, 'val': val_mask, 'test': test_mask}[split]
valid_src, valid_dst = sources[mask], destinations[mask]
1 + max(np.max(valid_src), np.max(valid_dst))

mock_dataset.node_feat = None

mock_dataset.full_data['node_label_dict'] = {}
for i in range(5):
mock_dataset.full_data['node_label_dict'][i] = {i: np.zeros(10)}

return mock_dataset

return _make_dataset


@pytest.fixture
def tgb_seq_dataset_factory():
def _make_dataset(
Expand Down Expand Up @@ -1369,11 +1436,6 @@ def _make_dataset(
return _make_dataset


def test_from_tkgl():
with pytest.raises(NotImplementedError):
DGData.from_tgb('tkgl-foo')


def test_from_bad_tgb_name():
with pytest.raises(ValueError):
DGData.from_tgb('foo')
Expand Down Expand Up @@ -2258,3 +2320,58 @@ def test_from_pandas_with_static_node_type():
)
assert isinstance(data, DGData)
torch.testing.assert_close(data.node_type.tolist(), node_dict['node_type'])


@pytest.mark.parametrize('with_node_feats', [True, False])
@pytest.mark.parametrize('with_edge_feats', [True, False])
@pytest.mark.parametrize('tkgl', [True])
@patch('tgb.linkproppred.dataset.LinkPropPredDataset')
@patch.dict('tgm.core.timedelta.TGB_TIME_DELTAS', {'tkgl-smallpedia': TimeDeltaDG('D')})
def test_from_tkgl(
mock_dataset_cls, tgb_dataset_factory, with_node_feats, with_edge_feats, tkgl
):
dataset = tgb_dataset_factory(
with_node_feats=with_node_feats, with_edge_feats=with_edge_feats, tkgl=tkgl
)
mock_dataset_cls.return_value = dataset

mock_native_time_delta = TimeDeltaDG('D') # Patched value

def _get_exp_edges():
src, dst = dataset.full_data['sources'], dataset.full_data['destinations']
return np.stack([src, dst], axis=1)

def _get_exp_times():
return dataset.full_data['timestamps']

def _get_exp_edge_type():
return dataset.full_data['edge_type']

def _get_exp_edge_feat():
edge_feat_np = dataset.full_data['edge_feat']
return np.concatenate((edge_feat_np, edge_feat_np))

data = DGData.from_tgb(name='tkgl-smallpedia')
assert isinstance(data, DGData)
assert data.time_delta == mock_native_time_delta
np.testing.assert_allclose(data.edge_index.numpy(), _get_exp_edges())
np.testing.assert_allclose(data.time.numpy(), _get_exp_times())
np.testing.assert_allclose(data.edge_type.numpy(), _get_exp_edge_type())
if with_edge_feats:
np.testing.assert_allclose(data.edge_x.numpy(), _get_exp_edge_feat())

# Confirm correct dataset instantiation
mock_dataset_cls.assert_called_once_with(name='tkgl-smallpedia')

if with_node_feats:
torch.testing.assert_close(data.static_node_x, torch.Tensor(dataset.node_feat))
else:
assert data.static_node_x is None


@patch('tgb.linkproppred.dataset.LinkPropPredDataset')
def test_from_bad_thgl(mock_dataset_cls, bad_tkgl_dataset_factory):
dataset = bad_tkgl_dataset_factory()
mock_dataset_cls.return_value = dataset
with pytest.raises(ValueError):
data = DGData.from_tgb(name='tkgl-smallpedia')
2 changes: 1 addition & 1 deletion test/unit/test_hooks/test_device_transfer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_device_transfer_hook_gpu_gpu(dg):
batch = dg.materialize()
batch.edge_src = batch.edge_src.to('cuda')
batch.edge_dst = batch.edge_dst.to('cuda')
batch.time = batch.time.to('cuda')
batch.edge_time = batch.edge_time.to('cuda')

# Add a custom field and ensure it's also moved
batch.foo = torch.rand(1, 2, device='cuda')
Expand Down
Loading
Loading