Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CoCa: generate and beam search #393

Merged
merged 17 commits into from
Feb 1, 2023
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,32 @@ Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through spec
```
Credit to [lucidrains](https://github.com/lucidrains) for [initial code](https://github.com/lucidrains/CoCa-pytorch), [gpucce](https://github.com/gpucce) for adapting the code to open_clip, and [iejMac](https://github.com/iejMac) for training the models.

### Generating text with CoCa

To generate text with coca this should work

```python
import open_clip
from PIL import Image

model, _, transform = open_clip.create_model_and_transform(
model_name="coca_ViT-B-32",
pretrained="path/to/pretrained"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Point to an existing model instead (see readme, the key is specified at the end)

)

# load an image
im = Image.load("path/to/image").convert("RGB")
# transform the image and add a batch size dimension
im = transform(im).unsqueeze(0)

generated = model.generate(im)
# alternatively if computation was running on a gpu
# generated = generated.detach()

print(open_clip.decode(generated[0]))
# "<start_of_text> some text here <end_of_text>"
```

### Training with pre-trained language models as text encoder:

If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
Expand Down
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .tokenizer import SimpleTokenizer, tokenize, decode
from .transform import image_transform, AugmentationCfg
222 changes: 207 additions & 15 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p
from .generation_utils import top_a, top_k, top_p, prepare_inputs_for_generation
from transformers import BeamSearchScorer, LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria

GENERATION_TYPES = {
"top_k": top_k,
"top_p": top_p,
"top_a": top_a,
"beam_search": "beam_search"
}

@dataclass
class MultimodalCfg(CLIPTextCfg):
Expand Down Expand Up @@ -108,8 +115,8 @@ def _encode_image(self, images, normalize=True):
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize=True):
text = text[:, :-1] # make space for CLS token
def _encode_text(self, text, normalize=True, embed_cls=True):
text = text[:, :-1] if embed_cls else text # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
Expand All @@ -118,13 +125,14 @@ def encode_image(self, images, normalize=True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent

def encode_text(self, text, normalize=True):
text_latent, _ = self._encode_text(text, normalize=normalize)
def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent

def forward(self, image, text):
text_latent, token_embs = self._encode_text(text)
image_latent, image_embs = self._encode_image(image)
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
Expand All @@ -141,19 +149,47 @@ def forward(self, image, text):
def generate(
self,
image,
text,
seq_len,
text=None,
seq_len=77,
max_seq_len=77,
mask_prob=0.0,
temperature=1.,
filter_logits_fn=top_k,
generation_type="beam_search",
filter_thres=0.9,
min_p_pow=2.0,
min_p_ratio=0.02,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
):

assert generation_type in GENERATION_TYPES, \
f"generation_type has to be one of {'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
filter_logits_fn = GENERATION_TYPES[generation_type]

if generation_type == "beam_search":
return self.generate_beamsearch(
image_inputs = image,
max_length = seq_len,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
)

assert mask_prob < 1, "mask_prob must be smaller than 1."
device = image.device

sot_token_id = 49406 if sot_token_id is None else sot_token_id
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)

Expand All @@ -167,8 +203,7 @@ def generate(
for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)["logits"][:, -1]
logits = self(image, x, embed_cls=False)["logits"][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
Expand All @@ -184,10 +219,167 @@ def generate(

out = torch.cat((out, sample), dim=-1)

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out

def generate_beamsearch(
self,
image_inputs,
max_length,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
):

sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id

device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)

input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
target_logits_processor_list = [
MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)
]
logits_processor = LogitsProcessorList(
target_logits_processor_list
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=max_length)]
stopping_criteria = StoppingCriteriaList(
stopping_criteria
)

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None

if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)

beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))

while True:

# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
embed_cls=False,
image_latent=image_latent,
image_embs=image_embs
)

for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx

# indices of beams of current group among all sentences in batch
batch_group_indices = []

for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]

# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]

next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)

next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size

# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]

# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)

input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None):
break

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']

23 changes: 22 additions & 1 deletion src/open_clip/generation_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F


Expand Down Expand Up @@ -35,3 +34,25 @@ def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits


def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
3 changes: 3 additions & 0 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def decode(self, tokens):

_tokenizer = SimpleTokenizer()

def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)

def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Expand Down