Skip to content

Commit

Permalink
Add: Rerun the CLIP for FashionIQ
Browse files Browse the repository at this point in the history
  • Loading branch information
whats2000 committed Aug 19, 2024
1 parent 18bbb6f commit d5a319b
Show file tree
Hide file tree
Showing 8 changed files with 29,069 additions and 32 deletions.
7,243 changes: 7,243 additions & 0 deletions src/fashioniq_experiment/clip/CLIP-VIT-G14-laion.ipynb

Large diffs are not rendered by default.

7,168 changes: 7,168 additions & 0 deletions src/fashioniq_experiment/clip/CLIP-VIT-G14.ipynb

Large diffs are not rendered by default.

7,112 changes: 7,112 additions & 0 deletions src/fashioniq_experiment/clip/CLIP-VIT-H14.ipynb

Large diffs are not rendered by default.

7,112 changes: 7,112 additions & 0 deletions src/fashioniq_experiment/clip/CLIP-VIT-L14.ipynb

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions src/fashioniq_experiment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,62 @@ def prepare_and_plot_recall_pivot(data: pd.DataFrame, title: str):
plt.show()


def filter_data_by_scale(data: pd.DataFrame, scale: float):
"""
Filter the data to include only rows and columns that match the given scale (formatted to 2 decimal places).
:param data: Original data
:param scale: The scale to filter by (e.g., 0.1)
:return: Filtered data
"""
# Convert index and columns to formatted strings
data.index = [f"{float(idx):.2f}" for idx in data.index]
data.columns = [f"{float(col):.2f}" for col in data.columns]

# Create a list of formatted strings that match the desired scale
scale_values = [f"{i * scale:.2f}" for i in range(int(1 / scale) + 1)]

# Filter rows and columns by the formatted scale values
filtered_data = data.loc[data.index.isin(scale_values), data.columns.isin(scale_values)]

return filtered_data


def filter_and_plot_recall_pivot(
data: pd.DataFrame,
title: str,
font_size: int = 16,
annot_font_size: int = 14,
filter_scale: float = 0.1,
):
"""
Prepare and plot a pivot table for recall@10 or recall@50.
:param data: Pivot table data
:param title: plot title
:param font_size: Font size for the title and axis labels
:param annot_font_size: Font size for the annotations in the heatmap
:param filter_scale: Scale to filter the data by
"""
data = filter_data_by_scale(data, filter_scale)

data.index = [f"{float(idx):.2f}" for idx in data.index]
data.columns = [f"{float(col):.2f}" for col in data.columns]

plt.figure(figsize=(8, 8), dpi=300) # Adjust figure size and resolution
sns.heatmap(data, annot=True, fmt=".2f", cmap="magma", vmin=0, vmax=100,
cbar_kws={'format': '%.2f'}, annot_kws={"size": annot_font_size})
plt.title(title, fontsize=font_size)
plt.xlabel('Alpha', fontsize=font_size)
plt.ylabel('Beta', fontsize=font_size)
plt.xticks(rotation=45, fontsize=font_size)
plt.yticks(rotation=0, fontsize=font_size)

plt.tight_layout()
plt.savefig("heatmap.png", bbox_inches='tight', dpi=300)
plt.show()


def prepare_ground_truths(json_data) -> dict:
"""
Prepare ground truth data from the JSON structure.
Expand Down
Loading

0 comments on commit d5a319b

Please sign in to comment.