Skip to content

Commit

Permalink
fix: truncate long texts correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Sep 25, 2022
1 parent d28c05c commit 2abf967
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
25 changes: 25 additions & 0 deletions lit/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def tokenize_texts(self, texts: List[str]):
tokens = self.text_tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.n_text_tokens,
).input_ids
return F.pad(tokens, (0, self.n_text_tokens - tokens.shape[1]))

Expand Down Expand Up @@ -232,3 +235,25 @@ def test_documentation_usage():

cosine_similarity = model.cosine_similarity(image_encodings, text_encodings)
assert cosine_similarity[0].argmax() == 0


def test_long_text():
from lit import LiT

model = LiT()
texts = [
"a photo of a cat in a house with a bird and a dog and a fish and a human. this should not crash",
"another really long text. a photo of a cat in a house with a bird and a dog and a fish and a human. this should not crash",
]
model.encode_texts(texts)
tokens = model.tokenize_texts(texts)
assert tokens.shape[-1] == model.n_text_tokens


def test_padded_tokens():
from lit import LiT

model = LiT()
texts = ["a photo of a cat"]
tokens = model.tokenize_texts(texts)
assert tokens.shape[-1] == model.n_text_tokens
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytorch-zero-lit"
version = "0.2.1"
version = "0.2.2"
description = "LiT: Zero-Shot Transfer with Locked-image text Tuning"
authors = ["Richard Löwenström <[email protected]>"]
packages = [
Expand Down

0 comments on commit 2abf967

Please sign in to comment.