2222#
2323"""Unit tests exercising dMRI models."""
2424
25+ import functools
2526import re
2627import warnings
2728from contextlib import nullcontext
6162)
6263
6364
65+ def ignore_dipy_invalid_divide (func ):
66+ @functools .wraps (func )
67+ def _wrapped (* args , ** kwargs ):
68+ with warnings .catch_warnings ():
69+ warnings .filterwarnings (
70+ "ignore" ,
71+ message = r".*invalid value encountered in divide.*" ,
72+ category = RuntimeWarning ,
73+ )
74+ return func (* args , ** kwargs )
75+
76+ return _wrapped
77+
78+
6479def _get_attributes (instance ):
6580 """Return a dictionary of non-callable, non-dunder scalar- or array-like attributes."""
6681 _attrs = {}
@@ -82,8 +97,20 @@ def _get_attributes(instance):
8297 return _attrs
8398
8499
100+ _EIGVEC_DEPENDENT_ATTRS = frozenset ({"model_params" , "evecs" , "directions" })
101+ """Attributes that encode eigenvector orientation and are subject to sign/rotation
102+ ambiguity when eigenvalues are degenerate. Skipped in favour of comparing the
103+ reconstructed diffusion tensor (``quadratic_form``), which is invariant."""
104+
105+
85106def _compare_instance_attributes (instance1 , instance2 ):
86- """Compare non-callable, non-dunder attributes of two instances for numerical equality."""
107+ """Compare non-callable, non-dunder attributes of two instances for numerical equality.
108+
109+ Eigenvector-dependent attributes (``model_params``, ``evecs``, ``directions``)
110+ are skipped because degenerate eigenvalues make the eigenvector basis
111+ non-unique. Instead, the reconstructed diffusion tensor
112+ (``quadratic_form``) is compared, which is basis-invariant.
113+ """
87114 # Get attributes of both instances, excluding dunder and method attributes
88115 attributes1 = _get_attributes (instance1 )
89116 attributes2 = _get_attributes (instance2 )
@@ -96,6 +123,10 @@ def _compare_instance_attributes(instance1, instance2):
96123 # Compare the values of the attributes
97124 all_equal = True
98125 for attr in attributes1 :
126+ # Skip eigenvector-dependent attributes — compared via quadratic_form
127+ if attr in _EIGVEC_DEPENDENT_ATTRS :
128+ continue
129+
99130 value1 = attributes1 .get (attr )
100131 value2 = attributes2 .get (attr )
101132
@@ -106,10 +137,6 @@ def _compare_instance_attributes(instance1, instance2):
106137 all_equal = False
107138 continue
108139
109- elif value1 is None or value2 is None :
110- print (f"Attribute '{ attr } ' differs: { value1 } != { value2 } " )
111- all_equal = False
112-
113140 try :
114141 array1 = np .asarray (value1 ).ravel ()
115142 array2 = np .asarray (value2 ).ravel ()
@@ -555,6 +582,14 @@ def test_dti_prediction_shape(setup_random_dwi_data, index):
555582 "snr" : None ,
556583 "vol_shape" : (1 , 1 , 1 ),
557584 },
585+ {
586+ "bval_shell" : 1000 ,
587+ "S0" : 1 ,
588+ "evals" : (0.0015 , 0.0003 , 0.0003 ),
589+ "hsph_dirs" : 3 ,
590+ "snr" : None ,
591+ "vol_shape" : (2 , 4 , 1 ),
592+ },
558593 {
559594 "bval_shell" : 1000 ,
560595 "S0" : 1 ,
@@ -563,6 +598,14 @@ def test_dti_prediction_shape(setup_random_dwi_data, index):
563598 "snr" : None ,
564599 "vol_shape" : (1 , 1 , 1 ),
565600 },
601+ {
602+ "bval_shell" : 1000 ,
603+ "S0" : 1 ,
604+ "evals" : (0.0016 , 0.0004 , 0.0004 ),
605+ "hsph_dirs" : 6 ,
606+ "snr" : None ,
607+ "vol_shape" : (2 , 5 , 1 ),
608+ },
566609 {
567610 "bval_shell" : 1000 ,
568611 "S0" : 1 ,
@@ -571,12 +614,21 @@ def test_dti_prediction_shape(setup_random_dwi_data, index):
571614 "snr" : None ,
572615 "vol_shape" : (1 , 1 , 1 ),
573616 },
617+ {
618+ "bval_shell" : 1000 ,
619+ "S0" : 1 ,
620+ "evals" : (0.0015 , 0.0003 , 0.0003 ),
621+ "hsph_dirs" : 8 ,
622+ "snr" : None ,
623+ "vol_shape" : (2 , 3 , 1 ),
624+ },
574625 ],
575626 indirect = True ,
576627)
577628@pytest .mark .parametrize ("index" , (None , 3 , 5 ))
578629@pytest .mark .parametrize ("ignore_bzero" , (False , True ))
579630@pytest .mark .parametrize ("use_mask" , (False , True ))
631+ @ignore_dipy_invalid_divide
580632def test_dti_model_fit (single_shell_test_data , index , ignore_bzero , use_mask ):
581633 """Ensure that we get the same result obtained through the DTI model
582634 implemented in DIPY."""
@@ -616,6 +668,14 @@ def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
616668 "snr" : None ,
617669 "vol_shape" : (1 , 1 , 1 ),
618670 },
671+ {
672+ "bval_shell" : 1000 ,
673+ "S0" : 1 ,
674+ "evals" : (0.0015 , 0.0003 , 0.0003 ),
675+ "hsph_dirs" : 3 ,
676+ "snr" : None ,
677+ "vol_shape" : (2 , 4 , 1 ),
678+ },
619679 {
620680 "bval_shell" : 1000 ,
621681 "S0" : 1 ,
@@ -624,6 +684,14 @@ def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
624684 "snr" : None ,
625685 "vol_shape" : (1 , 1 , 1 ),
626686 },
687+ {
688+ "bval_shell" : 1000 ,
689+ "S0" : 1 ,
690+ "evals" : (0.0016 , 0.0004 , 0.0004 ),
691+ "hsph_dirs" : 6 ,
692+ "snr" : None ,
693+ "vol_shape" : (2 , 5 , 1 ),
694+ },
627695 {
628696 "bval_shell" : 1000 ,
629697 "S0" : 1 ,
@@ -632,12 +700,21 @@ def test_dti_model_fit(single_shell_test_data, index, ignore_bzero, use_mask):
632700 "snr" : None ,
633701 "vol_shape" : (1 , 1 , 1 ),
634702 },
703+ {
704+ "bval_shell" : 1000 ,
705+ "S0" : 1 ,
706+ "evals" : (0.0015 , 0.0003 , 0.0003 ),
707+ "hsph_dirs" : 8 ,
708+ "snr" : None ,
709+ "vol_shape" : (2 , 3 , 1 ),
710+ },
635711 ],
636712 indirect = True ,
637713)
638714@pytest .mark .parametrize ("index" , (None , 3 , 5 ))
639715@pytest .mark .parametrize ("ignore_bzero" , (False , True ))
640716@pytest .mark .parametrize ("use_mask" , (False , True ))
717+ @ignore_dipy_invalid_divide
641718def test_dti_model_predict (single_shell_test_data , index , ignore_bzero , use_mask ):
642719 """Ensure that we get the same result obtained through the DTI model
643720 implemented in DIPY."""
0 commit comments