Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True)
self.eval()

self.channel_wise = channel_wise
Expand Down Expand Up @@ -297,7 +297,7 @@ class RadImageNetPerceptualSimilarity(nn.Module):

def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
super().__init__()
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True)
self.eval()

for param in self.parameters():
Expand Down
43 changes: 27 additions & 16 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,32 @@

nib, _ = optional_import("nibabel")
http_error, has_req = optional_import("requests", name="HTTPError")
file_url_error, has_gdown = optional_import("gdown.exceptions", name="FileURLRetrievalError")


quick_test_var = "QUICKTEST"
_tf32_enabled = None
_test_data_config: dict = {}

MODULE_PATH = Path(__file__).resolve().parents[1]

DOWNLOAD_EXCEPTS: tuple[type, ...] = (ContentTooShortError, HTTPError, ConnectionError)
if has_req:
DOWNLOAD_EXCEPTS += (http_error,)
if has_gdown:
DOWNLOAD_EXCEPTS += (file_url_error,)

DOWNLOAD_FAIL_MSGS = (
"unexpected EOF", # incomplete download
"network issue",
"gdown dependency", # gdown not installed
"md5 check",
"limit", # HTTP Error 503: Egress is over the account limit
"authenticate",
"timed out", # urlopen error [Errno 110] Connection timed out
"HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub
)


def testing_data_config(*keys):
"""get _test_data_config[keys0][keys1]...[keysN]"""
Expand Down Expand Up @@ -142,29 +161,21 @@ def assert_allclose(

@contextmanager
def skip_if_downloading_fails():
"""
Skips a test if downloading something raises an exception recognised to indicate a download has failed.
"""

try:
yield
except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030
raise unittest.SkipTest(f"error while downloading: {e}") from e
except DOWNLOAD_EXCEPTS as e:
raise unittest.SkipTest(f"Error while downloading: {e}") from e
except ssl.SSLError as ssl_e:
if "decryption failed" in str(ssl_e):
raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e
except (RuntimeError, OSError) as rt_e:
err_str = str(rt_e)
if any(
k in err_str
for k in (
"unexpected EOF", # incomplete download
"network issue",
"gdown dependency", # gdown not installed
"md5 check",
"limit", # HTTP Error 503: Egress is over the account limit
"authenticate",
"timed out", # urlopen error [Errno 110] Connection timed out
"HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub
)
):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download
if any(k in err_str for k in DOWNLOAD_FAIL_MSGS):
raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download

raise rt_e

Expand Down
Loading