Skip to content

Commit

Permalink
Feat: Remove the unnecessary calculation distance
Browse files Browse the repository at this point in the history
We can direct use the similarities but sort descending as rank
  • Loading branch information
whats2000 committed Nov 1, 2024
1 parent 10abd62 commit 503ff60
Show file tree
Hide file tree
Showing 12 changed files with 4,954 additions and 4,588 deletions.
62 changes: 30 additions & 32 deletions src/ablation_experiment/validate_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def compute_fiq_val_metrics_text_image_combinations(
beta: float,
) -> pd.DataFrame:
"""
Compute validation metrics on FashionIQ dataset combining text and image distances.
Compute validation metrics on FashionIQ dataset combining text and image similarities.
:param relative_val_dataset: FashionIQ validation dataset in relative mode
:param blip_text_encoder: BLIP text encoder
Expand All @@ -33,14 +33,14 @@ def compute_fiq_val_metrics_text_image_combinations(
:param image_index_features: validation image index features
:param image_index_names: validation image index names
:param combining_function: function that combines features
:param beta: beta value for the combination of text and image distances
:param beta: beta value for the combination of text and image similarities
:return: the computed validation metrics
"""
all_text_distances = []
all_text_similarities = []
results = []
target_names = None

# Compute distances for individual text features
# Compute similarities for individual text features
for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names):
# Generate text predictions and normalize features
predicted_text_features, target_names = generate_fiq_val_predictions(
Expand All @@ -55,12 +55,11 @@ def compute_fiq_val_metrics_text_image_combinations(
text_features = F.normalize(text_features, dim=-1)
predicted_text_features = F.normalize(predicted_text_features, dim=-1)

# Compute cosine similarity and convert to distance
# Compute cosine similarity
cosine_similarities = torch.mm(predicted_text_features, text_features.T)
distances = 1 - cosine_similarities
all_text_distances.append(distances)
all_text_similarities.append(cosine_similarities)

# Normalize and compute distances for image features if available
# Normalize and compute similarities for image features if available
if image_index_features is not None and len(image_index_features) > 0:
predicted_image_features, _ = generate_fiq_val_predictions(
blip_text_encoder,
Expand All @@ -71,17 +70,17 @@ def compute_fiq_val_metrics_text_image_combinations(
no_print_output=True,
)

# Normalize and compute distances
# Normalize and compute similarities
image_index_features = F.normalize(image_index_features, dim=-1).float()
image_distances = 1 - predicted_image_features @ image_index_features.T
image_similarities = predicted_image_features @ image_index_features.T
else:
image_distances = torch.zeros_like(all_text_distances[0])
image_similarities = torch.zeros_like(all_text_similarities[0])

# Merge text distances
merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0)
# Merge text similarities
merged_text_similarities = torch.mean(torch.stack(all_text_similarities), dim=0)

merged_distances = beta * merged_text_distances + (1 - beta) * image_distances
sorted_indices = torch.argsort(merged_distances, dim=-1).cpu()
merged_similarities = beta * merged_text_similarities + (1 - beta) * image_similarities
sorted_indices = torch.argsort(merged_similarities, dim=-1, descending=True).cpu()
sorted_index_names = np.array(image_index_names if image_index_names else multiple_text_index_names[0])[
sorted_indices]
labels = torch.tensor(
Expand Down Expand Up @@ -110,7 +109,7 @@ def compute_fiq_val_metrics_text_image_combinations_clip(
beta: float,
) -> pd.DataFrame:
"""
Compute validation metrics on FashionIQ dataset combining text and image distances.
Compute validation metrics on FashionIQ dataset combining text and image similarities.
:param relative_val_dataset: FashionIQ validation dataset in relative mode
:param clip_text_encoder: CLIP text encoder
Expand All @@ -120,14 +119,14 @@ def compute_fiq_val_metrics_text_image_combinations_clip(
:param image_index_features: validation image index features
:param image_index_names: validation image index names
:param combining_function: function that combines features
:param beta: beta value for the combination of text and image distances
:param beta: beta value for the combination of text and image similarities
:return: the computed validation metrics
"""
all_text_distances = []
all_text_similarities = []
results = []
target_names = None

# Compute distances for individual text features
# Compute similarities for individual text features
for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names):
# Generate text predictions and normalize features
predicted_text_features, target_names = generate_fiq_val_predictions_clip(
Expand All @@ -143,12 +142,11 @@ def compute_fiq_val_metrics_text_image_combinations_clip(
text_features = F.normalize(text_features, dim=-1)
predicted_text_features = F.normalize(predicted_text_features, dim=-1)

# Compute cosine similarity and convert to distance
# Compute cosine similarity
cosine_similarities = torch.mm(predicted_text_features, text_features.T)
distances = 1 - cosine_similarities
all_text_distances.append(distances)
all_text_similarities.append(cosine_similarities)

# Normalize and compute distances for image features if available
# Normalize and compute similarities for image features if available
if image_index_features is not None and len(image_index_features) > 0:
predicted_image_features, _ = generate_fiq_val_predictions_clip(
clip_text_encoder,
Expand All @@ -160,17 +158,17 @@ def compute_fiq_val_metrics_text_image_combinations_clip(
no_print_output=True,
)

# Normalize and compute distances
# Normalize and compute similarities
image_index_features = F.normalize(image_index_features, dim=-1).float()
image_distances = 1 - predicted_image_features @ image_index_features.T
image_similarities = predicted_image_features @ image_index_features.T
else:
image_distances = torch.zeros_like(all_text_distances[0])
image_similarities = torch.zeros_like(all_text_similarities[0])

# Merge text distances
merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0)
# Merge text similarities
merged_text_similarities = torch.mean(torch.stack(all_text_similarities), dim=0)

merged_distances = beta * merged_text_distances + (1 - beta) * image_distances
sorted_indices = torch.argsort(merged_distances, dim=-1).cpu()
merged_similarities = beta * merged_text_similarities + (1 - beta) * image_similarities
sorted_indices = torch.argsort(merged_similarities, dim=-1, descending=True).cpu()
sorted_index_names = np.array(
image_index_names if image_index_names else multiple_text_index_names[0]
)[sorted_indices]
Expand Down Expand Up @@ -208,7 +206,7 @@ def fiq_val_retrieval_text_image_combinations(
:param blip_img_encoder: BLIP image model
:param text_captions: text captions for the FashionIQ dataset
:param preprocess: preprocess pipeline
:param beta: beta value for the combination of text and image distances
:param beta: beta value for the combination of text and image similarities
:param cache: cache dictionary
:return: DataFrame containing the retrieval metrics for each combination of text features
"""
Expand Down Expand Up @@ -332,7 +330,7 @@ def fiq_val_retrieval_text_image_combinations_clip(
:param clip_tokenizer: CLIP tokenizer
:param text_captions: text captions for the FashionIQ dataset
:param preprocess: preprocess pipeline
:param beta: beta value for the combination of text and image distances
:param beta: beta value for the combination of text and image similarities
:param cache: cache dictionary
:return: DataFrame containing the retrieval metrics for each combination of text features
"""
Expand Down
50 changes: 24 additions & 26 deletions src/ablation_experiment/validate_notebook_cirr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ def compute_cirr_val_metrics_text_image_combinations(
Returns:
The computed validation metrics
"""
all_text_distances = []
all_text_similarities = []
results = []
reference_names = None
target_names = None
group_members = None

# Compute distances for individual text features
# Compute similarities for individual text features
for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names):
# Generate text predictions and normalize features
predicted_text_features, reference_names, target_names, group_members = generate_cirr_val_predictions(
Expand All @@ -60,12 +60,11 @@ def compute_cirr_val_metrics_text_image_combinations(
text_features = F.normalize(text_features, dim=-1)
predicted_text_features = F.normalize(predicted_text_features, dim=-1)

# Compute cosine similarity and convert to distance
# Compute cosine similarity
cosine_similarities = torch.mm(predicted_text_features, text_features.T)
distances = 1 - cosine_similarities
all_text_distances.append(distances)
all_text_similarities.append(cosine_similarities)

# Normalize and compute distances for image features if available
# Normalize and compute similarities for image features if available
if image_index_features is not None and len(image_index_features) > 0:
predicted_image_features, _, _, _ = generate_cirr_val_predictions(
blip_text_encoder,
Expand All @@ -76,18 +75,18 @@ def compute_cirr_val_metrics_text_image_combinations(
no_print_output=True,
)

# Normalize and compute distances
# Normalize and compute similarities
image_index_features = F.normalize(image_index_features, dim=-1).float()
image_distances = 1 - predicted_image_features @ image_index_features.T
image_similarities = predicted_image_features @ image_index_features.T
else:
image_distances = torch.zeros_like(all_text_distances[0])
image_similarities = torch.zeros_like(all_text_similarities[0])

# Merge text distances
merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0)
# Merge text similarities
merged_text_similarities = torch.mean(torch.stack(all_text_similarities), dim=0)

merged_distances = beta * merged_text_distances + (1 - beta) * image_distances
merged_similarities = beta * merged_text_similarities + (1 - beta) * image_similarities
# Sort the results
sorted_indices = torch.argsort(merged_distances, dim=-1).cpu()
sorted_indices = torch.argsort(merged_similarities, dim=-1, descending=True).cpu()
sorted_index_names = np.array(
image_index_names if image_index_names else multiple_text_index_names[0]
)[sorted_indices]
Expand Down Expand Up @@ -172,13 +171,13 @@ def compute_cirr_val_metrics_text_image_combinations_clip(
Returns:
The computed validation metrics
"""
all_text_distances = []
all_text_similarities = []
results = []
reference_names = None
target_names = None
group_members = None

# Compute distances for individual text features
# Compute similarities for individual text features
for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names):
# Generate text predictions and normalize features
predicted_text_features, reference_names, target_names, group_members = generate_cirr_val_predictions_clip(
Expand All @@ -194,12 +193,11 @@ def compute_cirr_val_metrics_text_image_combinations_clip(
text_features = F.normalize(text_features, dim=-1)
predicted_text_features = F.normalize(predicted_text_features, dim=-1)

# Compute cosine similarity and convert to distance
# Compute cosine similarity
cosine_similarities = torch.mm(predicted_text_features, text_features.T)
distances = 1 - cosine_similarities
all_text_distances.append(distances)
all_text_similarities.append(cosine_similarities)

# Normalize and compute distances for image features if available
# Normalize and compute similarities for image features if available
if image_index_features is not None and len(image_index_features) > 0:
predicted_image_features, _, _, _ = generate_cirr_val_predictions_clip(
clip_text_encoder,
Expand All @@ -211,18 +209,18 @@ def compute_cirr_val_metrics_text_image_combinations_clip(
no_print_output=True,
)

# Normalize and compute distances
# Normalize and compute similarities
image_index_features = F.normalize(image_index_features, dim=-1).float()
image_distances = 1 - predicted_image_features @ image_index_features.T
image_similarities = predicted_image_features @ image_index_features.T
else:
image_distances = torch.zeros_like(all_text_distances[0])
image_similarities = torch.zeros_like(all_text_similarities[0])

# Merge text distances
merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0)
# Merge text similarities
merged_text_similarities = torch.mean(torch.stack(all_text_similarities), dim=0)

merged_distances = beta * merged_text_distances + (1 - beta) * image_distances
merged_similarities = beta * merged_text_similarities + (1 - beta) * image_similarities
# Sort the results
sorted_indices = torch.argsort(merged_distances, dim=-1).cpu()
sorted_indices = torch.argsort(merged_similarities, dim=-1, descending=True).cpu()
sorted_index_names = np.array(
image_index_names if image_index_names else multiple_text_index_names[0]
)[sorted_indices]
Expand Down
Loading

0 comments on commit 503ff60

Please sign in to comment.