Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Apr 24, 2021
1 parent e1b4137 commit dd23194
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ An implementation of Geoffrey Hinton's paper "How to represent part-whole hierar

## 2. Usage

### 2 - 1. PyTorch version

```python
import torch
from pyglom import GLOM
Expand Down Expand Up @@ -103,10 +105,51 @@ levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations
```

### 2 - 2. PyTorch-Lightning version

The pyglom also provides the GLOM model that is implemented with PyTorch-Lightning.

```python
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from pytorch_lightning.callbacks import ModelCheckpoint


from pyglom.glom import LightningGLOM


dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
]))
train, val = random_split(dataset, [55000, 5000])

glom = LightningGLOM(
dim=256, # dimension
levels=6, # number of levels
image_size=256, # image size
patch_size=16, # patch size
img_channels=1
)

checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='.',
filename='mnist-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min',
)

gpus = torch.cuda.device_count()
trainer = pl.Trainer(callbacks=[checkpoint_callback], gpus=gpus)
trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=32, num_workers=2))
```

## 3. ToDo

- contrastive / consistency regularization of top-ish levels
- re-implement the model with PyTorchLightning

## 4. Citations

Expand Down

0 comments on commit dd23194

Please sign in to comment.