From 9f4662798f3cc2af9f0fc4f5ca6212edc0016d01 Mon Sep 17 00:00:00 2001
From: whats2000 <60466660+whats2000@users.noreply.github.com>
Date: Wed, 28 Aug 2024 04:31:21 +0800
Subject: [PATCH] Add: Add the study about the sample combination
---
.../CompareCaptionsGenerate.ipynb | 262 -----
.../CompareCaptionsGenerateBLIP.ipynb | 921 ++++++++++++++++++
.../CompareCaptionsGenerateCLIP.ipynb | 909 +++++++++++++++++
src/ablation_experiment/validate_notebook.py | 435 +++++++++
4 files changed, 2265 insertions(+), 262 deletions(-)
delete mode 100644 src/ablation_experiment/CompareCaptionsGenerate.ipynb
create mode 100644 src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb
create mode 100644 src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb
create mode 100644 src/ablation_experiment/validate_notebook.py
diff --git a/src/ablation_experiment/CompareCaptionsGenerate.ipynb b/src/ablation_experiment/CompareCaptionsGenerate.ipynb
deleted file mode 100644
index bac3ac6..0000000
--- a/src/ablation_experiment/CompareCaptionsGenerate.ipynb
+++ /dev/null
@@ -1,262 +0,0 @@
-{
- "cells": [
- {
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-08-26T07:39:52.816157Z",
- "start_time": "2024-08-26T07:39:52.795021Z"
- }
- },
- "cell_type": "code",
- "source": [
- "import torch\n",
- "from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor\n",
- "\n",
- "from src.utils import device"
- ],
- "id": "af51e4f14f8d4d4b",
- "outputs": [],
- "execution_count": 6
- },
- {
- "metadata": {},
- "cell_type": "markdown",
- "source": "#
Step 1: Set up the experiment
",
- "id": "6b34e3ba439aaa4f"
- },
- {
- "metadata": {},
- "cell_type": "markdown",
- "source": "## Set up the cache for the experiment
",
- "id": "7c2294918930c572"
- },
- {
- "cell_type": "code",
- "id": "initial_id",
- "metadata": {
- "collapsed": true,
- "ExecuteTime": {
- "end_time": "2024-08-26T07:02:56.696609Z",
- "start_time": "2024-08-26T07:02:54.348924Z"
- }
- },
- "source": "cache = {}",
- "outputs": [],
- "execution_count": 1
- },
- {
- "metadata": {},
- "cell_type": "markdown",
- "source": "## Same concept as script version here
",
- "id": "9c64305289b9fe48"
- },
- {
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-08-26T07:02:56.777576Z",
- "start_time": "2024-08-26T07:02:56.775383Z"
- }
- },
- "cell_type": "code",
- "source": "CLIP_NAME = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'",
- "id": "22e5d9c8c2fc4547",
- "outputs": [],
- "execution_count": 2
- },
- {
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-08-26T07:03:03.069862Z",
- "start_time": "2024-08-26T07:02:56.857715Z"
- }
- },
- "cell_type": "code",
- "source": [
- "clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(CLIP_NAME, torch_dtype=torch.float32, projection_dim=768)\n",
- "clip_text_encoder = clip_text_encoder.float().to(device)\n",
- "\n",
- "print(\"clip text encoder loaded.\")\n",
- "clip_text_encoder.eval()"
- ],
- "id": "94a46a8e90581af4",
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "clip text encoder loaded.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "CLIPTextModelWithProjection(\n",
- " (text_model): CLIPTextTransformer(\n",
- " (embeddings): CLIPTextEmbeddings(\n",
- " (token_embedding): Embedding(49408, 768)\n",
- " (position_embedding): Embedding(77, 768)\n",
- " )\n",
- " (encoder): CLIPEncoder(\n",
- " (layers): ModuleList(\n",
- " (0-11): 12 x CLIPEncoderLayer(\n",
- " (self_attn): CLIPAttention(\n",
- " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
- " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
- " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
- " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
- " )\n",
- " (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
- " (mlp): CLIPMLP(\n",
- " (activation_fn): GELUActivation()\n",
- " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
- " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
- " )\n",
- " (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (text_projection): Linear(in_features=768, out_features=768, bias=False)\n",
- ")"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "execution_count": 3
- },
- {
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-08-26T07:03:04.748384Z",
- "start_time": "2024-08-26T07:03:03.204895Z"
- }
- },
- "cell_type": "code",
- "source": [
- "clip_img_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_NAME,torch_dtype=torch.float32, projection_dim=768)\n",
- "\n",
- "clip_img_encoder = clip_img_encoder.float().to(device)\n",
- "print(\"clip img encoder loaded.\")\n",
- "clip_img_encoder.eval()"
- ],
- "id": "32f7fb2e83ce7d74",
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "clip img encoder loaded.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "CLIPVisionModelWithProjection(\n",
- " (vision_model): CLIPVisionTransformer(\n",
- " (embeddings): CLIPVisionEmbeddings(\n",
- " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
- " (position_embedding): Embedding(257, 1024)\n",
- " )\n",
- " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
- " (encoder): CLIPEncoder(\n",
- " (layers): ModuleList(\n",
- " (0-23): 24 x CLIPEncoderLayer(\n",
- " (self_attn): CLIPAttention(\n",
- " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
- " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
- " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
- " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
- " )\n",
- " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
- " (mlp): CLIPMLP(\n",
- " (activation_fn): GELUActivation()\n",
- " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
- " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
- " )\n",
- " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (visual_projection): Linear(in_features=1024, out_features=768, bias=False)\n",
- ")"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "execution_count": 4
- },
- {
- "metadata": {
- "ExecuteTime": {
- "end_time": "2024-08-26T07:40:00.097035Z",
- "start_time": "2024-08-26T07:40:00.093435Z"
- }
- },
- "cell_type": "code",
- "source": [
- "print('CLIP preprocess pipeline is used')\n",
- "preprocess = CLIPImageProcessor(\n",
- " crop_size={'height': 224, 'width': 224},\n",
- " do_center_crop=True,\n",
- " do_convert_rgb=True,\n",
- " do_normalize=True,\n",
- " do_rescale=True,\n",
- " do_resize=True,\n",
- " image_mean=[0.48145466, 0.4578275, 0.40821073],\n",
- " image_std=[0.26862954, 0.26130258, 0.27577711],\n",
- " resample=3,\n",
- " size={'shortest_edge': 224},\n",
- ")"
- ],
- "id": "54ac20b43a2e8b69",
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CLIP preprocess pipeline is used\n"
- ]
- }
- ],
- "execution_count": 7
- },
- {
- "metadata": {},
- "cell_type": "code",
- "outputs": [],
- "execution_count": null,
- "source": "",
- "id": "5be046bb92d5e588"
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb b/src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb
new file mode 100644
index 0000000..a0ba3fa
--- /dev/null
+++ b/src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb
@@ -0,0 +1,921 @@
+{
+ "cells": [
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:20.675533Z",
+ "start_time": "2024-08-27T20:06:17.920748Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "import json\n",
+ "from typing import List\n",
+ "\n",
+ "import pandas as pd\n",
+ "from src.blip_modules.blip_text_encoder import BLIPTextEncoder\n",
+ "from src.blip_modules.blip_img_encoder import BLIPImgEncoder\n",
+ "\n",
+ "from src.ablation_experiment.validate_notebook import fiq_val_retrieval_text_image_combinations\n",
+ "from src.data_utils import targetpad_transform\n",
+ "from src.fashioniq_experiment.utils import get_combing_function_with_alpha\n",
+ "from src.utils import device\n",
+ "\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ],
+ "id": "af51e4f14f8d4d4b",
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Step 1: Set up the experiment
",
+ "id": "6b34e3ba439aaa4f"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Set up the cache for the experiment
",
+ "id": "7c2294918930c572"
+ },
+ {
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:20.699334Z",
+ "start_time": "2024-08-27T20:06:20.680753Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "cache = {}",
+ "id": "initial_id",
+ "outputs": [],
+ "execution_count": 2
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Same concept as script version here
",
+ "id": "9c64305289b9fe48"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:20.857834Z",
+ "start_time": "2024-08-27T20:06:20.838273Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "BLIP_PRETRAINED_PATH = '../../models/model_base.pth'\n",
+ "MED_CONFIG_PATH = '../blip_modules/med_config.json'"
+ ],
+ "id": "22e5d9c8c2fc4547",
+ "outputs": [],
+ "execution_count": 3
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:42.410241Z",
+ "start_time": "2024-08-27T20:06:20.885235Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "blip_text_encoder = BLIPTextEncoder(\n",
+ " BLIP_PRETRAINED_PATH, \n",
+ " MED_CONFIG_PATH,\n",
+ " use_pretrained_proj_layer=True\n",
+ ")\n",
+ "\n",
+ "blip_text_encoder = blip_text_encoder.to(device)\n",
+ "print(\"blip text encoder loaded.\")\n",
+ "blip_text_encoder.eval()"
+ ],
+ "id": "94a46a8e90581af4",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load checkpoint from ../../models/model_base.pth for text_encoder.\n",
+ "load checkpoint from ../../models/model_base.pth for text_proj.\n",
+ "blip text encoder loaded.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "BLIPTextEncoder(\n",
+ " (text_encoder): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(30524, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-11): 12 x BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (crossattention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (text_proj): Linear(in_features=768, out_features=256, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:52.195749Z",
+ "start_time": "2024-08-27T20:06:42.484449Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "blip_img_encoder = BLIPImgEncoder(BLIP_PRETRAINED_PATH)\n",
+ "blip_img_encoder = blip_img_encoder.to(device)\n",
+ "print(\"blip img encoder loaded.\")\n",
+ "blip_img_encoder.eval()"
+ ],
+ "id": "32f7fb2e83ce7d74",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "reshape position embedding from 196 to 576\n",
+ "load checkpoint from ../../models/model_base.pth for visual_encoder.\n",
+ "load checkpoint from ../../models/model_base.pth for vision_proj.\n",
+ "blip img encoder loaded.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "BLIPImgEncoder(\n",
+ " (visual_encoder): VisionTransformer(\n",
+ " (patch_embed): PatchEmbed(\n",
+ " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
+ " (norm): Identity()\n",
+ " )\n",
+ " (pos_drop): Dropout(p=0.0, inplace=False)\n",
+ " (blocks): ModuleList(\n",
+ " (0-11): 12 x Block(\n",
+ " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (attn): Attention(\n",
+ " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
+ " (attn_drop): Dropout(p=0.0, inplace=False)\n",
+ " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (proj_drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " (drop_path): Identity()\n",
+ " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " (mlp): Mlp(\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (act): GELU(approximate='none')\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (drop): Dropout(p=0.0, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
+ " )\n",
+ " (vision_proj): Linear(in_features=768, out_features=256, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 5
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:52.297318Z",
+ "start_time": "2024-08-27T20:06:52.273746Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "print('Target pad preprocess pipeline is used')\n",
+ "preprocess = targetpad_transform(1.25, 384)"
+ ],
+ "id": "54ac20b43a2e8b69",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Target pad preprocess pipeline is used\n"
+ ]
+ }
+ ],
+ "execution_count": 6
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Step 2: Load the MLLM generated text captions
",
+ "id": "75ae5725572a1752"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Load the addition text captions
",
+ "id": "fb8114bce1dbace0"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:06:52.575831Z",
+ "start_time": "2024-08-27T20:06:52.364940Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "with open('../../fashionIQ_dataset/labeled_images_cir_cleaned.json', 'r') as f:\n",
+ " text_captions = json.load(f)\n",
+ " \n",
+ "total_recall_list: List[List[pd.DataFrame]] = []\n",
+ "\n",
+ "print(f'Total number of text captions: {len(text_captions)}')"
+ ],
+ "id": "838d0c968d6fcf7a",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total number of text captions: 74357\n"
+ ]
+ }
+ ],
+ "execution_count": 7
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Step 3: Perform retrieval on the FashionIQ dataset
",
+ "id": "99e7660b7fce179"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Perform retrieval on the shirt category
",
+ "id": "5a8efc43414d05c"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:08:56.211110Z",
+ "start_time": "2024-08-27T20:06:52.591751Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "shirt_recall = fiq_val_retrieval_text_image_combinations(\n",
+ " 'shirt',\n",
+ " get_combing_function_with_alpha(0.95),\n",
+ " blip_text_encoder,\n",
+ " blip_img_encoder,\n",
+ " text_captions,\n",
+ " preprocess,\n",
+ " 0.2,\n",
+ " cache,\n",
+ ")"
+ ],
+ "id": "a0f1b60d0aecba10",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Evaluating feature combinations: 100%|██████████| 7/7 [00:24<00:00, 3.45s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": 8
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:08:56.259118Z",
+ "start_time": "2024-08-27T20:08:56.230637Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "shirt_recall",
+ "id": "96b511edd8178b1d",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50 Combination\n",
+ "0 0.2 22.522080 36.457312 First set\n",
+ "1 0.2 21.687929 35.525024 Second set\n",
+ "2 0.2 22.571148 35.672227 Third set\n",
+ "3 0.2 22.816487 37.095192 First and second set\n",
+ "4 0.2 22.669284 36.261040 Second and third set\n",
+ "5 0.2 23.110893 36.800784 First and third set\n",
+ "6 0.2 23.159961 36.997056 All sets"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ " Combination | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.2 | \n",
+ " 22.522080 | \n",
+ " 36.457312 | \n",
+ " First set | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.2 | \n",
+ " 21.687929 | \n",
+ " 35.525024 | \n",
+ " Second set | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.2 | \n",
+ " 22.571148 | \n",
+ " 35.672227 | \n",
+ " Third set | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.2 | \n",
+ " 22.816487 | \n",
+ " 37.095192 | \n",
+ " First and second set | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.2 | \n",
+ " 22.669284 | \n",
+ " 36.261040 | \n",
+ " Second and third set | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.2 | \n",
+ " 23.110893 | \n",
+ " 36.800784 | \n",
+ " First and third set | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.2 | \n",
+ " 23.159961 | \n",
+ " 36.997056 | \n",
+ " All sets | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 9
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Perform retrieval on the dress category
",
+ "id": "d4ab21021db7064b"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:10:21.596635Z",
+ "start_time": "2024-08-27T20:08:56.295538Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dress_recall = fiq_val_retrieval_text_image_combinations(\n",
+ " 'dress',\n",
+ " get_combing_function_with_alpha(0.95),\n",
+ " blip_text_encoder,\n",
+ " blip_img_encoder,\n",
+ " text_captions,\n",
+ " preprocess,\n",
+ " 0.2,\n",
+ " cache,\n",
+ ")"
+ ],
+ "id": "a59ce3e3360ce5e6",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Evaluating feature combinations: 100%|██████████| 7/7 [00:24<00:00, 3.44s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": 10
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:10:21.647068Z",
+ "start_time": "2024-08-27T20:10:21.621898Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "dress_recall",
+ "id": "1813a1e5d3d6cd2c",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50 Combination\n",
+ "0 0.2 20.327219 38.076350 First set\n",
+ "1 0.2 18.889439 37.233517 Second set\n",
+ "2 0.2 18.641546 35.448685 Third set\n",
+ "3 0.2 20.525533 38.572136 First and second set\n",
+ "4 0.2 19.732276 37.332672 Second and third set\n",
+ "5 0.2 20.079325 37.729302 First and third set\n",
+ "6 0.2 20.475954 38.621715 All sets"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ " Combination | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.2 | \n",
+ " 20.327219 | \n",
+ " 38.076350 | \n",
+ " First set | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.2 | \n",
+ " 18.889439 | \n",
+ " 37.233517 | \n",
+ " Second set | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.2 | \n",
+ " 18.641546 | \n",
+ " 35.448685 | \n",
+ " Third set | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.2 | \n",
+ " 20.525533 | \n",
+ " 38.572136 | \n",
+ " First and second set | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.2 | \n",
+ " 19.732276 | \n",
+ " 37.332672 | \n",
+ " Second and third set | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.2 | \n",
+ " 20.079325 | \n",
+ " 37.729302 | \n",
+ " First and third set | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.2 | \n",
+ " 20.475954 | \n",
+ " 38.621715 | \n",
+ " All sets | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 11
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Perform retrieval on the toptee category
",
+ "id": "f912a96497bb2a63"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:12:15.138640Z",
+ "start_time": "2024-08-27T20:10:21.683852Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "toptee_recall = fiq_val_retrieval_text_image_combinations(\n",
+ " 'toptee',\n",
+ " get_combing_function_with_alpha(0.95),\n",
+ " blip_text_encoder,\n",
+ " blip_img_encoder,\n",
+ " text_captions,\n",
+ " preprocess,\n",
+ " 0.2,\n",
+ " cache,\n",
+ ")"
+ ],
+ "id": "b705129d0dd555ee",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Evaluating feature combinations: 100%|██████████| 7/7 [00:24<00:00, 3.47s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": 12
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:12:15.188263Z",
+ "start_time": "2024-08-27T20:12:15.161793Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "toptee_recall",
+ "id": "dda4e0a6e1f697fd",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50 Combination\n",
+ "0 0.2 24.579297 46.098930 First set\n",
+ "1 0.2 23.406425 45.690975 Second set\n",
+ "2 0.2 24.987252 45.283020 Third set\n",
+ "3 0.2 25.038245 46.812850 First and second set\n",
+ "4 0.2 24.783275 46.863845 Second and third set\n",
+ "5 0.2 25.650179 46.965834 First and third set\n",
+ "6 0.2 25.140235 47.373790 All sets"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ " Combination | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.2 | \n",
+ " 24.579297 | \n",
+ " 46.098930 | \n",
+ " First set | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.2 | \n",
+ " 23.406425 | \n",
+ " 45.690975 | \n",
+ " Second set | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.2 | \n",
+ " 24.987252 | \n",
+ " 45.283020 | \n",
+ " Third set | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.2 | \n",
+ " 25.038245 | \n",
+ " 46.812850 | \n",
+ " First and second set | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.2 | \n",
+ " 24.783275 | \n",
+ " 46.863845 | \n",
+ " Second and third set | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.2 | \n",
+ " 25.650179 | \n",
+ " 46.965834 | \n",
+ " First and third set | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.2 | \n",
+ " 25.140235 | \n",
+ " 47.373790 | \n",
+ " All sets | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 13
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:12:15.284150Z",
+ "start_time": "2024-08-27T20:12:15.263692Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Change the index to 'Combination' column\n",
+ "shirt_recall.set_index('Combination', inplace=True)\n",
+ "dress_recall.set_index('Combination', inplace=True)\n",
+ "toptee_recall.set_index('Combination', inplace=True)"
+ ],
+ "id": "12cef5a9d541bbde",
+ "outputs": [],
+ "execution_count": 14
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:12:15.366604Z",
+ "start_time": "2024-08-27T20:12:15.343908Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Average the recall values\n",
+ "average_recall = (shirt_recall + dress_recall + toptee_recall) / 3\n",
+ "average_recall"
+ ],
+ "id": "7c4aead0e821fbf",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50\n",
+ "Combination \n",
+ "First set 0.2 22.476199 40.210864\n",
+ "Second set 0.2 21.327931 39.483172\n",
+ "Third set 0.2 22.066649 38.801310\n",
+ "First and second set 0.2 22.793422 40.826726\n",
+ "Second and third set 0.2 22.394945 40.152519\n",
+ "First and third set 0.2 22.946799 40.498640\n",
+ "All sets 0.2 22.925383 40.997520"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ "
\n",
+ " \n",
+ " Combination | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " First set | \n",
+ " 0.2 | \n",
+ " 22.476199 | \n",
+ " 40.210864 | \n",
+ "
\n",
+ " \n",
+ " Second set | \n",
+ " 0.2 | \n",
+ " 21.327931 | \n",
+ " 39.483172 | \n",
+ "
\n",
+ " \n",
+ " Third set | \n",
+ " 0.2 | \n",
+ " 22.066649 | \n",
+ " 38.801310 | \n",
+ "
\n",
+ " \n",
+ " First and second set | \n",
+ " 0.2 | \n",
+ " 22.793422 | \n",
+ " 40.826726 | \n",
+ "
\n",
+ " \n",
+ " Second and third set | \n",
+ " 0.2 | \n",
+ " 22.394945 | \n",
+ " 40.152519 | \n",
+ "
\n",
+ " \n",
+ " First and third set | \n",
+ " 0.2 | \n",
+ " 22.946799 | \n",
+ " 40.498640 | \n",
+ "
\n",
+ " \n",
+ " All sets | \n",
+ " 0.2 | \n",
+ " 22.925383 | \n",
+ " 40.997520 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 15
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T20:12:15.560807Z",
+ "start_time": "2024-08-27T20:12:15.558677Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "",
+ "id": "72f821a5bb798988",
+ "outputs": [],
+ "execution_count": null
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb b/src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb
new file mode 100644
index 0000000..05dcaf6
--- /dev/null
+++ b/src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb
@@ -0,0 +1,909 @@
+{
+ "cells": [
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:51.095040Z",
+ "start_time": "2024-08-27T15:50:47.706586Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "import json\n",
+ "from typing import List\n",
+ "\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "from clip import tokenize\n",
+ "from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor\n",
+ "\n",
+ "from src.ablation_experiment.validate_notebook import fiq_val_retrieval_text_image_combinations_clip\n",
+ "from src.fashioniq_experiment.utils import get_combing_function_with_alpha\n",
+ "from src.utils import device\n",
+ "\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ],
+ "id": "af51e4f14f8d4d4b",
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Step 1: Set up the experiment
",
+ "id": "6b34e3ba439aaa4f"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Set up the cache for the experiment
",
+ "id": "7c2294918930c572"
+ },
+ {
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:51.124450Z",
+ "start_time": "2024-08-27T15:50:51.106174Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "cache = {}",
+ "id": "initial_id",
+ "outputs": [],
+ "execution_count": 2
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Same concept as script version here
",
+ "id": "9c64305289b9fe48"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:51.303373Z",
+ "start_time": "2024-08-27T15:50:51.285660Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "CLIP_NAME = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'",
+ "id": "22e5d9c8c2fc4547",
+ "outputs": [],
+ "execution_count": 3
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:57.640569Z",
+ "start_time": "2024-08-27T15:50:51.341862Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(CLIP_NAME, torch_dtype=torch.float32, projection_dim=768)\n",
+ "clip_text_encoder = clip_text_encoder.float().to(device)\n",
+ "\n",
+ "print(\"clip text encoder loaded.\")\n",
+ "clip_text_encoder.eval()"
+ ],
+ "id": "94a46a8e90581af4",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "clip text encoder loaded.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "CLIPTextModelWithProjection(\n",
+ " (text_model): CLIPTextTransformer(\n",
+ " (embeddings): CLIPTextEmbeddings(\n",
+ " (token_embedding): Embedding(49408, 768)\n",
+ " (position_embedding): Embedding(77, 768)\n",
+ " )\n",
+ " (encoder): CLIPEncoder(\n",
+ " (layers): ModuleList(\n",
+ " (0-11): 12 x CLIPEncoderLayer(\n",
+ " (self_attn): CLIPAttention(\n",
+ " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
+ " )\n",
+ " (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+ " (mlp): CLIPMLP(\n",
+ " (activation_fn): GELUActivation()\n",
+ " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " )\n",
+ " (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (text_projection): Linear(in_features=768, out_features=768, bias=False)\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:59.474130Z",
+ "start_time": "2024-08-27T15:50:57.657253Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "clip_img_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_NAME,torch_dtype=torch.float32, projection_dim=768)\n",
+ "\n",
+ "clip_img_encoder = clip_img_encoder.float().to(device)\n",
+ "print(\"clip img encoder loaded.\")\n",
+ "clip_img_encoder.eval()"
+ ],
+ "id": "32f7fb2e83ce7d74",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "clip img encoder loaded.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "CLIPVisionModelWithProjection(\n",
+ " (vision_model): CLIPVisionTransformer(\n",
+ " (embeddings): CLIPVisionEmbeddings(\n",
+ " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n",
+ " (position_embedding): Embedding(257, 1024)\n",
+ " )\n",
+ " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (encoder): CLIPEncoder(\n",
+ " (layers): ModuleList(\n",
+ " (0-23): 24 x CLIPEncoderLayer(\n",
+ " (self_attn): CLIPAttention(\n",
+ " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
+ " )\n",
+ " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " (mlp): CLIPMLP(\n",
+ " (activation_fn): GELUActivation()\n",
+ " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n",
+ " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n",
+ " )\n",
+ " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (visual_projection): Linear(in_features=1024, out_features=768, bias=False)\n",
+ ")"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 5
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:59.515715Z",
+ "start_time": "2024-08-27T15:50:59.498141Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "print('CLIP preprocess pipeline is used')\n",
+ "preprocess = CLIPImageProcessor(\n",
+ " crop_size={'height': 224, 'width': 224},\n",
+ " do_center_crop=True,\n",
+ " do_convert_rgb=True,\n",
+ " do_normalize=True,\n",
+ " do_rescale=True,\n",
+ " do_resize=True,\n",
+ " image_mean=[0.48145466, 0.4578275, 0.40821073],\n",
+ " image_std=[0.26862954, 0.26130258, 0.27577711],\n",
+ " resample=3,\n",
+ " size={'shortest_edge': 224},\n",
+ ")"
+ ],
+ "id": "54ac20b43a2e8b69",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CLIP preprocess pipeline is used\n"
+ ]
+ }
+ ],
+ "execution_count": 6
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:59.562263Z",
+ "start_time": "2024-08-27T15:50:59.547800Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "clip_tokenizer = tokenize",
+ "id": "91b281d7472fef42",
+ "outputs": [],
+ "execution_count": 7
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Step 2: Load the MLLM generated text captions
",
+ "id": "75ae5725572a1752"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Load the addition text captions
",
+ "id": "fb8114bce1dbace0"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:50:59.872357Z",
+ "start_time": "2024-08-27T15:50:59.607774Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "with open('../../fashionIQ_dataset/labeled_images_cir_cleaned.json', 'r') as f:\n",
+ " text_captions = json.load(f)\n",
+ " \n",
+ "total_recall_list: List[List[pd.DataFrame]] = []\n",
+ "\n",
+ "print(f'Total number of text captions: {len(text_captions)}')"
+ ],
+ "id": "838d0c968d6fcf7a",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total number of text captions: 74357\n"
+ ]
+ }
+ ],
+ "execution_count": 8
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Step 3: Perform retrieval on the FashionIQ dataset
",
+ "id": "99e7660b7fce179"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Perform retrieval on the shirt category
",
+ "id": "5a8efc43414d05c"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:56:59.038961Z",
+ "start_time": "2024-08-27T15:56:17.685425Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "shirt_recall = fiq_val_retrieval_text_image_combinations_clip(\n",
+ " 'shirt',\n",
+ " get_combing_function_with_alpha(0.8),\n",
+ " clip_text_encoder,\n",
+ " clip_img_encoder,\n",
+ " clip_tokenizer,\n",
+ " text_captions,\n",
+ " preprocess,\n",
+ " 0.1,\n",
+ " cache,\n",
+ ")"
+ ],
+ "id": "a0f1b60d0aecba10",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Evaluating feature combinations: 100%|██████████| 7/7 [00:41<00:00, 5.90s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": 11
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:57:03.849951Z",
+ "start_time": "2024-08-27T15:57:03.828771Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "shirt_recall",
+ "id": "96b511edd8178b1d",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50 Combination\n",
+ "0 0.1 32.679096 49.509323 First set\n",
+ "1 0.1 32.777232 48.135427 Second set\n",
+ "2 0.1 32.384691 47.988224 Third set\n",
+ "3 0.1 33.022571 49.018645 First and second set\n",
+ "4 0.1 32.924435 47.742885 Second and third set\n",
+ "5 0.1 33.071640 48.969579 First and third set\n",
+ "6 0.1 32.777232 48.969579 All sets"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ " Combination | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.1 | \n",
+ " 32.679096 | \n",
+ " 49.509323 | \n",
+ " First set | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.1 | \n",
+ " 32.777232 | \n",
+ " 48.135427 | \n",
+ " Second set | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.1 | \n",
+ " 32.384691 | \n",
+ " 47.988224 | \n",
+ " Third set | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.1 | \n",
+ " 33.022571 | \n",
+ " 49.018645 | \n",
+ " First and second set | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.1 | \n",
+ " 32.924435 | \n",
+ " 47.742885 | \n",
+ " Second and third set | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.1 | \n",
+ " 33.071640 | \n",
+ " 48.969579 | \n",
+ " First and third set | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.1 | \n",
+ " 32.777232 | \n",
+ " 48.969579 | \n",
+ " All sets | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 12
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Perform retrieval on the dress category
",
+ "id": "d4ab21021db7064b"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T15:59:00.165828Z",
+ "start_time": "2024-08-27T15:57:18.015381Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dress_recall = fiq_val_retrieval_text_image_combinations_clip(\n",
+ " 'dress',\n",
+ " get_combing_function_with_alpha(0.8),\n",
+ " clip_text_encoder,\n",
+ " clip_img_encoder,\n",
+ " clip_tokenizer,\n",
+ " text_captions,\n",
+ " preprocess,\n",
+ " 0.1,\n",
+ " cache,\n",
+ ")"
+ ],
+ "id": "a59ce3e3360ce5e6",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Evaluating feature combinations: 100%|██████████| 7/7 [00:41<00:00, 5.89s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": 13
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T16:01:09.123300300Z",
+ "start_time": "2024-08-27T15:59:00.189498Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "dress_recall",
+ "id": "1813a1e5d3d6cd2c",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50 Combination\n",
+ "0 0.1 25.582549 46.603867 First set\n",
+ "1 0.1 24.144769 46.455130 Second set\n",
+ "2 0.1 24.243927 46.207237 Third set\n",
+ "3 0.1 26.177493 47.347546 First and second set\n",
+ "4 0.1 24.987605 46.306396 Second and third set\n",
+ "5 0.1 25.929597 47.198811 First and third set\n",
+ "6 0.1 25.880021 47.297966 All sets"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ " Combination | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.1 | \n",
+ " 25.582549 | \n",
+ " 46.603867 | \n",
+ " First set | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.1 | \n",
+ " 24.144769 | \n",
+ " 46.455130 | \n",
+ " Second set | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.1 | \n",
+ " 24.243927 | \n",
+ " 46.207237 | \n",
+ " Third set | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.1 | \n",
+ " 26.177493 | \n",
+ " 47.347546 | \n",
+ " First and second set | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.1 | \n",
+ " 24.987605 | \n",
+ " 46.306396 | \n",
+ " Second and third set | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.1 | \n",
+ " 25.929597 | \n",
+ " 47.198811 | \n",
+ " First and third set | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.1 | \n",
+ " 25.880021 | \n",
+ " 47.297966 | \n",
+ " All sets | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 14
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Perform retrieval on the toptee category
",
+ "id": "f912a96497bb2a63"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T16:01:08.524086Z",
+ "start_time": "2024-08-27T15:59:00.334132Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "toptee_recall = fiq_val_retrieval_text_image_combinations_clip(\n",
+ " 'toptee',\n",
+ " get_combing_function_with_alpha(0.8),\n",
+ " clip_text_encoder,\n",
+ " clip_img_encoder,\n",
+ " clip_tokenizer,\n",
+ " text_captions,\n",
+ " preprocess,\n",
+ " 0.1,\n",
+ " cache,\n",
+ ")"
+ ],
+ "id": "b705129d0dd555ee",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Evaluating feature combinations: 100%|██████████| 7/7 [00:41<00:00, 5.90s/it]\n"
+ ]
+ }
+ ],
+ "execution_count": 15
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T16:01:08.578006Z",
+ "start_time": "2024-08-27T16:01:08.555585Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "toptee_recall",
+ "id": "dda4e0a6e1f697fd",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50 Combination\n",
+ "0 0.1 36.053035 56.348801 First set\n",
+ "1 0.1 35.390106 55.940849 Second set\n",
+ "2 0.1 35.900050 55.685872 Third set\n",
+ "3 0.1 36.308005 56.450790 First and second set\n",
+ "4 0.1 35.237125 56.297809 Second and third set\n",
+ "5 0.1 36.308005 56.756759 First and third set\n",
+ "6 0.1 35.951045 56.705761 All sets"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ " Combination | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.1 | \n",
+ " 36.053035 | \n",
+ " 56.348801 | \n",
+ " First set | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.1 | \n",
+ " 35.390106 | \n",
+ " 55.940849 | \n",
+ " Second set | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.1 | \n",
+ " 35.900050 | \n",
+ " 55.685872 | \n",
+ " Third set | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.1 | \n",
+ " 36.308005 | \n",
+ " 56.450790 | \n",
+ " First and second set | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.1 | \n",
+ " 35.237125 | \n",
+ " 56.297809 | \n",
+ " Second and third set | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.1 | \n",
+ " 36.308005 | \n",
+ " 56.756759 | \n",
+ " First and third set | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.1 | \n",
+ " 35.951045 | \n",
+ " 56.705761 | \n",
+ " All sets | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 16
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T19:24:38.356362Z",
+ "start_time": "2024-08-27T19:24:38.333283Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Change the index to 'Combination' column\n",
+ "shirt_recall.set_index('Combination', inplace=True)\n",
+ "dress_recall.set_index('Combination', inplace=True)\n",
+ "toptee_recall.set_index('Combination', inplace=True)"
+ ],
+ "id": "12cef5a9d541bbde",
+ "outputs": [],
+ "execution_count": 18
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-08-27T19:24:41.043674Z",
+ "start_time": "2024-08-27T19:24:41.021762Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Average the recall values\n",
+ "average_recall = (shirt_recall + dress_recall + toptee_recall) / 3\n",
+ "average_recall"
+ ],
+ "id": "7c4aead0e821fbf",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " beta recall_at10 recall_at50\n",
+ "Combination \n",
+ "First set 0.1 31.438227 50.820664\n",
+ "Second set 0.1 30.770702 50.177135\n",
+ "Third set 0.1 30.842889 49.960444\n",
+ "First and second set 0.1 31.836023 50.938994\n",
+ "Second and third set 0.1 31.049721 50.115697\n",
+ "First and third set 0.1 31.769748 50.975050\n",
+ "All sets 0.1 31.536099 50.991102"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " beta | \n",
+ " recall_at10 | \n",
+ " recall_at50 | \n",
+ "
\n",
+ " \n",
+ " Combination | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " First set | \n",
+ " 0.1 | \n",
+ " 31.438227 | \n",
+ " 50.820664 | \n",
+ "
\n",
+ " \n",
+ " Second set | \n",
+ " 0.1 | \n",
+ " 30.770702 | \n",
+ " 50.177135 | \n",
+ "
\n",
+ " \n",
+ " Third set | \n",
+ " 0.1 | \n",
+ " 30.842889 | \n",
+ " 49.960444 | \n",
+ "
\n",
+ " \n",
+ " First and second set | \n",
+ " 0.1 | \n",
+ " 31.836023 | \n",
+ " 50.938994 | \n",
+ "
\n",
+ " \n",
+ " Second and third set | \n",
+ " 0.1 | \n",
+ " 31.049721 | \n",
+ " 50.115697 | \n",
+ "
\n",
+ " \n",
+ " First and third set | \n",
+ " 0.1 | \n",
+ " 31.769748 | \n",
+ " 50.975050 | \n",
+ "
\n",
+ " \n",
+ " All sets | \n",
+ " 0.1 | \n",
+ " 31.536099 | \n",
+ " 50.991102 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 19
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": "",
+ "id": "72f821a5bb798988"
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/src/ablation_experiment/validate_notebook.py b/src/ablation_experiment/validate_notebook.py
new file mode 100644
index 0000000..244e25f
--- /dev/null
+++ b/src/ablation_experiment/validate_notebook.py
@@ -0,0 +1,435 @@
+from typing import List
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+from src.data_utils import FashionIQDataset
+from src.utils import extract_index_features_with_text_captions_clip, extract_index_features_clip, \
+ extract_index_features_with_text_captions, extract_index_features
+from src.validate import generate_fiq_val_predictions
+from src.validate_clip import generate_fiq_val_predictions as generate_fiq_val_predictions_clip
+
+
+def compute_fiq_val_metrics_text_image_combinations(
+ relative_val_dataset: FashionIQDataset,
+ blip_text_encoder: torch.nn.Module,
+ multiple_text_index_features: List[torch.Tensor],
+ multiple_text_index_names: List[List[str]],
+ image_index_features: torch.Tensor,
+ image_index_names: List[str],
+ combining_function: callable,
+ beta: float,
+) -> pd.DataFrame:
+ """
+ Compute validation metrics on FashionIQ dataset combining text and image distances.
+
+ :param relative_val_dataset: FashionIQ validation dataset in relative mode
+ :param blip_text_encoder: BLIP text encoder
+ :param multiple_text_index_features: validation index features from text
+ :param multiple_text_index_names: validation index names from text
+ :param image_index_features: validation image index features
+ :param image_index_names: validation image index names
+ :param combining_function: function that combines features
+ :param beta: beta value for the combination of text and image distances
+ :return: the computed validation metrics
+ """
+ all_text_distances = []
+ results = []
+ target_names = None
+
+ # Compute distances for individual text features
+ for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names):
+ # Generate text predictions and normalize features
+ predicted_text_features, target_names = generate_fiq_val_predictions(
+ blip_text_encoder,
+ relative_val_dataset,
+ combining_function,
+ text_names,
+ text_features,
+ no_print_output=True,
+ )
+ # Normalize features
+ text_features = F.normalize(text_features, dim=-1)
+ predicted_text_features = F.normalize(predicted_text_features, dim=-1)
+
+ # Compute cosine similarity and convert to distance
+ cosine_similarities = torch.mm(predicted_text_features, text_features.T)
+ distances = 1 - cosine_similarities
+ all_text_distances.append(distances)
+
+ # Normalize and compute distances for image features if available
+ if image_index_features is not None and len(image_index_features) > 0:
+ predicted_image_features, _ = generate_fiq_val_predictions(
+ blip_text_encoder,
+ relative_val_dataset,
+ combining_function,
+ image_index_names,
+ image_index_features,
+ no_print_output=True,
+ )
+
+ # Normalize and compute distances
+ image_index_features = F.normalize(image_index_features, dim=-1).float()
+ image_distances = 1 - predicted_image_features @ image_index_features.T
+ else:
+ image_distances = torch.zeros_like(all_text_distances[0])
+
+ # Merge text distances
+ merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0)
+
+ merged_distances = beta * merged_text_distances + (1 - beta) * image_distances
+ sorted_indices = torch.argsort(merged_distances, dim=-1).cpu()
+ sorted_index_names = np.array(image_index_names if image_index_names else multiple_text_index_names[0])[
+ sorted_indices]
+ labels = torch.tensor(
+ sorted_index_names == np.repeat(
+ np.array(target_names),
+ len(image_index_names if image_index_names else multiple_text_index_names[0])
+ ).reshape(len(target_names), -1)
+ )
+ assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
+ recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
+ recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
+ results.append({"beta": beta, "recall_at10": recall_at10, "recall_at50": recall_at50})
+
+ return pd.DataFrame(results)
+
+
+def compute_fiq_val_metrics_text_image_combinations_clip(
+ relative_val_dataset: FashionIQDataset,
+ clip_text_encoder: torch.nn.Module,
+ clip_tokenizer: callable,
+ multiple_text_index_features: List[torch.Tensor],
+ multiple_text_index_names: List[List[str]],
+ image_index_features: torch.Tensor,
+ image_index_names: List[str],
+ combining_function: callable,
+ beta: float,
+) -> pd.DataFrame:
+ """
+ Compute validation metrics on FashionIQ dataset combining text and image distances.
+
+ :param relative_val_dataset: FashionIQ validation dataset in relative mode
+ :param clip_text_encoder: CLIP text encoder
+ :param clip_tokenizer: CLIP tokenizer
+ :param multiple_text_index_features: validation index features from text
+ :param multiple_text_index_names: validation index names from text
+ :param image_index_features: validation image index features
+ :param image_index_names: validation image index names
+ :param combining_function: function that combines features
+ :param beta: beta value for the combination of text and image distances
+ :return: the computed validation metrics
+ """
+ all_text_distances = []
+ results = []
+ target_names = None
+
+ # Compute distances for individual text features
+ for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names):
+ # Generate text predictions and normalize features
+ predicted_text_features, target_names = generate_fiq_val_predictions_clip(
+ clip_text_encoder,
+ clip_tokenizer,
+ relative_val_dataset,
+ combining_function,
+ text_names,
+ text_features,
+ no_print_output=True,
+ )
+ # Normalize features
+ text_features = F.normalize(text_features, dim=-1)
+ predicted_text_features = F.normalize(predicted_text_features, dim=-1)
+
+ # Compute cosine similarity and convert to distance
+ cosine_similarities = torch.mm(predicted_text_features, text_features.T)
+ distances = 1 - cosine_similarities
+ all_text_distances.append(distances)
+
+ # Normalize and compute distances for image features if available
+ if image_index_features is not None and len(image_index_features) > 0:
+ predicted_image_features, _ = generate_fiq_val_predictions_clip(
+ clip_text_encoder,
+ clip_tokenizer,
+ relative_val_dataset,
+ combining_function,
+ image_index_names,
+ image_index_features,
+ no_print_output=True,
+ )
+
+ # Normalize and compute distances
+ image_index_features = F.normalize(image_index_features, dim=-1).float()
+ image_distances = 1 - predicted_image_features @ image_index_features.T
+ else:
+ image_distances = torch.zeros_like(all_text_distances[0])
+
+ # Merge text distances
+ merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0)
+
+ merged_distances = beta * merged_text_distances + (1 - beta) * image_distances
+ sorted_indices = torch.argsort(merged_distances, dim=-1).cpu()
+ sorted_index_names = np.array(
+ image_index_names if image_index_names else multiple_text_index_names[0]
+ )[sorted_indices]
+
+ labels = torch.tensor(
+ sorted_index_names == np.repeat(
+ np.array(target_names),
+ len(image_index_names if image_index_names else multiple_text_index_names[0])
+ ).reshape(len(target_names), -1)
+ )
+ assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int())
+ recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100
+ recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100
+ results.append({"beta": beta, "recall_at10": recall_at10, "recall_at50": recall_at50})
+
+ return pd.DataFrame(results)
+
+
+def fiq_val_retrieval_text_image_combinations(
+ dress_type: str,
+ combining_function: callable,
+ blip_text_encoder: torch.nn.Module,
+ blip_img_encoder: torch.nn.Module,
+ text_captions: List[dict],
+ preprocess: callable,
+ beta: float,
+ cache: dict,
+) -> pd.DataFrame:
+ """
+ Perform retrieval on FashionIQ validation set computing the metrics for different text feature combinations.
+
+ :param dress_type: FashionIQ category on which perform the retrieval
+ :param combining_function: function which takes as input (image_features, text_features) and outputs the combined features
+ :param blip_text_encoder: BLIP text model
+ :param blip_img_encoder: BLIP image model
+ :param text_captions: text captions for the FashionIQ dataset
+ :param preprocess: preprocess pipeline
+ :param beta: beta value for the combination of text and image distances
+ :param cache: cache dictionary
+ :return: DataFrame containing the retrieval metrics for each combination of text features
+ """
+ cache_key = f"{dress_type}_cache"
+
+ blip_text_encoder = blip_text_encoder.float().eval()
+ blip_img_encoder = blip_img_encoder.float().eval()
+
+ if cache_key not in cache:
+ # Define the validation datasets and extract the index features
+ classic_val_dataset = FashionIQDataset(
+ 'val',
+ [dress_type],
+ 'classic',
+ preprocess,
+ no_print_output=True,
+ )
+
+ multiple_index_features, multiple_index_names = [], []
+
+ for i in range(3):
+ index_features, index_names, _ = extract_index_features_with_text_captions(
+ classic_val_dataset,
+ blip_text_encoder,
+ text_captions,
+ i + 1,
+ no_print_output=True,
+ )
+ multiple_index_features.append(index_features)
+ multiple_index_names.append(index_names)
+
+ image_index_features, image_index_names = extract_index_features(
+ classic_val_dataset,
+ blip_img_encoder,
+ no_print_output=True,
+ )
+
+ cache[cache_key] = {
+ "multiple_index_features": multiple_index_features,
+ "multiple_index_names": multiple_index_names,
+ "image_index_features": image_index_features,
+ "image_index_names": image_index_names
+ }
+ else:
+ multiple_index_features = cache[cache_key]["multiple_index_features"]
+ multiple_index_names = cache[cache_key]["multiple_index_names"]
+ image_index_features = cache[cache_key]["image_index_features"]
+ image_index_names = cache[cache_key]["image_index_names"]
+
+ relative_val_dataset = FashionIQDataset(
+ 'val',
+ [dress_type],
+ 'relative',
+ preprocess,
+ no_print_output=True,
+ )
+
+ # Define the combinations of features to evaluate
+ feature_combinations = [
+ ([multiple_index_features[0]], [multiple_index_names[0]]), # Only first set
+ ([multiple_index_features[1]], [multiple_index_names[1]]), # Only second set
+ ([multiple_index_features[2]], [multiple_index_names[2]]), # Only third set
+ ([multiple_index_features[0], multiple_index_features[1]], [multiple_index_names[0], multiple_index_names[1]]), # First and second set
+ ([multiple_index_features[1], multiple_index_features[2]], [multiple_index_names[1], multiple_index_names[2]]), # Second and third set
+ ([multiple_index_features[0], multiple_index_features[2]], [multiple_index_names[0], multiple_index_names[2]]), # First and third set
+ (multiple_index_features, multiple_index_names) # All sets
+ ]
+
+ results = []
+
+ combination_name = [
+ 'First set',
+ 'Second set',
+ 'Third set',
+ 'First and second set',
+ 'Second and third set',
+ 'First and third set',
+ 'All sets',
+ ]
+
+ for idx, (features_combination, names_combination) in tqdm(
+ enumerate(feature_combinations),
+ desc="Evaluating feature combinations",
+ total=len(feature_combinations),
+ ):
+ result = compute_fiq_val_metrics_text_image_combinations(
+ relative_val_dataset,
+ blip_text_encoder,
+ features_combination,
+ names_combination,
+ image_index_features,
+ image_index_names,
+ combining_function,
+ beta,
+ )
+ result['Combination'] = combination_name[idx]
+ results.append(result)
+
+ # Concatenate all results into a single DataFrame
+ return pd.concat(results, ignore_index=True)
+
+
+def fiq_val_retrieval_text_image_combinations_clip(
+ dress_type: str,
+ combining_function: callable,
+ clip_text_encoder: torch.nn.Module,
+ clip_img_encoder: torch.nn.Module,
+ clip_tokenizer: callable,
+ text_captions: List[dict],
+ preprocess: callable,
+ beta: float,
+ cache: dict,
+) -> pd.DataFrame:
+ """
+ Perform retrieval on FashionIQ validation set computing the metrics for different text feature combinations.
+
+ :param dress_type: FashionIQ category on which perform the retrieval
+ :param combining_function: function which takes as input (image_features, text_features) and outputs the combined features
+ :param clip_text_encoder: CLIP text model
+ :param clip_img_encoder: CLIP image model
+ :param clip_tokenizer: CLIP tokenizer
+ :param text_captions: text captions for the FashionIQ dataset
+ :param preprocess: preprocess pipeline
+ :param beta: beta value for the combination of text and image distances
+ :param cache: cache dictionary
+ :return: DataFrame containing the retrieval metrics for each combination of text features
+ """
+ cache_key = f"{dress_type}_cache"
+
+ clip_text_encoder = clip_text_encoder.float().eval()
+ clip_img_encoder = clip_img_encoder.float().eval()
+
+ if cache_key not in cache:
+ # Define the validation datasets and extract the index features
+ classic_val_dataset = FashionIQDataset(
+ 'val',
+ [dress_type],
+ 'classic',
+ preprocess,
+ no_print_output=True,
+ )
+
+ multiple_index_features, multiple_index_names = [], []
+
+ for i in range(3):
+ index_features, index_names, _ = extract_index_features_with_text_captions_clip(
+ classic_val_dataset,
+ clip_text_encoder,
+ clip_tokenizer,
+ text_captions,
+ i + 1,
+ no_print_output=True,
+ )
+ multiple_index_features.append(index_features)
+ multiple_index_names.append(index_names)
+
+ image_index_features, image_index_names = extract_index_features_clip(
+ classic_val_dataset,
+ clip_img_encoder,
+ no_print_output=True,
+ )
+
+ cache[cache_key] = {
+ "multiple_index_features": multiple_index_features,
+ "multiple_index_names": multiple_index_names,
+ "image_index_features": image_index_features,
+ "image_index_names": image_index_names
+ }
+ else:
+ multiple_index_features = cache[cache_key]["multiple_index_features"]
+ multiple_index_names = cache[cache_key]["multiple_index_names"]
+ image_index_features = cache[cache_key]["image_index_features"]
+ image_index_names = cache[cache_key]["image_index_names"]
+
+ relative_val_dataset = FashionIQDataset(
+ 'val',
+ [dress_type],
+ 'relative',
+ preprocess,
+ no_print_output=True,
+ )
+
+ # Define the combinations of features to evaluate
+ feature_combinations = [
+ ([multiple_index_features[0]], [multiple_index_names[0]]), # Only first set
+ ([multiple_index_features[1]], [multiple_index_names[1]]), # Only second set
+ ([multiple_index_features[2]], [multiple_index_names[2]]), # Only third set
+ ([multiple_index_features[0], multiple_index_features[1]], [multiple_index_names[0], multiple_index_names[1]]), # First and second set
+ ([multiple_index_features[1], multiple_index_features[2]], [multiple_index_names[1], multiple_index_names[2]]), # Second and third set
+ ([multiple_index_features[0], multiple_index_features[2]], [multiple_index_names[0], multiple_index_names[2]]), # First and third set
+ (multiple_index_features, multiple_index_names) # All sets
+ ]
+
+ results = []
+ combination_name = [
+ 'First set',
+ 'Second set',
+ 'Third set',
+ 'First and second set',
+ 'Second and third set',
+ 'First and third set',
+ 'All sets',
+ ]
+ for idx, (features_combination, names_combination) in tqdm(
+ enumerate(feature_combinations),
+ desc="Evaluating feature combinations",
+ total=len(feature_combinations),
+ ):
+ result = compute_fiq_val_metrics_text_image_combinations_clip(
+ relative_val_dataset,
+ clip_text_encoder,
+ clip_tokenizer,
+ features_combination,
+ names_combination,
+ image_index_features,
+ image_index_names,
+ combining_function,
+ beta,
+ )
+ result['Combination'] = combination_name[idx]
+ results.append(result)
+
+ # Concatenate all results into a single DataFrame
+ return pd.concat(results, ignore_index=True)