diff --git a/basic_experiments.py b/basic_experiments.py index 07107f2..bc1145f 100644 --- a/basic_experiments.py +++ b/basic_experiments.py @@ -22,22 +22,25 @@ def load_dataset_and_sgl(dataset, k, k_sgl, n): if dataset == 'Two moons': assert(k == 2) X, y = make_moons(n_samples=n*k, noise=.05, shuffle=True) + n_points = n*k elif dataset == 'Blops': X, y = make_blobs(n_samples=n*k, centers=k, n_features=2, random_state=0, cluster_std=0.6) + n_points = n*k elif dataset == 'Circles': X, y = make_circles(n_samples=n) + n_points = n else : raise ValueError('%s is not a valid dataset ' % dataset) # dict to store position of nodes pos = {} - for i in range(n*k): + for i in range(n_points): pos[i] = X[i] # compute sample correlation matrix S = np.dot(X, X.T) # estimate underlying graph sgl = LearnGraphTopology(S, n_iter=100, beta=0.1) - graph = sgl.learn_graph(k=k) + graph = sgl.learn_graph(k=k_sgl) # build network A = graph['adjacency']