From c83cb0b34f79594e8997ad325262e09513291742 Mon Sep 17 00:00:00 2001 From: daniel-gomm Date: Mon, 20 Nov 2023 19:27:27 +0100 Subject: [PATCH 1/2] Fix issue where memory update at end may raise an exception when processing non-bipartite graphs --- model/tgn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/model/tgn.py b/model/tgn.py index 05704f0..c41d922 100644 --- a/model/tgn.py +++ b/model/tgn.py @@ -178,12 +178,12 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_ source_nodes, source_node_embedding, edge_times, edge_idxs) - if self.memory_update_at_start: - self.memory.store_raw_messages(unique_sources, source_id_to_messages) - self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) - else: - self.update_memory(unique_sources, source_id_to_messages) - self.update_memory(unique_destinations, destination_id_to_messages) + + self.memory.store_raw_messages(unique_sources, source_id_to_messages) + self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) + + if not self.memory_update_at_start: + self.get_updated_memory(list(range(self.n_nodes)), self.memory.messages) if self.dyrep: source_node_embedding = memory[source_nodes] From 8634faeaad4a93eda8efa3de441c6ca06722c9f8 Mon Sep 17 00:00:00 2001 From: daniel-gomm Date: Thu, 23 Nov 2023 11:04:21 +0100 Subject: [PATCH 2/2] Remove messages from memory after update --- model/tgn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/model/tgn.py b/model/tgn.py index c41d922..30ec65b 100644 --- a/model/tgn.py +++ b/model/tgn.py @@ -183,7 +183,10 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_ self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) if not self.memory_update_at_start: - self.get_updated_memory(list(range(self.n_nodes)), self.memory.messages) + unique_node_ids = np.unique(np.concatenate((unique_sources, unique_destinations))) + self.update_memory(unique_node_ids, + self.memory.messages) + self.memory.clear_messages(unique_node_ids) if self.dyrep: source_node_embedding = memory[source_nodes]