diff --git a/src/clip/__init__.py b/src/clip/__init__.py index c6d7dd09f..1d4154a7b 100644 --- a/src/clip/__init__.py +++ b/src/clip/__init__.py @@ -1,7 +1,7 @@ from .factory import create_model_and_transforms from .loss import ClipLoss from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16 -from .openai_clip import load_openai +from .openai import load_openai from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ get_pretrained_url, download_pretrained from .tokenizer import SimpleTokenizer, tokenize diff --git a/src/clip/factory.py b/src/clip/factory.py index 677f477b4..c834329e1 100644 --- a/src/clip/factory.py +++ b/src/clip/factory.py @@ -5,8 +5,8 @@ import torch -from .openai_clip import load_openai from .model import CLIP, convert_weights_to_fp16 +from .openai import load_openai from .pretrained import get_pretrained_url, download_pretrained from .transform import image_transform diff --git a/src/clip/openai_clip.py b/src/clip/openai.py similarity index 96% rename from src/clip/openai_clip.py rename to src/clip/openai.py index 1ebceb5ce..9e2c2e00e 100644 --- a/src/clip/openai_clip.py +++ b/src/clip/openai.py @@ -68,8 +68,8 @@ def load_openai( if str(device) == "cpu": model.float() return model, \ - image_transform(model.visual.image_size, is_train=True), \ - image_transform(model.visual.image_size, is_train=False) + image_transform(model.visual.image_size, is_train=True), \ + image_transform(model.visual.image_size, is_train=False) # patch the device names device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])