diff --git a/docs/getting-started.md b/docs/getting-started.md index 836714267e..5e1f7f3445 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -66,44 +66,58 @@ Inside this Docker environment you can do anything – run a Jupyter notebook, Let's pretend we've trained a model. With Cog, we can define how to run predictions on it in a standard way, so other people can easily run predictions on it without having to hunt around for a prediction script. -First, run this to get some pre-trained model weights: - -```bash -WEIGHTS_URL=https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5 -curl -O $WEIGHTS_URL - -``` - -Then, we need to write some code to describe how predictions are run on the model. +First, we need to write some code to describe how predictions are run on the model. Save this to `predict.py`: ```python from typing import Any from cog import BasePredictor, Input, Path -from tensorflow.keras.applications.resnet50 import ResNet50 -from tensorflow.keras.preprocessing import image as keras_image -from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions -import numpy as np +import torch +from PIL import Image +from torchvision import transforms +# reference to: https://colab.research.google.com/github/pytorch/pytorch.github.io/blob/master/assets/hub/pytorch_vision_resnet.ipynb class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" - self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5') + self.model = torch.hub.load( + "pytorch/vision:v0.10.0", "resnet18", pretrained=True + ) + self.model.eval() + + self.preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + with open("imagenet_classes.txt", "r") as f: + self.categories = [s.strip() for s in f.readlines()] # Define the arguments and types the model takes as input def predict(self, image: Path = Input(description="Image to classify")) -> Any: """Run a single prediction on the model""" # Preprocess the image - img = keras_image.load_img(image, target_size=(224, 224)) - x = keras_image.img_to_array(img) - x = np.expand_dims(x, axis=0) - x = preprocess_input(x) - # Run the prediction - preds = self.model.predict(x) - # Return the top 3 predictions - return decode_predictions(preds, top=3)[0] + input_image = Image.open(image) + input_tensor = self.preprocess(input_image) + # create a mini-batch as expected by the model + input_batch = input_tensor.unsqueeze(0) + with torch.no_grad(): + output = self.model(input_batch) + # Return the top 5 predictions + probabilities = torch.nn.functional.softmax(output[0], dim=0) + top5_prob, top5_catid = torch.topk(probabilities, 5) + res_list = [] + for i in range(top5_prob.size(0)): + res_list.append([self.categories[top5_catid[i]], top5_prob[i].item()]) + return res_list + ``` We also need to point Cog at this, and tell it what Python dependencies to install. Update `cog.yaml` to look like this: @@ -113,7 +127,8 @@ build: python_version: "3.11" python_packages: - pillow==9.5.0 - - tensorflow==2.12.0 + - torch==2.3.1+cpu + - torchvision==0.18.1 predict: "predict.py:Predictor" ``` @@ -124,11 +139,15 @@ IMAGE_URL=https://gist.githubusercontent.com/bfirsh/3c2115692682ae260932a67d93fd curl $IMAGE_URL > input.jpg ``` - +Then, let's grab the imagenet_classes.txt file: +```bash +# Download ImageNet labels +!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt +``` Now, let's run the model using Cog: ```bash -cog predict -i image=@input.jpg +cog predict -i image=@input.jpg --use-cog-base-image=false ``` @@ -137,19 +156,24 @@ If you see the following output ``` [ [ - "n02123159", - "tiger_cat", - 0.4874822497367859 + "Egyptian cat", + 0.5175395011901855 ], [ - "n02123045", "tabby", - 0.23169134557247162 + 0.26614469289779663 + ], + [ + "tiger cat", + 0.2014264166355133 + ], + [ + "lynx", + 0.003854369046166539 ], [ - "n02124075", - "Egyptian_cat", - 0.09728282690048218 + "plastic bag", + 0.002946337917819619 ] ] ```