From b83465296c417b5325be8f760468b8dd1bd4089a Mon Sep 17 00:00:00 2001 From: SamuelDiai Date: Fri, 2 Apr 2021 00:32:37 +0200 Subject: [PATCH] . --- cancer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/cancer.py b/cancer.py index 24cd117..1f62537 100644 --- a/cancer.py +++ b/cancer.py @@ -26,6 +26,16 @@ def Cancer(df_cancer, y_cancer, alpha, beta, k, n_iter): A = graph['adjacency'] G = nx.from_numpy_matrix(A) pos = nx.spring_layout(G) + + # normalize edge weights to plot edges strength + all_weights = [] + for (node1,node2,data) in G.edges(data=True): + all_weights.append(data['weight']) + max_weight = max(all_weights) + norm_weights = [3* w / max_weight for w in all_weights] + norm_weights = norm_weights + + fig = plt.figure(figsize=(12,12)) # Color labels color_map = [] @@ -33,7 +43,7 @@ def Cancer(df_cancer, y_cancer, alpha, beta, k, n_iter): for i in range(y_cancer.shape[0]): color_map.append(color_dict[y_cancer['Class'][i]]) # Plot graph - nx.draw(G, node_color=color_map, with_labels=True, pos = pos, font_weight='bold') + nx.draw(G, node_color=color_map, width=norm_weights, pos = pos, font_weight='bold') plt.title("Learned graph for the cancer dataset k=%s n_iter=%s alpha=%.3f beta=%.3f" % (k , n_iter, alpha, beta)) filename = os.path.join(plots_dir, 'cancer', 'graph') fig.savefig(filename)