Skip to content

Commit

Permalink
Add seed to kmeans. Log seeds.
Browse files Browse the repository at this point in the history
  • Loading branch information
pvti committed Sep 9, 2023
1 parent 2bd3535 commit 301b8bc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
2 changes: 2 additions & 0 deletions experiments/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def parse_args():
help="distance metric",
)
parser.add_argument("--rank", type=int, default=1, help="decomposition rank")
parser.add_argument("--seed", type=int, default=0, help="seed for random")

return parser.parse_args()

Expand Down Expand Up @@ -165,6 +166,7 @@ def compute_inter_distance(centroids):
rank=args.rank,
num_clusters=num_clusters,
dist=args.distance,
seed=args.seed,
)

# Print the inertia list for each iteration
Expand Down
31 changes: 21 additions & 10 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from kmeans import custom_kmeans, decompose
import wandb
import random
import os


def parse_args():
Expand Down Expand Up @@ -54,6 +55,7 @@ def parse_args():
parser.add_argument(
"--initialization", type=int, default=100, help="number of initialization"
)
parser.add_argument("--output", type=str, default="./", help="output directory")

return parser.parse_args()

Expand Down Expand Up @@ -102,8 +104,9 @@ def run(data_combined, data_processed, method, rank):
if inertia < inertia_min:
inertia_min = inertia
ARI = adjusted_rand_score(ground_truth, labels)
seed = i

return ARI, inertia_min
return ARI, inertia_min, seed


def compute_std(x):
Expand Down Expand Up @@ -134,13 +137,16 @@ def topk_positions(x, y, k=10):
def main():
name = f"distance = {args.distance} rank = {args.rank} clusters-std = [{args.clusters_std_min} {args.clusters_std_max}] noise-std = [{args.noise_std_min} {args.noise_std_max}]"
wandb.init(name=name, project=f"CORING_CustomKmeans", config=vars(args))
if not os.path.isdir(args.output):
os.makedirs(args.outputI)

ARIs_tensor = []
inertias_tensor = []
ARIs_matrix = []
inertias_matrix = []
seeds_tensor = []
seeds_matrix = []
i = 0

for i in tqdm(range(args.runs)):
while i < args.runs:
# Create dataset
clusters_std = random.uniform(args.clusters_std_min, args.clusters_std_max)
noise_std = random.uniform(args.noise_std_min, args.noise_std_max)
Expand All @@ -153,7 +159,8 @@ def main():
closely_similar_filters = add_noise_to_centroids(
initial_filters, args.satellites, noise_std
)
np.save(f"{i}.npy", [initial_filters, closely_similar_filters])
path = os.path.join(args.output, f"{i}.npy")
np.save(path, [initial_filters, closely_similar_filters])

# Combine both initial_filters and closely_similar_filters
data_combined = np.vstack((initial_filters, closely_similar_filters))
Expand All @@ -166,13 +173,13 @@ def main():

# Apply K-means
try:
ARI_tensor, inertia_tensor = run(
ARI_tensor, inertia_tensor, seed_tensor = run(
data_combined=data_combined,
data_processed=data_processed_tensor,
method="tensor",
rank=args.rank,
)
ARI_matrix, inertia_matrix = run(
ARI_matrix, inertia_matrix, seed_matrix = run(
data_combined=data_combined,
data_processed=data_processed_matrix,
method="matrix",
Expand All @@ -185,13 +192,17 @@ def main():
"inertia_tensor": inertia_tensor,
"ARI_matrix": ARI_matrix,
"inertia_matrix": inertia_matrix,
"seed_tensor": seed_tensor,
"seed_matrix": seed_matrix,
}
)

ARIs_tensor.append(ARI_tensor)
inertias_tensor.append(inertia_tensor)
ARIs_matrix.append(ARI_matrix)
inertias_matrix.append(inertia_matrix)
seeds_tensor.append(seed_tensor)
seeds_matrix.append(seed_matrix)

i += 1

except Exception as error:
print(error)
Expand All @@ -206,7 +217,7 @@ def main():
top_positions = topk_positions(ARIs_tensor, ARIs_matrix)
for i, position in enumerate(top_positions):
print(
f"Position {i + 1}: ARIs_tensor[{position}] = {ARIs_tensor[position]}, ARIs_matrix[{position}] = {ARIs_matrix[position]}"
f"Position {i + 1}: ARIs_tensor[{position}] = {ARIs_tensor[position]}, ARIs_matrix[{position}] = {ARIs_matrix[position]}, seed_tensor = {seed_tensor}, seed_matrix = {seed_matrix}"
)


Expand Down

0 comments on commit 301b8bc

Please sign in to comment.