From db2ce65fe35ce0f2b29de503362b74a1017c01b9 Mon Sep 17 00:00:00 2001 From: SamuelDiai Date: Fri, 2 Apr 2021 00:27:21 +0200 Subject: [PATCH] . --- animals.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/animals.py b/animals.py index ed8b2a8..79e347d 100644 --- a/animals.py +++ b/animals.py @@ -27,13 +27,21 @@ def animals(k, n_iter, alpha, beta): A = graph['adjacency'] G = nx.from_numpy_matrix(A) + # normalize edge weights to plot edges strength + all_weights = [] + for (node1,node2,data) in G.edges(data=True): + all_weights.append(data['weight']) + max_weight = max(all_weights) + norm_weights = [3* w / max_weight for w in all_weights] + norm_weights = norm_weights + mapping = {} for i in range(animals_names.shape[0]): mapping[i] = animals_names[i, 0] G = nx.relabel_nodes(G, mapping) fig = plt.figure(figsize=(10,10)) - nx.draw(G, with_labels=True, font_weight='bold') + nx.draw(G, with_labels=True, font_weight='bold', width=norm_weights) plt.title("Learned graph for the animal dataset k=%s n_iter=%s alpha=%.3f beta=%.3f" % (k , n_iter, alpha, beta)) filename = os.path.join(plots_dir, 'animals', 'graph') fig.savefig(filename)