From 9f4662798f3cc2af9f0fc4f5ca6212edc0016d01 Mon Sep 17 00:00:00 2001 From: whats2000 <60466660+whats2000@users.noreply.github.com> Date: Wed, 28 Aug 2024 04:31:21 +0800 Subject: [PATCH] Add: Add the study about the sample combination --- .../CompareCaptionsGenerate.ipynb | 262 ----- .../CompareCaptionsGenerateBLIP.ipynb | 921 ++++++++++++++++++ .../CompareCaptionsGenerateCLIP.ipynb | 909 +++++++++++++++++ src/ablation_experiment/validate_notebook.py | 435 +++++++++ 4 files changed, 2265 insertions(+), 262 deletions(-) delete mode 100644 src/ablation_experiment/CompareCaptionsGenerate.ipynb create mode 100644 src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb create mode 100644 src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb create mode 100644 src/ablation_experiment/validate_notebook.py diff --git a/src/ablation_experiment/CompareCaptionsGenerate.ipynb b/src/ablation_experiment/CompareCaptionsGenerate.ipynb deleted file mode 100644 index bac3ac6..0000000 --- a/src/ablation_experiment/CompareCaptionsGenerate.ipynb +++ /dev/null @@ -1,262 +0,0 @@ -{ - "cells": [ - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-08-26T07:39:52.816157Z", - "start_time": "2024-08-26T07:39:52.795021Z" - } - }, - "cell_type": "code", - "source": [ - "import torch\n", - "from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor\n", - "\n", - "from src.utils import device" - ], - "id": "af51e4f14f8d4d4b", - "outputs": [], - "execution_count": 6 - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "#
Step 1: Set up the experiment
", - "id": "6b34e3ba439aaa4f" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "##
Set up the cache for the experiment
", - "id": "7c2294918930c572" - }, - { - "cell_type": "code", - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2024-08-26T07:02:56.696609Z", - "start_time": "2024-08-26T07:02:54.348924Z" - } - }, - "source": "cache = {}", - "outputs": [], - "execution_count": 1 - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "##
Same concept as script version here
", - "id": "9c64305289b9fe48" - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-08-26T07:02:56.777576Z", - "start_time": "2024-08-26T07:02:56.775383Z" - } - }, - "cell_type": "code", - "source": "CLIP_NAME = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'", - "id": "22e5d9c8c2fc4547", - "outputs": [], - "execution_count": 2 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-08-26T07:03:03.069862Z", - "start_time": "2024-08-26T07:02:56.857715Z" - } - }, - "cell_type": "code", - "source": [ - "clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(CLIP_NAME, torch_dtype=torch.float32, projection_dim=768)\n", - "clip_text_encoder = clip_text_encoder.float().to(device)\n", - "\n", - "print(\"clip text encoder loaded.\")\n", - "clip_text_encoder.eval()" - ], - "id": "94a46a8e90581af4", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "clip text encoder loaded.\n" - ] - }, - { - "data": { - "text/plain": [ - "CLIPTextModelWithProjection(\n", - " (text_model): CLIPTextTransformer(\n", - " (embeddings): CLIPTextEmbeddings(\n", - " (token_embedding): Embedding(49408, 768)\n", - " (position_embedding): Embedding(77, 768)\n", - " )\n", - " (encoder): CLIPEncoder(\n", - " (layers): ModuleList(\n", - " (0-11): 12 x CLIPEncoderLayer(\n", - " (self_attn): CLIPAttention(\n", - " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", - " )\n", - " (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", - " (mlp): CLIPMLP(\n", - " (activation_fn): GELUActivation()\n", - " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", - " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", - " )\n", - " (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " )\n", - " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (text_projection): Linear(in_features=768, out_features=768, bias=False)\n", - ")" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 3 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-08-26T07:03:04.748384Z", - "start_time": "2024-08-26T07:03:03.204895Z" - } - }, - "cell_type": "code", - "source": [ - "clip_img_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_NAME,torch_dtype=torch.float32, projection_dim=768)\n", - "\n", - "clip_img_encoder = clip_img_encoder.float().to(device)\n", - "print(\"clip img encoder loaded.\")\n", - "clip_img_encoder.eval()" - ], - "id": "32f7fb2e83ce7d74", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "clip img encoder loaded.\n" - ] - }, - { - "data": { - "text/plain": [ - "CLIPVisionModelWithProjection(\n", - " (vision_model): CLIPVisionTransformer(\n", - " (embeddings): CLIPVisionEmbeddings(\n", - " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n", - " (position_embedding): Embedding(257, 1024)\n", - " )\n", - " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder): CLIPEncoder(\n", - " (layers): ModuleList(\n", - " (0-23): 24 x CLIPEncoderLayer(\n", - " (self_attn): CLIPAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (mlp): CLIPMLP(\n", - " (activation_fn): GELUActivation()\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " )\n", - " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " )\n", - " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (visual_projection): Linear(in_features=1024, out_features=768, bias=False)\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 4 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-08-26T07:40:00.097035Z", - "start_time": "2024-08-26T07:40:00.093435Z" - } - }, - "cell_type": "code", - "source": [ - "print('CLIP preprocess pipeline is used')\n", - "preprocess = CLIPImageProcessor(\n", - " crop_size={'height': 224, 'width': 224},\n", - " do_center_crop=True,\n", - " do_convert_rgb=True,\n", - " do_normalize=True,\n", - " do_rescale=True,\n", - " do_resize=True,\n", - " image_mean=[0.48145466, 0.4578275, 0.40821073],\n", - " image_std=[0.26862954, 0.26130258, 0.27577711],\n", - " resample=3,\n", - " size={'shortest_edge': 224},\n", - ")" - ], - "id": "54ac20b43a2e8b69", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CLIP preprocess pipeline is used\n" - ] - } - ], - "execution_count": 7 - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "", - "id": "5be046bb92d5e588" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb b/src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb new file mode 100644 index 0000000..a0ba3fa --- /dev/null +++ b/src/ablation_experiment/CompareCaptionsGenerateBLIP.ipynb @@ -0,0 +1,921 @@ +{ + "cells": [ + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:06:20.675533Z", + "start_time": "2024-08-27T20:06:17.920748Z" + } + }, + "cell_type": "code", + "source": [ + "import json\n", + "from typing import List\n", + "\n", + "import pandas as pd\n", + "from src.blip_modules.blip_text_encoder import BLIPTextEncoder\n", + "from src.blip_modules.blip_img_encoder import BLIPImgEncoder\n", + "\n", + "from src.ablation_experiment.validate_notebook import fiq_val_retrieval_text_image_combinations\n", + "from src.data_utils import targetpad_transform\n", + "from src.fashioniq_experiment.utils import get_combing_function_with_alpha\n", + "from src.utils import device\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ], + "id": "af51e4f14f8d4d4b", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#
Step 1: Set up the experiment
", + "id": "6b34e3ba439aaa4f" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Set up the cache for the experiment
", + "id": "7c2294918930c572" + }, + { + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-08-27T20:06:20.699334Z", + "start_time": "2024-08-27T20:06:20.680753Z" + } + }, + "cell_type": "code", + "source": "cache = {}", + "id": "initial_id", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Same concept as script version here
", + "id": "9c64305289b9fe48" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:06:20.857834Z", + "start_time": "2024-08-27T20:06:20.838273Z" + } + }, + "cell_type": "code", + "source": [ + "BLIP_PRETRAINED_PATH = '../../models/model_base.pth'\n", + "MED_CONFIG_PATH = '../blip_modules/med_config.json'" + ], + "id": "22e5d9c8c2fc4547", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:06:42.410241Z", + "start_time": "2024-08-27T20:06:20.885235Z" + } + }, + "cell_type": "code", + "source": [ + "blip_text_encoder = BLIPTextEncoder(\n", + " BLIP_PRETRAINED_PATH, \n", + " MED_CONFIG_PATH,\n", + " use_pretrained_proj_layer=True\n", + ")\n", + "\n", + "blip_text_encoder = blip_text_encoder.to(device)\n", + "print(\"blip text encoder loaded.\")\n", + "blip_text_encoder.eval()" + ], + "id": "94a46a8e90581af4", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "load checkpoint from ../../models/model_base.pth for text_encoder.\n", + "load checkpoint from ../../models/model_base.pth for text_proj.\n", + "blip text encoder loaded.\n" + ] + }, + { + "data": { + "text/plain": [ + "BLIPTextEncoder(\n", + " (text_encoder): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30524, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0-11): 12 x BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (crossattention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (text_proj): Linear(in_features=768, out_features=256, bias=True)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:06:52.195749Z", + "start_time": "2024-08-27T20:06:42.484449Z" + } + }, + "cell_type": "code", + "source": [ + "blip_img_encoder = BLIPImgEncoder(BLIP_PRETRAINED_PATH)\n", + "blip_img_encoder = blip_img_encoder.to(device)\n", + "print(\"blip img encoder loaded.\")\n", + "blip_img_encoder.eval()" + ], + "id": "32f7fb2e83ce7d74", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "reshape position embedding from 196 to 576\n", + "load checkpoint from ../../models/model_base.pth for visual_encoder.\n", + "load checkpoint from ../../models/model_base.pth for vision_proj.\n", + "blip img encoder loaded.\n" + ] + }, + { + "data": { + "text/plain": [ + "BLIPImgEncoder(\n", + " (visual_encoder): VisionTransformer(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", + " (norm): Identity()\n", + " )\n", + " (pos_drop): Dropout(p=0.0, inplace=False)\n", + " (blocks): ModuleList(\n", + " (0-11): 12 x Block(\n", + " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): Mlp(\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", + " )\n", + " (vision_proj): Linear(in_features=768, out_features=256, bias=True)\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:06:52.297318Z", + "start_time": "2024-08-27T20:06:52.273746Z" + } + }, + "cell_type": "code", + "source": [ + "print('Target pad preprocess pipeline is used')\n", + "preprocess = targetpad_transform(1.25, 384)" + ], + "id": "54ac20b43a2e8b69", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target pad preprocess pipeline is used\n" + ] + } + ], + "execution_count": 6 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#
Step 2: Load the MLLM generated text captions
", + "id": "75ae5725572a1752" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Load the addition text captions
", + "id": "fb8114bce1dbace0" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:06:52.575831Z", + "start_time": "2024-08-27T20:06:52.364940Z" + } + }, + "cell_type": "code", + "source": [ + "with open('../../fashionIQ_dataset/labeled_images_cir_cleaned.json', 'r') as f:\n", + " text_captions = json.load(f)\n", + " \n", + "total_recall_list: List[List[pd.DataFrame]] = []\n", + "\n", + "print(f'Total number of text captions: {len(text_captions)}')" + ], + "id": "838d0c968d6fcf7a", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of text captions: 74357\n" + ] + } + ], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#
Step 3: Perform retrieval on the FashionIQ dataset
", + "id": "99e7660b7fce179" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Perform retrieval on the shirt category
", + "id": "5a8efc43414d05c" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:08:56.211110Z", + "start_time": "2024-08-27T20:06:52.591751Z" + } + }, + "cell_type": "code", + "source": [ + "shirt_recall = fiq_val_retrieval_text_image_combinations(\n", + " 'shirt',\n", + " get_combing_function_with_alpha(0.95),\n", + " blip_text_encoder,\n", + " blip_img_encoder,\n", + " text_captions,\n", + " preprocess,\n", + " 0.2,\n", + " cache,\n", + ")" + ], + "id": "a0f1b60d0aecba10", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluating feature combinations: 100%|██████████| 7/7 [00:24<00:00, 3.45s/it]\n" + ] + } + ], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:08:56.259118Z", + "start_time": "2024-08-27T20:08:56.230637Z" + } + }, + "cell_type": "code", + "source": "shirt_recall", + "id": "96b511edd8178b1d", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50 Combination\n", + "0 0.2 22.522080 36.457312 First set\n", + "1 0.2 21.687929 35.525024 Second set\n", + "2 0.2 22.571148 35.672227 Third set\n", + "3 0.2 22.816487 37.095192 First and second set\n", + "4 0.2 22.669284 36.261040 Second and third set\n", + "5 0.2 23.110893 36.800784 First and third set\n", + "6 0.2 23.159961 36.997056 All sets" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50Combination
00.222.52208036.457312First set
10.221.68792935.525024Second set
20.222.57114835.672227Third set
30.222.81648737.095192First and second set
40.222.66928436.261040Second and third set
50.223.11089336.800784First and third set
60.223.15996136.997056All sets
\n", + "
" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 9 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Perform retrieval on the dress category
", + "id": "d4ab21021db7064b" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:10:21.596635Z", + "start_time": "2024-08-27T20:08:56.295538Z" + } + }, + "cell_type": "code", + "source": [ + "dress_recall = fiq_val_retrieval_text_image_combinations(\n", + " 'dress',\n", + " get_combing_function_with_alpha(0.95),\n", + " blip_text_encoder,\n", + " blip_img_encoder,\n", + " text_captions,\n", + " preprocess,\n", + " 0.2,\n", + " cache,\n", + ")" + ], + "id": "a59ce3e3360ce5e6", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluating feature combinations: 100%|██████████| 7/7 [00:24<00:00, 3.44s/it]\n" + ] + } + ], + "execution_count": 10 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:10:21.647068Z", + "start_time": "2024-08-27T20:10:21.621898Z" + } + }, + "cell_type": "code", + "source": "dress_recall", + "id": "1813a1e5d3d6cd2c", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50 Combination\n", + "0 0.2 20.327219 38.076350 First set\n", + "1 0.2 18.889439 37.233517 Second set\n", + "2 0.2 18.641546 35.448685 Third set\n", + "3 0.2 20.525533 38.572136 First and second set\n", + "4 0.2 19.732276 37.332672 Second and third set\n", + "5 0.2 20.079325 37.729302 First and third set\n", + "6 0.2 20.475954 38.621715 All sets" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50Combination
00.220.32721938.076350First set
10.218.88943937.233517Second set
20.218.64154635.448685Third set
30.220.52553338.572136First and second set
40.219.73227637.332672Second and third set
50.220.07932537.729302First and third set
60.220.47595438.621715All sets
\n", + "
" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Perform retrieval on the toptee category
", + "id": "f912a96497bb2a63" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:12:15.138640Z", + "start_time": "2024-08-27T20:10:21.683852Z" + } + }, + "cell_type": "code", + "source": [ + "toptee_recall = fiq_val_retrieval_text_image_combinations(\n", + " 'toptee',\n", + " get_combing_function_with_alpha(0.95),\n", + " blip_text_encoder,\n", + " blip_img_encoder,\n", + " text_captions,\n", + " preprocess,\n", + " 0.2,\n", + " cache,\n", + ")" + ], + "id": "b705129d0dd555ee", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluating feature combinations: 100%|██████████| 7/7 [00:24<00:00, 3.47s/it]\n" + ] + } + ], + "execution_count": 12 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:12:15.188263Z", + "start_time": "2024-08-27T20:12:15.161793Z" + } + }, + "cell_type": "code", + "source": "toptee_recall", + "id": "dda4e0a6e1f697fd", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50 Combination\n", + "0 0.2 24.579297 46.098930 First set\n", + "1 0.2 23.406425 45.690975 Second set\n", + "2 0.2 24.987252 45.283020 Third set\n", + "3 0.2 25.038245 46.812850 First and second set\n", + "4 0.2 24.783275 46.863845 Second and third set\n", + "5 0.2 25.650179 46.965834 First and third set\n", + "6 0.2 25.140235 47.373790 All sets" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50Combination
00.224.57929746.098930First set
10.223.40642545.690975Second set
20.224.98725245.283020Third set
30.225.03824546.812850First and second set
40.224.78327546.863845Second and third set
50.225.65017946.965834First and third set
60.225.14023547.373790All sets
\n", + "
" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 13 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:12:15.284150Z", + "start_time": "2024-08-27T20:12:15.263692Z" + } + }, + "cell_type": "code", + "source": [ + "# Change the index to 'Combination' column\n", + "shirt_recall.set_index('Combination', inplace=True)\n", + "dress_recall.set_index('Combination', inplace=True)\n", + "toptee_recall.set_index('Combination', inplace=True)" + ], + "id": "12cef5a9d541bbde", + "outputs": [], + "execution_count": 14 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:12:15.366604Z", + "start_time": "2024-08-27T20:12:15.343908Z" + } + }, + "cell_type": "code", + "source": [ + "# Average the recall values\n", + "average_recall = (shirt_recall + dress_recall + toptee_recall) / 3\n", + "average_recall" + ], + "id": "7c4aead0e821fbf", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50\n", + "Combination \n", + "First set 0.2 22.476199 40.210864\n", + "Second set 0.2 21.327931 39.483172\n", + "Third set 0.2 22.066649 38.801310\n", + "First and second set 0.2 22.793422 40.826726\n", + "Second and third set 0.2 22.394945 40.152519\n", + "First and third set 0.2 22.946799 40.498640\n", + "All sets 0.2 22.925383 40.997520" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50
Combination
First set0.222.47619940.210864
Second set0.221.32793139.483172
Third set0.222.06664938.801310
First and second set0.222.79342240.826726
Second and third set0.222.39494540.152519
First and third set0.222.94679940.498640
All sets0.222.92538340.997520
\n", + "
" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 15 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T20:12:15.560807Z", + "start_time": "2024-08-27T20:12:15.558677Z" + } + }, + "cell_type": "code", + "source": "", + "id": "72f821a5bb798988", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb b/src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb new file mode 100644 index 0000000..05dcaf6 --- /dev/null +++ b/src/ablation_experiment/CompareCaptionsGenerateCLIP.ipynb @@ -0,0 +1,909 @@ +{ + "cells": [ + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:51.095040Z", + "start_time": "2024-08-27T15:50:47.706586Z" + } + }, + "cell_type": "code", + "source": [ + "import json\n", + "from typing import List\n", + "\n", + "import pandas as pd\n", + "import torch\n", + "from clip import tokenize\n", + "from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor\n", + "\n", + "from src.ablation_experiment.validate_notebook import fiq_val_retrieval_text_image_combinations_clip\n", + "from src.fashioniq_experiment.utils import get_combing_function_with_alpha\n", + "from src.utils import device\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ], + "id": "af51e4f14f8d4d4b", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#
Step 1: Set up the experiment
", + "id": "6b34e3ba439aaa4f" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Set up the cache for the experiment
", + "id": "7c2294918930c572" + }, + { + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-08-27T15:50:51.124450Z", + "start_time": "2024-08-27T15:50:51.106174Z" + } + }, + "cell_type": "code", + "source": "cache = {}", + "id": "initial_id", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Same concept as script version here
", + "id": "9c64305289b9fe48" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:51.303373Z", + "start_time": "2024-08-27T15:50:51.285660Z" + } + }, + "cell_type": "code", + "source": "CLIP_NAME = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'", + "id": "22e5d9c8c2fc4547", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:57.640569Z", + "start_time": "2024-08-27T15:50:51.341862Z" + } + }, + "cell_type": "code", + "source": [ + "clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(CLIP_NAME, torch_dtype=torch.float32, projection_dim=768)\n", + "clip_text_encoder = clip_text_encoder.float().to(device)\n", + "\n", + "print(\"clip text encoder loaded.\")\n", + "clip_text_encoder.eval()" + ], + "id": "94a46a8e90581af4", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "clip text encoder loaded.\n" + ] + }, + { + "data": { + "text/plain": [ + "CLIPTextModelWithProjection(\n", + " (text_model): CLIPTextTransformer(\n", + " (embeddings): CLIPTextEmbeddings(\n", + " (token_embedding): Embedding(49408, 768)\n", + " (position_embedding): Embedding(77, 768)\n", + " )\n", + " (encoder): CLIPEncoder(\n", + " (layers): ModuleList(\n", + " (0-11): 12 x CLIPEncoderLayer(\n", + " (self_attn): CLIPAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (mlp): CLIPMLP(\n", + " (activation_fn): GELUActivation()\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " )\n", + " (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (text_projection): Linear(in_features=768, out_features=768, bias=False)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:59.474130Z", + "start_time": "2024-08-27T15:50:57.657253Z" + } + }, + "cell_type": "code", + "source": [ + "clip_img_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_NAME,torch_dtype=torch.float32, projection_dim=768)\n", + "\n", + "clip_img_encoder = clip_img_encoder.float().to(device)\n", + "print(\"clip img encoder loaded.\")\n", + "clip_img_encoder.eval()" + ], + "id": "32f7fb2e83ce7d74", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "clip img encoder loaded.\n" + ] + }, + { + "data": { + "text/plain": [ + "CLIPVisionModelWithProjection(\n", + " (vision_model): CLIPVisionTransformer(\n", + " (embeddings): CLIPVisionEmbeddings(\n", + " (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)\n", + " (position_embedding): Embedding(257, 1024)\n", + " )\n", + " (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder): CLIPEncoder(\n", + " (layers): ModuleList(\n", + " (0-23): 24 x CLIPEncoderLayer(\n", + " (self_attn): CLIPAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (mlp): CLIPMLP(\n", + " (activation_fn): GELUActivation()\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " )\n", + " (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (visual_projection): Linear(in_features=1024, out_features=768, bias=False)\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:59.515715Z", + "start_time": "2024-08-27T15:50:59.498141Z" + } + }, + "cell_type": "code", + "source": [ + "print('CLIP preprocess pipeline is used')\n", + "preprocess = CLIPImageProcessor(\n", + " crop_size={'height': 224, 'width': 224},\n", + " do_center_crop=True,\n", + " do_convert_rgb=True,\n", + " do_normalize=True,\n", + " do_rescale=True,\n", + " do_resize=True,\n", + " image_mean=[0.48145466, 0.4578275, 0.40821073],\n", + " image_std=[0.26862954, 0.26130258, 0.27577711],\n", + " resample=3,\n", + " size={'shortest_edge': 224},\n", + ")" + ], + "id": "54ac20b43a2e8b69", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CLIP preprocess pipeline is used\n" + ] + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:59.562263Z", + "start_time": "2024-08-27T15:50:59.547800Z" + } + }, + "cell_type": "code", + "source": "clip_tokenizer = tokenize", + "id": "91b281d7472fef42", + "outputs": [], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#
Step 2: Load the MLLM generated text captions
", + "id": "75ae5725572a1752" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Load the addition text captions
", + "id": "fb8114bce1dbace0" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:50:59.872357Z", + "start_time": "2024-08-27T15:50:59.607774Z" + } + }, + "cell_type": "code", + "source": [ + "with open('../../fashionIQ_dataset/labeled_images_cir_cleaned.json', 'r') as f:\n", + " text_captions = json.load(f)\n", + " \n", + "total_recall_list: List[List[pd.DataFrame]] = []\n", + "\n", + "print(f'Total number of text captions: {len(text_captions)}')" + ], + "id": "838d0c968d6fcf7a", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of text captions: 74357\n" + ] + } + ], + "execution_count": 8 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#
Step 3: Perform retrieval on the FashionIQ dataset
", + "id": "99e7660b7fce179" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Perform retrieval on the shirt category
", + "id": "5a8efc43414d05c" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:56:59.038961Z", + "start_time": "2024-08-27T15:56:17.685425Z" + } + }, + "cell_type": "code", + "source": [ + "shirt_recall = fiq_val_retrieval_text_image_combinations_clip(\n", + " 'shirt',\n", + " get_combing_function_with_alpha(0.8),\n", + " clip_text_encoder,\n", + " clip_img_encoder,\n", + " clip_tokenizer,\n", + " text_captions,\n", + " preprocess,\n", + " 0.1,\n", + " cache,\n", + ")" + ], + "id": "a0f1b60d0aecba10", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluating feature combinations: 100%|██████████| 7/7 [00:41<00:00, 5.90s/it]\n" + ] + } + ], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:57:03.849951Z", + "start_time": "2024-08-27T15:57:03.828771Z" + } + }, + "cell_type": "code", + "source": "shirt_recall", + "id": "96b511edd8178b1d", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50 Combination\n", + "0 0.1 32.679096 49.509323 First set\n", + "1 0.1 32.777232 48.135427 Second set\n", + "2 0.1 32.384691 47.988224 Third set\n", + "3 0.1 33.022571 49.018645 First and second set\n", + "4 0.1 32.924435 47.742885 Second and third set\n", + "5 0.1 33.071640 48.969579 First and third set\n", + "6 0.1 32.777232 48.969579 All sets" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50Combination
00.132.67909649.509323First set
10.132.77723248.135427Second set
20.132.38469147.988224Third set
30.133.02257149.018645First and second set
40.132.92443547.742885Second and third set
50.133.07164048.969579First and third set
60.132.77723248.969579All sets
\n", + "
" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 12 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Perform retrieval on the dress category
", + "id": "d4ab21021db7064b" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T15:59:00.165828Z", + "start_time": "2024-08-27T15:57:18.015381Z" + } + }, + "cell_type": "code", + "source": [ + "dress_recall = fiq_val_retrieval_text_image_combinations_clip(\n", + " 'dress',\n", + " get_combing_function_with_alpha(0.8),\n", + " clip_text_encoder,\n", + " clip_img_encoder,\n", + " clip_tokenizer,\n", + " text_captions,\n", + " preprocess,\n", + " 0.1,\n", + " cache,\n", + ")" + ], + "id": "a59ce3e3360ce5e6", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluating feature combinations: 100%|██████████| 7/7 [00:41<00:00, 5.89s/it]\n" + ] + } + ], + "execution_count": 13 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T16:01:09.123300300Z", + "start_time": "2024-08-27T15:59:00.189498Z" + } + }, + "cell_type": "code", + "source": "dress_recall", + "id": "1813a1e5d3d6cd2c", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50 Combination\n", + "0 0.1 25.582549 46.603867 First set\n", + "1 0.1 24.144769 46.455130 Second set\n", + "2 0.1 24.243927 46.207237 Third set\n", + "3 0.1 26.177493 47.347546 First and second set\n", + "4 0.1 24.987605 46.306396 Second and third set\n", + "5 0.1 25.929597 47.198811 First and third set\n", + "6 0.1 25.880021 47.297966 All sets" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50Combination
00.125.58254946.603867First set
10.124.14476946.455130Second set
20.124.24392746.207237Third set
30.126.17749347.347546First and second set
40.124.98760546.306396Second and third set
50.125.92959747.198811First and third set
60.125.88002147.297966All sets
\n", + "
" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 14 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "##
Perform retrieval on the toptee category
", + "id": "f912a96497bb2a63" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T16:01:08.524086Z", + "start_time": "2024-08-27T15:59:00.334132Z" + } + }, + "cell_type": "code", + "source": [ + "toptee_recall = fiq_val_retrieval_text_image_combinations_clip(\n", + " 'toptee',\n", + " get_combing_function_with_alpha(0.8),\n", + " clip_text_encoder,\n", + " clip_img_encoder,\n", + " clip_tokenizer,\n", + " text_captions,\n", + " preprocess,\n", + " 0.1,\n", + " cache,\n", + ")" + ], + "id": "b705129d0dd555ee", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Evaluating feature combinations: 100%|██████████| 7/7 [00:41<00:00, 5.90s/it]\n" + ] + } + ], + "execution_count": 15 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T16:01:08.578006Z", + "start_time": "2024-08-27T16:01:08.555585Z" + } + }, + "cell_type": "code", + "source": "toptee_recall", + "id": "dda4e0a6e1f697fd", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50 Combination\n", + "0 0.1 36.053035 56.348801 First set\n", + "1 0.1 35.390106 55.940849 Second set\n", + "2 0.1 35.900050 55.685872 Third set\n", + "3 0.1 36.308005 56.450790 First and second set\n", + "4 0.1 35.237125 56.297809 Second and third set\n", + "5 0.1 36.308005 56.756759 First and third set\n", + "6 0.1 35.951045 56.705761 All sets" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50Combination
00.136.05303556.348801First set
10.135.39010655.940849Second set
20.135.90005055.685872Third set
30.136.30800556.450790First and second set
40.135.23712556.297809Second and third set
50.136.30800556.756759First and third set
60.135.95104556.705761All sets
\n", + "
" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 16 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T19:24:38.356362Z", + "start_time": "2024-08-27T19:24:38.333283Z" + } + }, + "cell_type": "code", + "source": [ + "# Change the index to 'Combination' column\n", + "shirt_recall.set_index('Combination', inplace=True)\n", + "dress_recall.set_index('Combination', inplace=True)\n", + "toptee_recall.set_index('Combination', inplace=True)" + ], + "id": "12cef5a9d541bbde", + "outputs": [], + "execution_count": 18 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-27T19:24:41.043674Z", + "start_time": "2024-08-27T19:24:41.021762Z" + } + }, + "cell_type": "code", + "source": [ + "# Average the recall values\n", + "average_recall = (shirt_recall + dress_recall + toptee_recall) / 3\n", + "average_recall" + ], + "id": "7c4aead0e821fbf", + "outputs": [ + { + "data": { + "text/plain": [ + " beta recall_at10 recall_at50\n", + "Combination \n", + "First set 0.1 31.438227 50.820664\n", + "Second set 0.1 30.770702 50.177135\n", + "Third set 0.1 30.842889 49.960444\n", + "First and second set 0.1 31.836023 50.938994\n", + "Second and third set 0.1 31.049721 50.115697\n", + "First and third set 0.1 31.769748 50.975050\n", + "All sets 0.1 31.536099 50.991102" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
betarecall_at10recall_at50
Combination
First set0.131.43822750.820664
Second set0.130.77070250.177135
Third set0.130.84288949.960444
First and second set0.131.83602350.938994
Second and third set0.131.04972150.115697
First and third set0.131.76974850.975050
All sets0.131.53609950.991102
\n", + "
" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 19 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "72f821a5bb798988" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/ablation_experiment/validate_notebook.py b/src/ablation_experiment/validate_notebook.py new file mode 100644 index 0000000..244e25f --- /dev/null +++ b/src/ablation_experiment/validate_notebook.py @@ -0,0 +1,435 @@ +from typing import List + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from src.data_utils import FashionIQDataset +from src.utils import extract_index_features_with_text_captions_clip, extract_index_features_clip, \ + extract_index_features_with_text_captions, extract_index_features +from src.validate import generate_fiq_val_predictions +from src.validate_clip import generate_fiq_val_predictions as generate_fiq_val_predictions_clip + + +def compute_fiq_val_metrics_text_image_combinations( + relative_val_dataset: FashionIQDataset, + blip_text_encoder: torch.nn.Module, + multiple_text_index_features: List[torch.Tensor], + multiple_text_index_names: List[List[str]], + image_index_features: torch.Tensor, + image_index_names: List[str], + combining_function: callable, + beta: float, +) -> pd.DataFrame: + """ + Compute validation metrics on FashionIQ dataset combining text and image distances. + + :param relative_val_dataset: FashionIQ validation dataset in relative mode + :param blip_text_encoder: BLIP text encoder + :param multiple_text_index_features: validation index features from text + :param multiple_text_index_names: validation index names from text + :param image_index_features: validation image index features + :param image_index_names: validation image index names + :param combining_function: function that combines features + :param beta: beta value for the combination of text and image distances + :return: the computed validation metrics + """ + all_text_distances = [] + results = [] + target_names = None + + # Compute distances for individual text features + for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names): + # Generate text predictions and normalize features + predicted_text_features, target_names = generate_fiq_val_predictions( + blip_text_encoder, + relative_val_dataset, + combining_function, + text_names, + text_features, + no_print_output=True, + ) + # Normalize features + text_features = F.normalize(text_features, dim=-1) + predicted_text_features = F.normalize(predicted_text_features, dim=-1) + + # Compute cosine similarity and convert to distance + cosine_similarities = torch.mm(predicted_text_features, text_features.T) + distances = 1 - cosine_similarities + all_text_distances.append(distances) + + # Normalize and compute distances for image features if available + if image_index_features is not None and len(image_index_features) > 0: + predicted_image_features, _ = generate_fiq_val_predictions( + blip_text_encoder, + relative_val_dataset, + combining_function, + image_index_names, + image_index_features, + no_print_output=True, + ) + + # Normalize and compute distances + image_index_features = F.normalize(image_index_features, dim=-1).float() + image_distances = 1 - predicted_image_features @ image_index_features.T + else: + image_distances = torch.zeros_like(all_text_distances[0]) + + # Merge text distances + merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0) + + merged_distances = beta * merged_text_distances + (1 - beta) * image_distances + sorted_indices = torch.argsort(merged_distances, dim=-1).cpu() + sorted_index_names = np.array(image_index_names if image_index_names else multiple_text_index_names[0])[ + sorted_indices] + labels = torch.tensor( + sorted_index_names == np.repeat( + np.array(target_names), + len(image_index_names if image_index_names else multiple_text_index_names[0]) + ).reshape(len(target_names), -1) + ) + assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) + recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 + recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 + results.append({"beta": beta, "recall_at10": recall_at10, "recall_at50": recall_at50}) + + return pd.DataFrame(results) + + +def compute_fiq_val_metrics_text_image_combinations_clip( + relative_val_dataset: FashionIQDataset, + clip_text_encoder: torch.nn.Module, + clip_tokenizer: callable, + multiple_text_index_features: List[torch.Tensor], + multiple_text_index_names: List[List[str]], + image_index_features: torch.Tensor, + image_index_names: List[str], + combining_function: callable, + beta: float, +) -> pd.DataFrame: + """ + Compute validation metrics on FashionIQ dataset combining text and image distances. + + :param relative_val_dataset: FashionIQ validation dataset in relative mode + :param clip_text_encoder: CLIP text encoder + :param clip_tokenizer: CLIP tokenizer + :param multiple_text_index_features: validation index features from text + :param multiple_text_index_names: validation index names from text + :param image_index_features: validation image index features + :param image_index_names: validation image index names + :param combining_function: function that combines features + :param beta: beta value for the combination of text and image distances + :return: the computed validation metrics + """ + all_text_distances = [] + results = [] + target_names = None + + # Compute distances for individual text features + for text_features, text_names in zip(multiple_text_index_features, multiple_text_index_names): + # Generate text predictions and normalize features + predicted_text_features, target_names = generate_fiq_val_predictions_clip( + clip_text_encoder, + clip_tokenizer, + relative_val_dataset, + combining_function, + text_names, + text_features, + no_print_output=True, + ) + # Normalize features + text_features = F.normalize(text_features, dim=-1) + predicted_text_features = F.normalize(predicted_text_features, dim=-1) + + # Compute cosine similarity and convert to distance + cosine_similarities = torch.mm(predicted_text_features, text_features.T) + distances = 1 - cosine_similarities + all_text_distances.append(distances) + + # Normalize and compute distances for image features if available + if image_index_features is not None and len(image_index_features) > 0: + predicted_image_features, _ = generate_fiq_val_predictions_clip( + clip_text_encoder, + clip_tokenizer, + relative_val_dataset, + combining_function, + image_index_names, + image_index_features, + no_print_output=True, + ) + + # Normalize and compute distances + image_index_features = F.normalize(image_index_features, dim=-1).float() + image_distances = 1 - predicted_image_features @ image_index_features.T + else: + image_distances = torch.zeros_like(all_text_distances[0]) + + # Merge text distances + merged_text_distances = torch.mean(torch.stack(all_text_distances), dim=0) + + merged_distances = beta * merged_text_distances + (1 - beta) * image_distances + sorted_indices = torch.argsort(merged_distances, dim=-1).cpu() + sorted_index_names = np.array( + image_index_names if image_index_names else multiple_text_index_names[0] + )[sorted_indices] + + labels = torch.tensor( + sorted_index_names == np.repeat( + np.array(target_names), + len(image_index_names if image_index_names else multiple_text_index_names[0]) + ).reshape(len(target_names), -1) + ) + assert torch.equal(torch.sum(labels, dim=-1).int(), torch.ones(len(target_names)).int()) + recall_at10 = (torch.sum(labels[:, :10]) / len(labels)).item() * 100 + recall_at50 = (torch.sum(labels[:, :50]) / len(labels)).item() * 100 + results.append({"beta": beta, "recall_at10": recall_at10, "recall_at50": recall_at50}) + + return pd.DataFrame(results) + + +def fiq_val_retrieval_text_image_combinations( + dress_type: str, + combining_function: callable, + blip_text_encoder: torch.nn.Module, + blip_img_encoder: torch.nn.Module, + text_captions: List[dict], + preprocess: callable, + beta: float, + cache: dict, +) -> pd.DataFrame: + """ + Perform retrieval on FashionIQ validation set computing the metrics for different text feature combinations. + + :param dress_type: FashionIQ category on which perform the retrieval + :param combining_function: function which takes as input (image_features, text_features) and outputs the combined features + :param blip_text_encoder: BLIP text model + :param blip_img_encoder: BLIP image model + :param text_captions: text captions for the FashionIQ dataset + :param preprocess: preprocess pipeline + :param beta: beta value for the combination of text and image distances + :param cache: cache dictionary + :return: DataFrame containing the retrieval metrics for each combination of text features + """ + cache_key = f"{dress_type}_cache" + + blip_text_encoder = blip_text_encoder.float().eval() + blip_img_encoder = blip_img_encoder.float().eval() + + if cache_key not in cache: + # Define the validation datasets and extract the index features + classic_val_dataset = FashionIQDataset( + 'val', + [dress_type], + 'classic', + preprocess, + no_print_output=True, + ) + + multiple_index_features, multiple_index_names = [], [] + + for i in range(3): + index_features, index_names, _ = extract_index_features_with_text_captions( + classic_val_dataset, + blip_text_encoder, + text_captions, + i + 1, + no_print_output=True, + ) + multiple_index_features.append(index_features) + multiple_index_names.append(index_names) + + image_index_features, image_index_names = extract_index_features( + classic_val_dataset, + blip_img_encoder, + no_print_output=True, + ) + + cache[cache_key] = { + "multiple_index_features": multiple_index_features, + "multiple_index_names": multiple_index_names, + "image_index_features": image_index_features, + "image_index_names": image_index_names + } + else: + multiple_index_features = cache[cache_key]["multiple_index_features"] + multiple_index_names = cache[cache_key]["multiple_index_names"] + image_index_features = cache[cache_key]["image_index_features"] + image_index_names = cache[cache_key]["image_index_names"] + + relative_val_dataset = FashionIQDataset( + 'val', + [dress_type], + 'relative', + preprocess, + no_print_output=True, + ) + + # Define the combinations of features to evaluate + feature_combinations = [ + ([multiple_index_features[0]], [multiple_index_names[0]]), # Only first set + ([multiple_index_features[1]], [multiple_index_names[1]]), # Only second set + ([multiple_index_features[2]], [multiple_index_names[2]]), # Only third set + ([multiple_index_features[0], multiple_index_features[1]], [multiple_index_names[0], multiple_index_names[1]]), # First and second set + ([multiple_index_features[1], multiple_index_features[2]], [multiple_index_names[1], multiple_index_names[2]]), # Second and third set + ([multiple_index_features[0], multiple_index_features[2]], [multiple_index_names[0], multiple_index_names[2]]), # First and third set + (multiple_index_features, multiple_index_names) # All sets + ] + + results = [] + + combination_name = [ + 'First set', + 'Second set', + 'Third set', + 'First and second set', + 'Second and third set', + 'First and third set', + 'All sets', + ] + + for idx, (features_combination, names_combination) in tqdm( + enumerate(feature_combinations), + desc="Evaluating feature combinations", + total=len(feature_combinations), + ): + result = compute_fiq_val_metrics_text_image_combinations( + relative_val_dataset, + blip_text_encoder, + features_combination, + names_combination, + image_index_features, + image_index_names, + combining_function, + beta, + ) + result['Combination'] = combination_name[idx] + results.append(result) + + # Concatenate all results into a single DataFrame + return pd.concat(results, ignore_index=True) + + +def fiq_val_retrieval_text_image_combinations_clip( + dress_type: str, + combining_function: callable, + clip_text_encoder: torch.nn.Module, + clip_img_encoder: torch.nn.Module, + clip_tokenizer: callable, + text_captions: List[dict], + preprocess: callable, + beta: float, + cache: dict, +) -> pd.DataFrame: + """ + Perform retrieval on FashionIQ validation set computing the metrics for different text feature combinations. + + :param dress_type: FashionIQ category on which perform the retrieval + :param combining_function: function which takes as input (image_features, text_features) and outputs the combined features + :param clip_text_encoder: CLIP text model + :param clip_img_encoder: CLIP image model + :param clip_tokenizer: CLIP tokenizer + :param text_captions: text captions for the FashionIQ dataset + :param preprocess: preprocess pipeline + :param beta: beta value for the combination of text and image distances + :param cache: cache dictionary + :return: DataFrame containing the retrieval metrics for each combination of text features + """ + cache_key = f"{dress_type}_cache" + + clip_text_encoder = clip_text_encoder.float().eval() + clip_img_encoder = clip_img_encoder.float().eval() + + if cache_key not in cache: + # Define the validation datasets and extract the index features + classic_val_dataset = FashionIQDataset( + 'val', + [dress_type], + 'classic', + preprocess, + no_print_output=True, + ) + + multiple_index_features, multiple_index_names = [], [] + + for i in range(3): + index_features, index_names, _ = extract_index_features_with_text_captions_clip( + classic_val_dataset, + clip_text_encoder, + clip_tokenizer, + text_captions, + i + 1, + no_print_output=True, + ) + multiple_index_features.append(index_features) + multiple_index_names.append(index_names) + + image_index_features, image_index_names = extract_index_features_clip( + classic_val_dataset, + clip_img_encoder, + no_print_output=True, + ) + + cache[cache_key] = { + "multiple_index_features": multiple_index_features, + "multiple_index_names": multiple_index_names, + "image_index_features": image_index_features, + "image_index_names": image_index_names + } + else: + multiple_index_features = cache[cache_key]["multiple_index_features"] + multiple_index_names = cache[cache_key]["multiple_index_names"] + image_index_features = cache[cache_key]["image_index_features"] + image_index_names = cache[cache_key]["image_index_names"] + + relative_val_dataset = FashionIQDataset( + 'val', + [dress_type], + 'relative', + preprocess, + no_print_output=True, + ) + + # Define the combinations of features to evaluate + feature_combinations = [ + ([multiple_index_features[0]], [multiple_index_names[0]]), # Only first set + ([multiple_index_features[1]], [multiple_index_names[1]]), # Only second set + ([multiple_index_features[2]], [multiple_index_names[2]]), # Only third set + ([multiple_index_features[0], multiple_index_features[1]], [multiple_index_names[0], multiple_index_names[1]]), # First and second set + ([multiple_index_features[1], multiple_index_features[2]], [multiple_index_names[1], multiple_index_names[2]]), # Second and third set + ([multiple_index_features[0], multiple_index_features[2]], [multiple_index_names[0], multiple_index_names[2]]), # First and third set + (multiple_index_features, multiple_index_names) # All sets + ] + + results = [] + combination_name = [ + 'First set', + 'Second set', + 'Third set', + 'First and second set', + 'Second and third set', + 'First and third set', + 'All sets', + ] + for idx, (features_combination, names_combination) in tqdm( + enumerate(feature_combinations), + desc="Evaluating feature combinations", + total=len(feature_combinations), + ): + result = compute_fiq_val_metrics_text_image_combinations_clip( + relative_val_dataset, + clip_text_encoder, + clip_tokenizer, + features_combination, + names_combination, + image_index_features, + image_index_names, + combining_function, + beta, + ) + result['Combination'] = combination_name[idx] + results.append(result) + + # Concatenate all results into a single DataFrame + return pd.concat(results, ignore_index=True)