Skip to content

Commit

Permalink
fix download pretrained test
Browse files Browse the repository at this point in the history
  • Loading branch information
rom1504 committed Nov 7, 2022
1 parent 2a07af8 commit 84617b0
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions tests/test_download_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from urllib3 import HTTPResponse
from urllib3._collections import HTTPHeaderDict

from open_clip.pretrained import download_pretrained
from open_clip.pretrained import download_pretrained_from_url


class DownloadPretrainedTests(unittest.TestCase):
Expand All @@ -23,67 +23,67 @@ def create_response(self, data, status_code=200, content_type='application/octet
return raw

@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_openaipublic(self, urllib):
def test_download_pretrained_from_url_from_openaipublic(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained(url, root)
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()

@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_openaipublic_corrupted(self, urllib):
def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
with tempfile.TemporaryDirectory() as root:
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
download_pretrained(url, root)
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()

@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_openaipublic_valid_cache(self, urllib):
def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
local_file = Path(root) / 'RN50.pt'
local_file.write_bytes(file_contents)
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained(url, root)
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_not_called()

@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_openaipublic_corrupted_cache(self, urllib):
def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
local_file = Path(root) / 'RN50.pt'
local_file.write_bytes(b'corrupted pretrained model')
url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
download_pretrained(url, root)
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()

@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_mlfoundations(self, urllib):
def test_download_pretrained_from_url_from_mlfoundations(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
urllib.request.urlopen.return_value = self.create_response(file_contents)
with tempfile.TemporaryDirectory() as root:
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
download_pretrained(url, root)
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()

@patch('open_clip.pretrained.urllib')
def test_download_pretrained_from_mlfoundations_corrupted(self, urllib):
def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib):
file_contents = b'pretrained model weights'
expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
with tempfile.TemporaryDirectory() as root:
url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
download_pretrained(url, root)
download_pretrained_from_url(url, root)
urllib.request.urlopen.assert_called_once()

0 comments on commit 84617b0

Please sign in to comment.