-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathHEST_evaluation.py
More file actions
51 lines (38 loc) · 1.38 KB
/
HEST_evaluation.py
File metadata and controls
51 lines (38 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from hest.bench import benchmark
import torch
from torchvision import transforms
print("loading base")
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
ours = torch.load("checkpoints/teacher_epoch250000.pth")
checkpoint = ours["teacher"]
checkpoint_new = {}
for key in list(checkpoint.keys()):
if "dino" in str(key) or "ibot" in str(key):
checkpoint.pop(key, None)
for key, keyb in zip(checkpoint.keys(), dinov2.state_dict().keys()):
checkpoint_new[keyb] = checkpoint[key]
checkpoint = checkpoint_new
new_shape = checkpoint["pos_embed"] #The pos embed is the only different shape
dinov2.pos_embed = torch.nn.parameter.Parameter(new_shape)
dinov2.load_state_dict(checkpoint)
PATH_TO_CONFIG = "./HEST/bench_config/bench_config.yaml"
model = dinov2
RESIZE_DIM = 224
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
model_transforms = transforms.Compose([
#transforms.Resize(224), # Resize the smaller side of the image to 256
#transforms.CenterCrop(RESIZE_DIM), # Crop the center of the image to 224x224
# Step 2: Convert the image (PIL/numpy) to a PyTorch tensor
transforms.ToTensor(),
# Step 3: Normalize the tensor
transforms.Normalize(
mean=NORMALIZE_MEAN,
std=NORMALIZE_STD)
])
benchmark(
model,
model_transforms,
torch.float32,
config=PATH_TO_CONFIG,
)