Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelDiai committed Apr 1, 2021
1 parent f786534 commit 2bd1da0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion SGL.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import cvxpy as cp
import numpy as np
from .utils import Laplacian_dual, Laplacian_inv, Laplacian
from utils import Laplacian_dual, Laplacian_inv, Laplacian

class LearnGraphTopology:
def __init__(self, S, alpha=0, beta=1e4, n_iter=10000, c1=0., c2=1e10, tol = 1e-6):
Expand Down
8 changes: 5 additions & 3 deletions animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_blobs
from .SGL import LearnGraphTopology
from SGL import LearnGraphTopology
plots_dir = './plots'
if not os.path.exists(plots_dir):
os.makedirs(plots_dir)

def animals(k, n_iter, alpha, beta):
if not os.path.exists(os.path.join(plots_dir, 'animals')):
os.makedirs(os.path.join(plots_dir, 'animals'))
X_animals = np.load('data/animals_data.npy')
animals_names = np.load('data/animals_name.npy')
animals_features = np.load('data/animals_features.npy')
Expand All @@ -33,6 +35,6 @@ def animals(k, n_iter, alpha, beta):
G = nx.relabel_nodes(G, mapping)
fig = plt.figure(figsize=(15,15))
nx.draw(G, with_labels=True, font_weight='bold')
plt.title("Learned graph for the animal dataset" % dataset)
filename = os.path.join(plots_dir, 'animals', 'graph_%s_%s_%.3f_%.3f' % (k , n_iter, alpha, beta))
plt.title("Learned graph for the animal dataset k=%s n_iter=%s alpha=%.3f beta=%.3f" % (k , n_iter, alpha, beta))
filename = os.path.join(plots_dir, 'animals', 'graph')
fig.savefig(filename)
2 changes: 1 addition & 1 deletion basic_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_blobs, make_circles
from .SGL import LearnGraphTopology
from SGL import LearnGraphTopology
plots_dir = './plots'
if not os.path.exists(plots_dir):
os.makedirs(plots_dir)
Expand Down

0 comments on commit 2bd1da0

Please sign in to comment.