Galadrielle Humblot-Renaux, Gianni Franchi, Sergio Escalera, Thomas B. Moeslund
OOD detection methods typically revolve around a single classifier, leading to a split in the research field between the classical supervised setting (e.g. ResNet18 classifier trained on CIFAR100) vs. the zero-shot setting (class names fed as prompts to CLIP).
Instead, COOkeD is a heterogeneous ensemble combining the predictions of a closed-world classifier trained end-to-end on a specific dataset, a zero-shot CLIP classifier, and a linear probe classifier trained on CLIP image features. While bulky at first sight, this approach is modular, post-hoc and leverages the availability of pre-trained VLMs, thus introduces little overhead compared to training a single standard classifier.
We evaluate COOkeD on popular CIFAR100 and ImageNet benchmarks, but also consider more challenging, realistic settings ranging from training-time label noise, to test-time covariate shift, to zero-shot shift which has been previously overlooked. Despite its simplicity, COOkeD achieves state-of-the-art performance and greater robustness compared to both classical and CLIP-based OOD detection methods.
Code (see demo.py):
from PIL import Image
import torch
from model_utils import get_classifier_model, get_clip_model, get_probe_model
from data_utils import preprocess_image_for_clip, preprocess_image_for_cls, get_label_to_class_mapping
import glob
# load trained models
device = "cuda" # or "cpu"
clip_variant = "ViT-B-16+openai" # or ViT-B-16+openai, ViT-L-14+openai, ViT-H-14+laion2b_s32b_b79k
classifier = get_classifier_model("imagenet","resnet18-ft", is_torchvision_ckpt=True, device=device)
probe = get_probe_model("imagenet", clip_variant, device=device)
clip, clip_tokenizer, clip_logit_scale = get_clip_model(clip_variant, device=device)
clip.eval() # pre-trained CLIP model from open_clip
probe.eval() # linear probe trained on CLIP image features from ID dataset
classifier.eval() # Resnet18 trained on ID dataset
# define ID classes and encode prompts
class_mapping = get_label_to_class_mapping("imagenet")
prompts = ["a photo of a [cls]".replace("[cls]",f"{class_mapping[idx]}") for idx in range(len(class_mapping))]
with torch.no_grad():
prompt_features = clip.encode_text(clip_tokenizer(prompts).to(device))
prompt_features_normed = prompt_features / prompt_features.norm(dim=-1, keepdim=True)
image_paths = glob.glob("illustrations/*")
ood_scoring = lambda softmax_probs: torch.distributions.Categorical(probs=softmax_probs).entropy().item() # entropy as OOD score
ood_scoring = lambda softmax_probs: torch.max(softmax_probs, dim=1).values.item() # maximum softmax probability (MSP) as OOD score
for image_path in image_paths:
print(f"---------------{image_path}-------------------")
image = Image.open(image_path).convert("RGB")
# note: different normalization for CLIP image encoder vs. standard classifier
image_normalized_clip = preprocess_image_for_clip(image).to(device)
image_normalized_cls = preprocess_image_for_cls(image).to(device)
with torch.no_grad():
# 1. get zero-shot CLIP prediction
clip_image_features = clip.encode_image(image_normalized_clip)
clip_image_features_normed = clip_image_features / clip_image_features.norm(dim=-1, keepdim=True)
text_sim = (clip_image_features_normed @ prompt_features_normed.T)
softmax_clip_t100 = (clip_logit_scale * text_sim).softmax(dim=1)
# 2. get probe CLIP prediction
softmax_probe = probe(clip_image_features).softmax(dim=1)
# 3. get classifier prediction
softmax_classifier = classifier(image_normalized_cls).softmax(dim=1)
# 4. combined prediction
softmax_ensemble = torch.stack([softmax_clip_t100, softmax_probe, softmax_classifier]).mean(0)
# class prediction and OOD scores
pred = softmax_ensemble.argmax(dim=1)
ood_score = ood_scoring(softmax_ensemble)
print("CLIP:", class_mapping[softmax_clip_t100.argmax(dim=1).item()], f"(MSP: {ood_scoring(softmax_clip_t100):.2f})")
print("Probe:", class_mapping[softmax_probe.argmax(dim=1).item()], f"(MSP: {ood_scoring(softmax_probe):.2f})")
print("Classifier:", class_mapping[softmax_classifier.argmax(dim=1).item()], f"(MSP: {ood_scoring(softmax_classifier):.2f})")
print("---> COOkeD:", class_mapping[pred.item()] , f"(MSP: {ood_score:.2f})")
print(f"--------------------------------------------------------------------------------------------------------------")This code was tested on Ubuntu 18.04 with Python 3.11.3 + PyTorch 2.5.1+cu121 + TorchVision 0.20.1+cu121
conda create --name cooked python=3.11.3
conda activate cooked
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txtRun the following script to download the ID datasets (ImageNet-1K, ImageNet-200, CIFAR100, DTD, PatternNet) and corresponding OOD datasets automatically:
python3 data_download.py
Expected directory structure:
data/
βββ benchmark_imglist
βΒ Β βββ cifar100
βΒ Β βββ imagenet
βΒ Β βββ imagenet200
βΒ Β βββ ooddb
βββ images_classic
βΒ Β βββ cifar10
βΒ Β βΒ Β βββ test
βΒ Β βΒ Β βββ train
βΒ Β βββ cifar100
βΒ Β βΒ Β βββ test
βΒ Β βΒ Β βββ train
βΒ Β βββ mnist
βΒ Β βΒ Β βββ test
βΒ Β βΒ Β βββ train
βΒ Β βββ places365
βΒ Β βΒ Β βββ airfield
βΒ Β βΒ Β βββ ...
βΒ Β βΒ Β βββ zen_garden
βΒ Β βββ svhn
βΒ Β βΒ Β βββ test
βΒ Β βββ texture
βΒ Β βΒ Β βββ banded
βΒ Β βΒ Β βββ ...
βΒ Β βΒ Β βββ zigzagged
βΒ Β βββ tin
βΒ Β βββ test
βΒ Β βββ train
βΒ Β βββ val
βΒ Β βββ wnids.txt
βΒ Β βββ words.txt
βββ images_largescale
βββ DTD
βΒ Β βββ images
βΒ Β βββ imdb
βΒ Β βββ labels
βββ imagenet_1k
βΒ Β βββ train
βΒ Β βββ val
βββ imagenet_c
βΒ Β βββ brightness
Β Β βΒ Β βββ ...
βΒ Β βββ zoom_blur
βββ imagenet_r
βΒ Β βββ n01443537
Β Β βΒ Β βββ ...
βΒ Β βββ n12267677
βββ imagenet_v2
βΒ Β βββ 0
Β Β βΒ Β βββ ...
βΒ Β βββ 999
βββ inaturalist
βΒ Β βββ images
βΒ Β βββ imglist.txt
βββ ninco
βΒ Β βββ amphiuma_means
Β Β βΒ Β βββ ...
βΒ Β βββ windsor_chair
βββ openimage_o
βΒ Β βββ images
βββ PatternNet
βΒ Β βββ images
βΒ Β βββ patternnet_description.pdf
βββ ssb_hard
βββ n00470682
βββ ...
βββ n13033134Classifier checkpoints will be downloaded automatically when you run the demo or eval scripts. For ImageNet1K, we use pre-trained classifiers from TorchVision (will be downloaded to checkpoints/torchvision), and for the other ID datasets we share our own trained classifiers at https://huggingface.co/glhr/COOkeD-checkpoints (will be downloaded to checkpoints/classifiers).
The script eval.py evaluates COOkeD in terms of classification accuracy and OOD detection for a given ID dataset, classifier architecture and CLIP variant. Running the following should give you the same results as Table 3 in the paper:
classifier=resnet18-ft # or resnet50-ft
clip_variant=ViT-B-16+openai # or ViT-L-14+openai
python eval.py --id_name imagenet --classifier $classifier --clip_variant $clip_variant # standard evaluation on ImageNet-1K
python eval.py --id_name imagenet --classifier $classifier --clip_variant $clip_variant --csid # test-time covariate shift
python eval.py --id_name cifar100n_noisyfine --classifier $classifier --clip_variant $clip_variant # training-time label noise
python eval.py --id_name ooddb_dtd_0 --classifier $classifier --clip_variant $clip_variant # zero-shot shift (texture images as ID dataset)Full results with both MSP and entropy as OOD score are saved as CSVs to the results directory.
If you use our work, please cite our paper:
@InProceedings{cooked_2025,
author = {Humblot-Renaux, Galadrielle and Franchi, Gianni and Escalera, Sergio and Moeslund, Thomas B.},
title = {{COOkeD}: Ensemble-based {OOD} detection in the era of {CLIP}},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
year = {2025}
}If you have have any issues or doubts about the code, please create a Github issue. Otherwise, you can contact me at [email protected]
The codebase structure and dataset splits for ImageNet and CIFAR100 are based on OpenOOD. We also use data splits from OODDB. We use open_clip to load pre-trained CLIP models.



