Skip to content

Commit

Permalink
CoCa: generate and beam search (#393)
Browse files Browse the repository at this point in the history
* initial commit

* add utils

* imports

* add text default

* add readme and wrap all generation types in generate

* add assertion

* assertion

* make args explicit

* better readme

* better readme remove warning

* add decode function

* use defaults pretrained

---------

Co-authored-by: gpucce <[email protected]>
Co-authored-by: Romain Beaumont <[email protected]>
  • Loading branch information
3 people authored Feb 1, 2023
1 parent ccad1ab commit 8709d8b
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 17 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,32 @@ Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through spec
```
Credit to [lucidrains](https://github.com/lucidrains) for [initial code](https://github.com/lucidrains/CoCa-pytorch), [gpucce](https://github.com/gpucce) for adapting the code to open_clip, and [iejMac](https://github.com/iejMac) for training the models.

### Generating text with CoCa

To generate text with coca this should work

```python
import open_clip
from PIL import Image
model, _, transform = open_clip.create_model_and_transform(
model_name="coca_ViT-B-32",
pretrained="laion2B-s13B-b90k"
)
# load an image
im = Image.load("path/to/image").convert("RGB")
# transform the image and add a batch size dimension
im = transform(im).unsqueeze(0)
generated = model.generate(im)
# alternatively if computation was running on a gpu
# generated = generated.detach()
print(open_clip.decode(generated[0]))
# "<start_of_text> some text here <end_of_text>"
```

### Training with pre-trained language models as text encoder:

If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
Expand Down
2 changes: 1 addition & 1 deletion src/open_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .tokenizer import SimpleTokenizer, tokenize, decode
from .transform import image_transform, AugmentationCfg
222 changes: 207 additions & 15 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
from .generation_utils import top_a, top_k, top_p
from .generation_utils import top_a, top_k, top_p, prepare_inputs_for_generation
from transformers import BeamSearchScorer, LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria

GENERATION_TYPES = {
"top_k": top_k,
"top_p": top_p,
"top_a": top_a,
"beam_search": "beam_search"
}

@dataclass
class MultimodalCfg(CLIPTextCfg):
Expand Down Expand Up @@ -108,8 +115,8 @@ def _encode_image(self, images, normalize=True):
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs

def _encode_text(self, text, normalize=True):
text = text[:, :-1] # make space for CLS token
def _encode_text(self, text, normalize=True, embed_cls=True):
text = text[:, :-1] if embed_cls else text # make space for CLS token
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
Expand All @@ -118,13 +125,14 @@ def encode_image(self, images, normalize=True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent

def encode_text(self, text, normalize=True):
text_latent, _ = self._encode_text(text, normalize=normalize)
def encode_text(self, text, normalize=True, embed_cls=True):
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
return text_latent

def forward(self, image, text):
text_latent, token_embs = self._encode_text(text)
image_latent, image_embs = self._encode_image(image)
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
Expand All @@ -141,19 +149,47 @@ def forward(self, image, text):
def generate(
self,
image,
text,
seq_len,
text=None,
seq_len=77,
max_seq_len=77,
mask_prob=0.0,
temperature=1.,
filter_logits_fn=top_k,
generation_type="beam_search",
filter_thres=0.9,
min_p_pow=2.0,
min_p_ratio=0.02,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
):

assert generation_type in GENERATION_TYPES, \
f"generation_type has to be one of {'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
filter_logits_fn = GENERATION_TYPES[generation_type]

if generation_type == "beam_search":
return self.generate_beamsearch(
image_inputs = image,
max_length = seq_len,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
)

assert mask_prob < 1, "mask_prob must be smaller than 1."
device = image.device

sot_token_id = 49406 if sot_token_id is None else sot_token_id
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)

Expand All @@ -167,8 +203,7 @@ def generate(
for _ in range(seq_len):
x = out[:, -max_seq_len:]

# TODO: adjust for dict output
logits = self(image, x)["logits"][:, -1]
logits = self(image, x, embed_cls=False)["logits"][:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
Expand All @@ -184,10 +219,167 @@ def generate(

out = torch.cat((out, sample), dim=-1)

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

self.train(was_training)
return out

def generate_beamsearch(
self,
image_inputs,
max_length,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
):

sot_token_id = 49406 if sot_token_id is None else sot_token_id
eos_token_id = 49407 if eos_token_id is None else eos_token_id
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id

device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)

input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
target_logits_processor_list = [
MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)
]
logits_processor = LogitsProcessorList(
target_logits_processor_list
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=max_length)]
stopping_criteria = StoppingCriteriaList(
stopping_criteria
)

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None

if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)

beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))

while True:

# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
embed_cls=False,
image_latent=image_latent,
image_embs=image_embs
)

for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx

# indices of beams of current group among all sentences in batch
batch_group_indices = []

for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]

# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]

next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)

next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size

# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]

# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)

input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None):
break

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']

23 changes: 22 additions & 1 deletion src/open_clip/generation_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F


Expand Down Expand Up @@ -35,3 +34,25 @@ def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits


def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
3 changes: 3 additions & 0 deletions src/open_clip/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def decode(self, tokens):

_tokenizer = SimpleTokenizer()

def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)

def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Expand Down

0 comments on commit 8709d8b

Please sign in to comment.