-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPaligemma_Processing.py
More file actions
130 lines (106 loc) · 4.14 KB
/
Paligemma_Processing.py
File metadata and controls
130 lines (106 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Dict, List, Tuple, Optional, Union, Iterable
import numpy as np
from PIL import Image
import torch
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] #RGB
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token): #bos -> beginning of sentence (same for eos)
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n" #\n should be tokenized differently, but as per HF it should be like this only.
def resize(
image: Image,
size: Tuple[int, int],
resample: Image.Resampling = None,
reducing_gap: Optional[int] = None,
) -> np.ndarray:
height, width = size
resized_image = image.resize(
(width, height), resample=resample, reducing_gap=reducing_gap
)
return resized_image
def rescale(
image: np.ndarray, scale: float, dtype: np.dtype = np.float32
) -> np.ndarray:
rescaled_image = image * scale
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image
def normalize (
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
) -> np.ndarray:
mean = np.array(mean, dtype=image.dtype)
std = np.array(std, dtype=image.dtype)
image = (image - mean) / std
return image
def process_images(
images: List[Image.Image],
size: Dict[str, int] = None,
resample: Image.Resampling = None,
rescale_factor: float = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
) -> List[np.ndarray]:
height, width = size[0], size[1]
images = [
resize(image=image, size = (height, width), resample=resample) for image in images
]
images = [np.array(image) for image in images]
images = [rescale(image, scale=rescale_factor) for image in images]
images = [normalize(image, mean=image_mean, std=image_std) for image in images]
images = [image.transpose(2, 0, 1) for image in images]
return images
#Gemma_Tokenizer
class PaliGemmaProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_token: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_token
self.image_size = image_size
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024) #LOCATION TOKENS
]
EXTRA_TOKENS += [
f"<img{i:03d}>" for i in range(128) #IMAGE TOKENS
]
tokenizer.add_tokens(EXTRA_TOKENS)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
def __call__(
self,
text: List[str],
images: List[Image.Image],
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts."
pixel_values = process_images(
images,
size= (self.image_size, self.image_size),
resample= Image.Resampling.BICUBIC,
rescale_factor = 1/ 255.0,
image_mean = IMAGENET_STANDARD_MEAN,
image_std = IMAGENET_STANDARD_STD,
)
pixel_values = np.stack(pixel_values, axis=0) #converted the pixel values into the array
pixel_values = torch.tensor(pixel_values) #converted into tensor values
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt = prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=self.IMAGE_TOKEN,
)
for prompt in text
]
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)
return_data = {"pixel_values": pixel_values, **inputs}
return return_data