From dd23194f85f50d4a919e9951c7d28a9d64fd2f9f Mon Sep 17 00:00:00 2001 From: YeonwooSung Date: Sat, 24 Apr 2021 19:27:18 +0900 Subject: [PATCH] Update README.md --- README.md | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index badd4b7..dec5260 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,8 @@ An implementation of Geoffrey Hinton's paper "How to represent part-whole hierar ## 2. Usage +### 2 - 1. PyTorch version + ```python import torch from pyglom import GLOM @@ -103,10 +105,51 @@ levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins levels3 = model(img3, levels = levels2, iters = 6) # image 3 for 6 iterations ``` +### 2 - 2. PyTorch-Lightning version + +The pyglom also provides the GLOM model that is implemented with PyTorch-Lightning. + +```python +from torchvision.datasets import MNIST +from torch.utils.data import DataLoader, random_split +from torchvision import transforms +import os +from pytorch_lightning.callbacks import ModelCheckpoint + + +from pyglom.glom import LightningGLOM + + +dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor() +])) +train, val = random_split(dataset, [55000, 5000]) + +glom = LightningGLOM( + dim=256, # dimension + levels=6, # number of levels + image_size=256, # image size + patch_size=16, # patch size + img_channels=1 +) + +checkpoint_callback = ModelCheckpoint( + monitor='val_loss', + dirpath='.', + filename='mnist-{epoch:02d}-{val_loss:.2f}', + save_top_k=3, + mode='min', +) + +gpus = torch.cuda.device_count() +trainer = pl.Trainer(callbacks=[checkpoint_callback], gpus=gpus) +trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=32, num_workers=2)) +``` + ## 3. ToDo - contrastive / consistency regularization of top-ish levels -- re-implement the model with PyTorchLightning ## 4. Citations