Skip to content

Commit

Permalink
feature: interface for converted models
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Sep 25, 2022
1 parent 3e1e017 commit c34abba
Show file tree
Hide file tree
Showing 414 changed files with 678 additions and 326,809 deletions.
59 changes: 22 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,12 @@
# Pytorch-LiT
# pytorch-lit

[LiT: Zero-Shot Transfer with Locked-image text Tuning](https://arxiv.org/pdf/2111.07991v3.pdf)
Converted official JAX models for [LiT: Zero-Shot Transfer with Locked-image text Tuning](https://arxiv.org/pdf/2111.07991v3.pdf)
to pytorch.

Wrapping LiT implementation in JAX with [jax2torch](https://github.com/lucidrains/jax2torch) to allow
gradient backpropagation.
_JAX -> Tensorflow -> ONNX -> Pytorch._

Installation of `JAX` and `tensorflow` is required at the moment. The
packages `flaxformer` and `t5x` are not on pypi and are copied into this
sourcecode but are only used relatively and will not conflict if another
version is installed.

## Usage

```python
from lit import LiT


model = LiT()


image_encodings = model.encode_images()
text_encodings = model.encode_texts()

cosine_similarity = model.cosine_similarity(image_encodings, text_encodings)
```
- Image encoder is loaded into pytorch and supports gradients
- Text encoder is not loaded into pytorch and runs via ONNX on cpu

## Install

Expand All @@ -37,23 +20,25 @@ or
pip install pytorch-lit
```

### CUDA Toolkit and cuDNN for tensorflow

Recommended to use miniconda to isolate CUDA and cuDNN and poetry for
python packages.
## Usage

```bash
conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0
```
```python
from lit import LiT

### Pytorch for CUDA 11.3
model = LiT()

```bash
poetry run pip uninstall torch torchvision -y && poetry run pip install torch==1.11.0 torchvision==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
images = TF.to_tensor(
Image.open("cat.png").convert("RGB").resize((224, 224))
)[None]
texts = [
"a photo of a cat",
"a photo of a dog",
"a photo of a bird",
"a photo of a fish",
]

### JAX for CUDA >= 11.1 and cuDNN >= 8.0.5
image_encodings = model.encode_images(images)
text_encodings = model.encode_texts(texts)

```bash
poetry run pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
cosine_similarity = model.cosine_similarity(image_encodings, text_encodings)
```
1 change: 1 addition & 0 deletions lit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lit import LiT
Empty file removed lit/flaxformer/__init__.py
Empty file.
163 changes: 0 additions & 163 deletions lit/flaxformer/activation_partitioning.py

This file was deleted.

81 changes: 0 additions & 81 deletions lit/flaxformer/activation_partitioning_test.py

This file was deleted.

Empty file.
Empty file.
Loading

0 comments on commit c34abba

Please sign in to comment.