Skip to content

Commit 9013cde

Browse files
authored
Merge pull request #410 from jhlegarreta/fix/fix-multivoxel-dti-fit
FIX: Multi-voxel DTI fitting tests - oracle using rotation-invariants
2 parents b1e674e + f9b461c commit 9013cde

1 file changed

Lines changed: 82 additions & 5 deletions

File tree

test/test_model_dmri.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Unit tests exercising dMRI models."""
2424

25+
import functools
2526
import re
2627
import warnings
2728
from contextlib import nullcontext
@@ -61,6 +62,20 @@
6162
)
6263

6364

65+
def ignore_dipy_invalid_divide(func):
66+
@functools.wraps(func)
67+
def _wrapped(*args, **kwargs):
68+
with warnings.catch_warnings():
69+
warnings.filterwarnings(
70+
"ignore",
71+
message=r".*invalid value encountered in divide.*",
72+
category=RuntimeWarning,
73+
)
74+
return func(*args, **kwargs)
75+
76+
return _wrapped
77+
78+
6479
def _get_attributes(instance):
6580
"""Return a dictionary of non-callable, non-dunder scalar- or array-like attributes."""
6681
_attrs = {}
@@ -82,8 +97,20 @@ def _get_attributes(instance):
8297
return _attrs
8398

8499

100+
_EIGVEC_DEPENDENT_ATTRS = frozenset({"model_params", "evecs", "directions"})
101+
"""Attributes that encode eigenvector orientation and are subject to sign/rotation
102+
ambiguity when eigenvalues are degenerate. Skipped in favour of comparing the
103+
reconstructed diffusion tensor (``quadratic_form``), which is invariant."""
104+
105+
85106
def _compare_instance_attributes(instance1, instance2):
86-
"""Compare non-callable, non-dunder attributes of two instances for numerical equality."""
107+
"""Compare non-callable, non-dunder attributes of two instances for numerical equality.
108+
109+
Eigenvector-dependent attributes (``model_params``, ``evecs``, ``directions``)
110+
are skipped because degenerate eigenvalues make the eigenvector basis
111+
non-unique. Instead, the reconstructed diffusion tensor
112+
(``quadratic_form``) is compared, which is basis-invariant.
113+
"""
87114
# Get attributes of both instances, excluding dunder and method attributes
88115
attributes1 = _get_attributes(instance1)
89116
attributes2 = _get_attributes(instance2)
@@ -96,6 +123,10 @@ def _compare_instance_attributes(instance1, instance2):
96123
# Compare the values of the attributes
97124
all_equal = True
98125
for attr in attributes1:
126+
# Skip eigenvector-dependent attributes — compared via quadratic_form
127+
if attr in _EIGVEC_DEPENDENT_ATTRS:
128+
continue
129+
99130
value1 = attributes1.get(attr)
100131
value2 = attributes2.get(attr)
101132

@@ -106,10 +137,6 @@ def _compare_instance_attributes(instance1, instance2):
106137
all_equal = False
107138
continue
108139

109-
elif value1 is None or value2 is None:
110-
print(f"Attribute '{attr}' differs: {value1} != {value2}")
111-
all_equal = False
112-
113140
try:
114141
array1 = np.asarray(value1).ravel()
115142
array2 = np.asarray(value2).ravel()
@@ -555,6 +582,14 @@ def test_dti_prediction_shape(setup_random_dwi_data, index):
555582
"snr": None,
556583
"vol_shape": (1, 1, 1),
557584
},
585+
{
586+
"bval_shell": 1000,
587+
"S0": 1,
588+
"evals": (0.0015, 0.0003, 0.0003),
589+
"hsph_dirs": 3,
590+
"snr": None,
591+
"vol_shape": (2, 4, 1),
592+
},
558593
{
559594
"bval_shell": 1000,
560595
"S0": 1,
@@ -563,6 +598,14 @@ def test_dti_prediction_shape(setup_random_dwi_data, index):
563598
"snr": None,
564599
"vol_shape": (1, 1, 1),
565600
},
601+
{
602+
"bval_shell": 1000,
603+
"S0": 1,
604+
"evals": (0.0016, 0.0004, 0.0004),
605+
"hsph_dirs": 6,
606+
"snr": None,
607+
"vol_shape": (2, 5, 1),
608+
},
566609
{
567610
"bval_shell": 1000,
568611
"S0": 1,
@@ -571,12 +614,21 @@ def test_dti_prediction_shape(setup_random_dwi_data, index):
571614
"snr": None,
572615
"vol_shape": (1, 1, 1),
573616
},
617+
{
618+
"bval_shell": 1000,
619+
"S0": 1,
620+
"evals": (0.0015, 0.0003, 0.0003),
621+
"hsph_dirs": 8,
622+
"snr": None,
623+
"vol_shape": (2, 3, 1),
624+
},
574625
],
575626
indirect=True,
576627
)
577628
@pytest.mark.parametrize("index", (None, 3, 5))
578629
@pytest.mark.parametrize("ignore_bzero", (False, True))
579630
@pytest.mark.parametrize("use_mask", (False, True))
631+
@ignore_dipy_invalid_divide
580632
def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
581633
"""Ensure that we get the same result obtained through the DTI model
582634
implemented in DIPY."""
@@ -616,6 +668,14 @@ def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
616668
"snr": None,
617669
"vol_shape": (1, 1, 1),
618670
},
671+
{
672+
"bval_shell": 1000,
673+
"S0": 1,
674+
"evals": (0.0015, 0.0003, 0.0003),
675+
"hsph_dirs": 3,
676+
"snr": None,
677+
"vol_shape": (2, 4, 1),
678+
},
619679
{
620680
"bval_shell": 1000,
621681
"S0": 1,
@@ -624,6 +684,14 @@ def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
624684
"snr": None,
625685
"vol_shape": (1, 1, 1),
626686
},
687+
{
688+
"bval_shell": 1000,
689+
"S0": 1,
690+
"evals": (0.0016, 0.0004, 0.0004),
691+
"hsph_dirs": 6,
692+
"snr": None,
693+
"vol_shape": (2, 5, 1),
694+
},
627695
{
628696
"bval_shell": 1000,
629697
"S0": 1,
@@ -632,12 +700,21 @@ def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
632700
"snr": None,
633701
"vol_shape": (1, 1, 1),
634702
},
703+
{
704+
"bval_shell": 1000,
705+
"S0": 1,
706+
"evals": (0.0015, 0.0003, 0.0003),
707+
"hsph_dirs": 8,
708+
"snr": None,
709+
"vol_shape": (2, 3, 1),
710+
},
635711
],
636712
indirect=True,
637713
)
638714
@pytest.mark.parametrize("index", (None, 3, 5))
639715
@pytest.mark.parametrize("ignore_bzero", (False, True))
640716
@pytest.mark.parametrize("use_mask", (False, True))
717+
@ignore_dipy_invalid_divide
641718
def test_dti_model_predict(single_shell_test_data, index, ignore_bzero, use_mask):
642719
"""Ensure that we get the same result obtained through the DTI model
643720
implemented in DIPY."""

0 commit comments

Comments
 (0)