Skip to content

Commit

Permalink
figExperimental
Browse files Browse the repository at this point in the history
  • Loading branch information
wahabk committed Feb 4, 2025
1 parent a1402fc commit f096140
Show file tree
Hide file tree
Showing 11 changed files with 4,164 additions and 186 deletions.
7 changes: 3 additions & 4 deletions colloidoscope/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def detect(input_array:np.ndarray, diameter:Union[int, list], model:torch.nn.Mod

# model
if model is None:
n_blocks=2
n_blocks = 2
start = int(math.sqrt(32))
channels = [2**n for n in range(start, start + n_blocks)]
strides = [2 for n in range(1, n_blocks)]
Expand All @@ -169,7 +169,7 @@ def detect(input_array:np.ndarray, diameter:Union[int, list], model:torch.nn.Mod
padding='valid',
)

if 'attention_unet_202206' in weights_path:
if isinstance(weights_path, str) and ('attention_unet_202206' in weights_path):
model = monai.networks.nets.AttentionUnet(
spatial_dims=3,
in_channels=1,
Expand All @@ -188,8 +188,7 @@ def detect(input_array:np.ndarray, diameter:Union[int, list], model:torch.nn.Mod
weights_path = Path(__file__).parent / "attention_unet_202302.pt"
model_weights = torch.load(str(weights_path), map_location=device) # read trained weights
model.load_state_dict(model_weights) # add weights to model

else:
elif weights_path != "preloaded":
model_weights = torch.load(weights_path, map_location=device) # read trained weights
model.load_state_dict(model_weights) # add weights to model

Expand Down
Loading

0 comments on commit f096140

Please sign in to comment.