diff --git a/torchattack/att.py b/torchattack/att.py index 0901ffe..0b2a2fc 100644 --- a/torchattack/att.py +++ b/torchattack/att.py @@ -56,7 +56,7 @@ def __init__( if hook_cfg: # Explicit config name takes precedence over inferred model.model_name self.hook_cfg = hook_cfg - elif hasattr(model, 'model_name'): + elif isinstance(model, AttackModel): # If model is initialized via `torchattack.AttackModel`, the model_name # is automatically attached to the model during instantiation. self.hook_cfg = model.model_name diff --git a/torchattack/dr.py b/torchattack/dr.py index 5585045..85caee6 100644 --- a/torchattack/dr.py +++ b/torchattack/dr.py @@ -56,7 +56,7 @@ def __init__( # If model is initialized via `torchattack.AttackModel`, infer its model_name # from automatically attached attribute during instantiation. - if not model_name and hasattr(model, 'model_name'): + if not model_name and isinstance(model, AttackModel): model_name = model.model_name self.eps = eps diff --git a/torchattack/pna_patchout.py b/torchattack/pna_patchout.py index 10ca34d..61261bc 100644 --- a/torchattack/pna_patchout.py +++ b/torchattack/pna_patchout.py @@ -59,7 +59,7 @@ def __init__( if hook_cfg: # Explicit config name takes precedence over inferred model.model_name self.hook_cfg = hook_cfg - elif hasattr(model, 'model_name'): + elif isinstance(model, AttackModel): # If model is initialized via `torchattack.AttackModel`, the model_name # is automatically attached to the model during instantiation. self.hook_cfg = model.model_name diff --git a/torchattack/tgr.py b/torchattack/tgr.py index 081464b..e1f7d73 100644 --- a/torchattack/tgr.py +++ b/torchattack/tgr.py @@ -55,7 +55,7 @@ def __init__( if hook_cfg: # Explicit config name takes precedence over inferred model.model_name self.hook_cfg = hook_cfg - elif hasattr(model, 'model_name'): + elif isinstance(model, AttackModel): # If model is initialized via `torchattack.AttackModel`, the model_name # is automatically attached to the model during instantiation. self.hook_cfg = model.model_name diff --git a/torchattack/vdc.py b/torchattack/vdc.py index de644e4..66dc11b 100644 --- a/torchattack/vdc.py +++ b/torchattack/vdc.py @@ -57,10 +57,9 @@ def __init__( if hook_cfg: # Explicit config name takes precedence over inferred model.model_name self.hook_cfg = hook_cfg - elif hasattr(model, 'model_name'): + elif isinstance(model, AttackModel): # If model is initialized via `torchattack.AttackModel`, the model_name # is automatically attached to the model during instantiation. - assert isinstance(str, model.model_name) self.hook_cfg = model.model_name self.eps = eps