diff --git a/.circleci/config.yml b/.circleci/config.yml index 47b0e00e..dbbf54db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -142,7 +142,8 @@ jobs: -v /tmp/data/nitransforms-tests:/data -e TEST_DATA_HOME=/data \ -e COVERAGE_FILE=/tmp/summaries/.pytest.coverage \ -v /tmp/fslicense/license.txt:/opt/freesurfer/license.txt:ro \ - -v /tmp/tests:/tmp nitransforms:latest \ + -v /tmp/tests:/tmp -e TEST_OUTPUT_DIR=/tmp/artifacts \ + nitransforms:latest \ pytest --junit-xml=/tmp/summaries/pytest.xml \ --cov nitransforms --cov-report xml:/tmp/summaries/unittests.xml \ nitransforms/ diff --git a/CHANGES.rst b/CHANGES.rst index 31628681..44579977 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,21 @@ +25.0.0 (TBD) +============ +A new major release with critical updates. +The new release includes a critical hotfix for 4D resamplings. +The second major improvement is the inclusion of a first implementation of the X5 format (BIDS). +The X5 implementation is currently restricted to reading/writing of linear transforms. + +CHANGES +------- +* FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247 +* ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243 +* ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241 +* DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242 +* MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248 +* MNT: Refactor io/lta to reduce one partial line by @oesteban in https://github.com/nipy/nitransforms/pull/246 +* MNT: Move flake8 config into ``pyproject.toml`` by @oesteban in https://github.com/nipy/nitransforms/pull/245 +* MNT: Configure coverage to omit tests by @oesteban in https://github.com/nipy/nitransforms/pull/244 + 24.1.2 (June 02, 2025) ====================== New patch release that addresses a crash when applying a 3D transform to a 4D image. diff --git a/docs/examples.rst b/docs/examples.rst index a41c4c5d..09d078a2 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -9,3 +9,4 @@ A collection of Jupyter Notebooks to serve as interactive tutorials. notebooks/isbi2020 notebooks/Reading and Writing transforms.ipynb + notebooks/Visualizing transforms.ipynb diff --git a/docs/notebooks/Visualizing transforms.ipynb b/docs/notebooks/Visualizing transforms.ipynb new file mode 100644 index 00000000..a322ea4b --- /dev/null +++ b/docs/notebooks/Visualizing transforms.ipynb @@ -0,0 +1,379 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vis - visualizing transforms\n", + "This notebook showcases the `nitransforms.vis` module, which implements the functionality to illustrate and view transforms." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preamble\n", + "Prepare a Python environment and use a temporal directory for the outputs. After that, fetch the actual file from NiBabel documentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "from tempfile import TemporaryDirectory\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import nibabel as nb\n", + "from nitransforms.vis import PlotDenseField\n", + "\n", + "cwd = TemporaryDirectory()\n", + "os.chdir(cwd.name)\n", + "print(f\"This notebook is being executed under <{os.getcwd()}>.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_dir = \"Desktop/nitransforms-tests\"\n", + "\n", + "anat_file = Path(os.getenv(\"TEST_DATA_HOME\", str(Path.home() / test_dir))) / \"someones_anatomy.nii.gz\"\n", + "transform_file = Path(os.getenv(\"TEST_DATA_HOME\", str(Path.home() / test_dir))) / \"someones_displacement_field.nii.gz\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load a displacement field\n", + "Info about the transform file here..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the example\n", + "nii = nb.load(anat_file)\n", + "hdr = nii.header.copy()\n", + "aff = nii.affine.copy()\n", + "data = np.asanyarray(nii.dataobj)\n", + "nii.orthoview()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Displacement fields\n", + "About displacement fields here..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `PlotDenseField()` from the `nitransforms.vis` module, we can prepare the ground for the illustration of a dense field transform object. `PlotDenseField()` takes three aruments: the transform file, an indication regarding the nature of the transform (default: deltas field) and a reference (default: None). We must also define slices to select which axial, coronal and sagittal planes we wish to visualise.\n", + "\n", + "Looking at the corresponding field for the example nifti file showcased above:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xslice, yslice, zslice = 30, 30, 30\n", + "\n", + "pdf = PlotDenseField(\n", + " transform=transform_file,\n", + " is_deltas=True,\n", + ")\n", + "print(pdf, pdf._xfm._field.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot distortion\n", + "To sbegin, we can use the `DenseFieldTransform().plot_distortion` module to visualise the deformed grid and superimpose it onto the density map of the deformation. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n", + "pdf.plot_distortion(\n", + " axes=axes,\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example, we notice that the coronal slice contains no deformations, hinting that the transform is only dependent on the y-dimension, ie that the deltas field only contains non-zero $V_y$ vector-components. The distortions of both Axial and Sagittal planes confirm this, with the grid being unaffected in $x$ and $z$ (ie $V_{x, i} = V_{y, i} = 0$ for all $i$). \n", + "\n", + "## Quiver deformation scalar map\n", + "To verify this, we can plot the field using `DenseFieldTransform().plot_quiverdsm` and view the transformation as a quiver plot. Here, the colors represent the dominant vector components of the transform field accroding to the standard convention for deformation scalar maps: red ($V_x$-dominant), green ($V_y$-dominant) and blue ($V_z$-dominant)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n", + "pdf.plot_quiverdsm(\n", + " axes=axes,\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As expected, the field is restrained to a single spatial dimension ($V_y$) and is most intense in the areas where gridlines are subject to the most distortions. The magnitude of the field at any point is represented by the length of the correspodning arrow, while the colour represence the dominance of the dimensional component. In this example, given that the transformation is confined to one dimension, these measures are equivalent and, indeed, the color intensity is also a proxy for arrow length. As we will see further down, this is not always true when dealing with non-zero, three-dimensional transformation fields." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Transformation fields in three dimensions\n", + "Let's now look at a more complicated dense field transform, this time containing deformations in all three spatial dimensions. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transform_3d = Path(os.getenv(\"TEST_DATA_HOME\", str(Path.home() / \"workspace/nitransforms/nitransforms/tests/data\"))) / \"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xslice, yslice, zslice = 50, 50, 50\n", + "\n", + "pdf_3d = PlotDenseField(\n", + " transform=transform_3d,\n", + " is_deltas=True,\n", + ")\n", + "print(pdf_3d, pdf_3d._xfm._field.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can start by taking a look at the deformation grid." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n", + "pdf_3d.plot_distortion(\n", + " axes=axes,\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice,\n", + " lw=0.25,\n", + " show_brain=False,\n", + " show_grid=True,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The three-dimensional deformations can easily be identified through the distortion of the coordinate grid alone. \n", + "\n", + "As in 1D, we again want to see how the field looks as a quiver plot." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quiver and deformation scalar map\n", + "Once again, it is straightfowrad to highlight the dominance of each dimension using `PlotDensefield().plot_quiverdsm`... " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n", + "pdf_3d.plot_quiverdsm(\n", + " axes = axes,\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "... which now allows us to identify which regions of the transform field have dominant $x$- (red), $y$- (green) or $z$- (blue) vector components.\n", + "\n", + "The color of each arrow therefore highlights which of the three directional components ($V_x$, $V_y$ or $V_z$) is dominant. The intensity of the color maps the strength of the dominant dimension (eg pale red -> dark red represents low $V_x$ -> high $V_x$, and similarly for $y$- and $z$- components with green and blue, respectively). The magnitude of the vector is still represented by the arrow length. \n", + "\n", + "For example, a vector $\\textbf{V} = (0.1, 0.2, 0.5)$ with magnitude $V \\approx 0.5$ has a slightly dominant $V_z$ and a relatively low magnitude, so will appear pale blue with a relatively short arrow. On the other hand, although a different vector $\\textbf{V} = (7.2, 7.4, 7.8)$ with $V \\approx 13$ still has a very small $z$-dominace, the $V_z$ component is significantly larger causing the arrow to take a dark blue color. Independently of the color, this vector will be long due to it's large magnitude." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linear, planar and spherical coefficients\n", + "We have shown that the transformation field can be mapped according to the dominance of its individual vector components. However, in the case where the difference between the most dominant vector component and the second most dominant component is very small, the transformation is in fact dictated by two dimensions. \n", + "\n", + "We therefore compute the Linear, Planar and Spherical coefficients of each point in the vector field to visualise whether the transformation is dictated by one, two or three dimensions, respectively. \n", + "\n", + "For example:\n", + "- A vector $V = (3, 0.2, 0.4)$ dominant in one dimension ($V_x$) will have a linear coefficient $c_L > c_p, c_s$\n", + "- A vector $V = (3, 2.8, 0.4)$ dominant in two dimensions ($V_x$ and $V_y$) will have a planar coefficient $c_p > c_l, c_s$\n", + "- A vector $V = (3, 2.8, 3.1)$ dominant in three dimensions will have a spherical coefficient $c_s > c_l, c_p$\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(3, 3, figsize=(10, 9), layout='constrained')\n", + "pdf_3d.plot_coeffs(\n", + " fig=fig,\n", + " axes=axes,\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Jacobians\n", + "Now that our field has components in more than one dimension, we can also use `PlotDenseField().plot_jacobian` to map out the jacobians of the vector field and highlight regions in which the transformation represents an expansion (red, $J>0$) or a contraction (blue, $J<0$) of the brain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(12, 4.5), tight_layout=True)\n", + "pdf_3d.plot_jacobian(\n", + " axes=axes,\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Transform overview\n", + "Now that we have unpacked the dense field transform, we can now bring this all together with `PlotDenseField().show_transform`. This creates a 3x3 grid of plots to provide the user with an overview of the transformation field contained in the nifti file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pdf_3d.show_transform(\n", + " xslice=xslice,\n", + " yslice=yslice,\n", + " zslice=zslice,\n", + ")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nitransforms/conftest.py b/nitransforms/conftest.py index 70680882..e68d4833 100644 --- a/nitransforms/conftest.py +++ b/nitransforms/conftest.py @@ -7,6 +7,8 @@ import tempfile _testdir = Path(os.getenv("TEST_DATA_HOME", "~/.nitransforms/testdata")).expanduser() +_outdir = os.getenv("TEST_OUTPUT_DIR", None) + _datadir = Path(__file__).parent / "tests" / "data" @@ -43,6 +45,12 @@ def testdata_path(): return _testdir +@pytest.fixture +def output_path(): + """Return an output folder.""" + return Path(_outdir) if _outdir is not None else None + + @pytest.fixture def get_testdata(): """Generate data in the requested orientation.""" diff --git a/nitransforms/io/__init__.py b/nitransforms/io/__init__.py index f9030724..a2ec7e6b 100644 --- a/nitransforms/io/__init__.py +++ b/nitransforms/io/__init__.py @@ -1,6 +1,7 @@ # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Read and write transforms.""" + from nitransforms.io import afni, fsl, itk, lta, x5 from nitransforms.io.base import TransformIOError, TransformFileError @@ -27,7 +28,37 @@ def get_linear_factory(fmt, is_array=True): - """Return the type required by a given format.""" + """ + Return the type required by a given format. + + Parameters + ---------- + fmt : :obj:`str` + A format identifying string. + is_array : :obj:`bool` + Whether the array version of the class should be returned. + + Returns + ------- + type + The class object (not an instance) of the linear transfrom to be created + (for example, :obj:`~nitransforms.io.itk.ITKLinearTransform`). + + Examples + -------- + >>> get_linear_factory("itk") + + >>> get_linear_factory("itk", is_array=False) + + >>> get_linear_factory("fsl") + + >>> get_linear_factory("fsl", is_array=False) + + >>> get_linear_factory("fakepackage") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TypeError: Unsupported transform format . + + """ if fmt.lower() not in _IO_TYPES: raise TypeError(f"Unsupported transform format <{fmt}>.") diff --git a/nitransforms/io/afni.py b/nitransforms/io/afni.py index 7c66d434..fb27eda6 100644 --- a/nitransforms/io/afni.py +++ b/nitransforms/io/afni.py @@ -198,7 +198,7 @@ def from_image(cls, imgobj): hdr = imgobj.header.copy() shape = hdr.get_data_shape() - if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3): + if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3): raise TransformFileError( 'Displacements field "%s" does not come from AFNI.' % imgobj.file_map["image"].filename diff --git a/nitransforms/io/fsl.py b/nitransforms/io/fsl.py index f454227e..3425f5d0 100644 --- a/nitransforms/io/fsl.py +++ b/nitransforms/io/fsl.py @@ -180,7 +180,7 @@ def from_image(cls, imgobj): hdr = imgobj.header.copy() shape = hdr.get_data_shape() - if len(shape) != 4 or not shape[-1] in (2, 3): + if len(shape) != 4 or shape[-1] not in (2, 3): raise TransformFileError( 'Displacements field "%s" does not come from FSL.' % imgobj.file_map['image'].filename) diff --git a/nitransforms/io/itk.py b/nitransforms/io/itk.py index afabfd98..02fd9fe9 100644 --- a/nitransforms/io/itk.py +++ b/nitransforms/io/itk.py @@ -337,7 +337,7 @@ def from_image(cls, imgobj): hdr = imgobj.header.copy() shape = hdr.get_data_shape() - if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3): + if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3): raise TransformFileError( 'Displacements field "%s" does not come from ITK.' % imgobj.file_map["image"].filename diff --git a/nitransforms/io/lta.py b/nitransforms/io/lta.py index 334266bb..1e7445bf 100644 --- a/nitransforms/io/lta.py +++ b/nitransforms/io/lta.py @@ -1,4 +1,5 @@ """Read/write linear transforms.""" + import numpy as np from nibabel.volumeutils import Recoder from nibabel.affines import voxel_sizes, from_matvec @@ -29,12 +30,12 @@ class VolumeGeometry(StringBasedStruct): template_dtype = np.dtype( [ ("valid", "i4"), # Valid values: 0, 1 - ("volume", "i4", (3, )), # width, height, depth - ("voxelsize", "f4", (3, )), # xsize, ysize, zsize + ("volume", "i4", (3,)), # width, height, depth + ("voxelsize", "f4", (3,)), # xsize, ysize, zsize ("xras", "f8", (3, 1)), # x_r, x_a, x_s ("yras", "f8", (3, 1)), # y_r, y_a, y_s ("zras", "f8", (3, 1)), # z_r, z_a, z_s - ("cras", "f8", (3, )), # c_r, c_a, c_s + ("cras", "f8", (3,)), # c_r, c_a, c_s ("filename", "U1024"), ] ) # Not conformant (may be >1024 bytes) @@ -109,14 +110,19 @@ def from_string(cls, string): label, valstring = lines.pop(0).split(" =") assert label.strip() == key - val = "" - if valstring.strip(): - parsed = np.genfromtxt( + parsed = ( + np.genfromtxt( [valstring.encode()], autostrip=True, dtype=cls.dtype[key] ) - if parsed.size: - val = parsed.reshape(sa[key].shape) - sa[key] = val + if valstring.strip() + else None + ) + + if parsed is not None and parsed.size: + sa[key] = parsed.reshape(sa[key].shape) + else: # pragma: no coverage + """Do not set sa[key]""" + return volgeom @@ -218,11 +224,15 @@ def to_ras(self, moving=None, reference=None): def to_string(self, partial=False): """Convert this transform to text.""" sa = self.structarr - lines = [ - "# LTA file created by NiTransforms", - "type = {}".format(sa["type"]), - "nxforms = 1", - ] if not partial else [] + lines = ( + [ + "# LTA file created by NiTransforms", + "type = {}".format(sa["type"]), + "nxforms = 1", + ] + if not partial + else [] + ) # Standard preamble lines += [ @@ -232,10 +242,7 @@ def to_string(self, partial=False): ] # Format parameters matrix - lines += [ - " ".join(f"{v:18.15e}" for v in sa["m_L"][i]) - for i in range(4) - ] + lines += [" ".join(f"{v:18.15e}" for v in sa["m_L"][i]) for i in range(4)] lines += [ "src volume info", @@ -324,10 +331,7 @@ def __getitem__(self, idx): def to_ras(self, moving=None, reference=None): """Set type to RAS2RAS and return the new matrix.""" self.structarr["type"] = 1 - return [ - xfm.to_ras(moving=moving, reference=reference) - for xfm in self.xforms - ] + return [xfm.to_ras(moving=moving, reference=reference) for xfm in self.xforms] def to_string(self): """Convert this LTA into text format.""" @@ -396,9 +400,11 @@ def from_ras(cls, ras, moving=None, reference=None): sa["type"] = 1 sa["nxforms"] = ras.shape[0] for i in range(sa["nxforms"]): - lt._xforms.append(cls._inner_type.from_ras( - ras[i, ...], moving=moving, reference=reference - )) + lt._xforms.append( + cls._inner_type.from_ras( + ras[i, ...], moving=moving, reference=reference + ) + ) sa["subject"] = "unset" sa["fscale"] = 0.0 @@ -407,8 +413,10 @@ def from_ras(cls, ras, moving=None, reference=None): def _drop_comments(string): """Drop comments.""" - return "\n".join([ - line.split("#")[0].strip() - for line in string.splitlines() - if line.split("#")[0].strip() - ]) + return "\n".join( + [ + line.split("#")[0].strip() + for line in string.splitlines() + if line.split("#")[0].strip() + ] + ) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 463a1336..a86a8554 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -136,3 +136,53 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): # "AdditionalParameters", data=node.additional_parameters # ) return fname + + +def from_filename(fname: str | Path) -> List[X5Transform]: + """Read a list of :class:`X5Transform` objects from an X5 HDF5 file.""" + try: + with h5py.File(str(fname), "r") as in_file: + if in_file.attrs.get("Format") != "X5": + raise TypeError("Input file is not in X5 format") + + tg = in_file["TransformGroup"] + return [ + _read_x5_group(node) + for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) + ] + except OSError as err: + if "file signature not found" in err.args[0]: + raise TypeError("Input file is not HDF5.") + + raise # pragma: no cover + + +def _read_x5_group(node) -> X5Transform: + x5 = X5Transform( + type=node.attrs["Type"], + transform=np.asarray(node["Transform"]), + subtype=node.attrs.get("SubType"), + representation=node.attrs.get("Representation"), + metadata=json.loads(node.attrs["Metadata"]) + if "Metadata" in node.attrs + else None, + dimension_kinds=[ + k.decode() if isinstance(k, bytes) else k + for k in node["DimensionKinds"][()] + ], + domain=None, + inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None, + jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None, + array_length=int(node.attrs.get("ArrayLength", 1)), + ) + + if "Domain" in node: + dgrp = node["Domain"] + x5.domain = X5Domain( + grid=bool(int(np.asarray(dgrp["Grid"]))), + size=tuple(np.asarray(dgrp["Size"])), + mapping=np.asarray(dgrp["Mapping"]), + coordinates=dgrp.attrs.get("Coordinates"), + ) + + return x5 diff --git a/nitransforms/linear.py b/nitransforms/linear.py index cf8f8465..26bf3374 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -9,6 +9,7 @@ """Linear transforms.""" import warnings +from collections import namedtuple import numpy as np from pathlib import Path @@ -27,7 +28,12 @@ EQUALITY_TOL, ) from nitransforms.io import get_linear_factory, TransformFileError -from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5 +from nitransforms.io.x5 import ( + X5Transform, + X5Domain, + to_filename as save_x5, + from_filename as load_x5, +) class Affine(TransformBase): @@ -174,8 +180,29 @@ def ndim(self): return self._matrix.ndim + 1 @classmethod - def from_filename(cls, filename, fmt=None, reference=None, moving=None): + def from_filename( + cls, filename, fmt=None, reference=None, moving=None, x5_position=0 + ): """Create an affine from a transform file.""" + + if fmt and fmt.upper() == "X5": + x5_xfm = load_x5(filename)[x5_position] + Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping + if ( + x5_xfm.domain + and not x5_xfm.domain.grid + and len(x5_xfm.domain.size) == 3 + ): # pragma: no cover + raise NotImplementedError( + "Only 3D regularly gridded domains are supported" + ) + elif x5_xfm.domain: + # Override reference + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + return Transform(x5_xfm.transform, reference=reference) + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") if fmt is not None and not Path(filename).exists(): @@ -265,7 +292,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False): if fmt.upper() == "X5": return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)]) - writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping)) + writer = get_linear_factory( + fmt, is_array=isinstance(self, LinearTransformsMapping) + ) if fmt.lower() in ("itk", "ants", "elastix"): writer.from_ras(self.matrix).to_filename(filename) @@ -348,11 +377,6 @@ def __init__(self, transforms, reference=None): ) self._inverse = np.linalg.inv(self._matrix) - def __iter__(self): - """Enable iterating over the series of transforms.""" - for _m in self.matrix: - yield Affine(_m, reference=self._reference) - def __getitem__(self, i): """Enable indexed access to the series of matrices.""" return Affine(self.matrix[i, ...], reference=self._reference) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 53750206..98ef4454 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -10,6 +10,7 @@ import asyncio from os import cpu_count +from contextlib import suppress from functools import partial from pathlib import Path from typing import Callable, TypeVar, Union @@ -108,12 +109,17 @@ async def _apply_serial( semaphore = asyncio.Semaphore(max_concurrent) for t in range(n_resamplings): - xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + xfm_t = ( + transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + ) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image + targets_t = ( + ImageGrid(spatialimage).index( _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) ) + if targets is None + else targets[t, ...] + ) data_t = ( data @@ -127,7 +133,7 @@ async def _apply_serial( partial( ndi.map_coordinates, data_t, - targets, + targets_t, output=output[..., t], order=order, mode=mode, @@ -255,11 +261,22 @@ def apply( dim=_ref.ndim, ) ) - elif xfm_nvols == 1: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + else: + # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints. + targets = ( + ImageGrid(spatialimage).index( + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + ) + if targets is None + else targets ) + if targets.ndim == 3: + targets = np.rollaxis(targets, targets.ndim - 1, 0) + else: + assert targets.ndim == 2 + targets = targets[np.newaxis, ...] + if serialize_4d: data = ( np.asanyarray(spatialimage.dataobj, dtype=input_dtype) @@ -294,17 +311,24 @@ def apply( else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) - ) - + if data_nvols == 1 and xfm_nvols == 1: + targets = np.squeeze(targets) + assert targets.ndim == 2 # Cast 3D data into 4D if 4D nonsequential transform - if data_nvols == 1 and xfm_nvols > 1: + elif data_nvols == 1 and xfm_nvols > 1: data = data[..., np.newaxis] - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + if xfm_nvols > 1: + assert targets.ndim == 3 + n_time, n_dim, n_vox = targets.shape + # Reshape to (3, n_time x n_vox) + ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1)) + time_row = np.repeat(np.arange(n_time), n_vox)[None, :] + + # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k) + # t is the slowest-changing axis, so we put it first + targets = np.vstack((time_row, ijk_targets)) + data = np.rollaxis(data, data.ndim - 1, 0) resampled = ndi.map_coordinates( data, @@ -323,11 +347,19 @@ def apply( ) hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype()) - moved = spatialimage.__class__( - resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)), - _ref.affine, - hdr, - ) + if serialize_4d: + resampled = resampled.reshape( + _ref.shape + if n_resamplings == 1 + else _ref.shape + (resampled.shape[-1],) + ) + else: + resampled = resampled.reshape((-1, *_ref.shape)) + resampled = np.rollaxis(resampled, 0, resampled.ndim) + with suppress(ValueError): + resampled = np.squeeze(resampled, axis=3) + + moved = spatialimage.__class__(resampled, _ref.affine, hdr) return moved output_dtype = output_dtype or input_dtype diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 32634c61..d1e5e47e 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -265,6 +265,9 @@ def test_linear_to_x5(tmpdir, store_inverse): aff.to_filename("export1.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export1.x5", fmt="X5") + # Test with Domain img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4)) img_path = Path(tmpdir) / "ref.nii.gz" @@ -275,21 +278,32 @@ def test_linear_to_x5(tmpdir, store_inverse): assert node.domain.size == aff.reference.shape aff.to_filename("export2.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export2.x5", fmt="X5") + # Test with Jacobian node.jacobian = np.zeros((2, 2, 2), dtype="float32") io.x5.to_filename("export3.x5", [node]) -def test_mapping_to_x5(): +@pytest.mark.parametrize("store_inverse", [True, False]) +def test_mapping_to_x5(tmp_path, store_inverse): mats = [ np.eye(4), np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]), ] mapping = nitl.LinearTransformsMapping(mats) - node = mapping.to_x5() + node = mapping.to_x5( + metadata={"GeneratedBy": "FreeSurfer 8"}, store_inverse=store_inverse + ) assert node.array_length == 2 assert node.transform.shape == (2, 4, 4) + mapping.to_filename(tmp_path / "export1.x5", x5_inverse=store_inverse) + + # Test round trip + assert mapping == nitl.Affine.from_filename(tmp_path / "export1.x5", fmt="X5") + def test_mulmat_operator(testdata_path): """Check the @ operator.""" diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 2384ad97..0e11df5b 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -363,3 +363,28 @@ def test_LinearTransformsMapping_apply( reference=testdata_path / "sbref.nii.gz", serialize_nvols=2 if serialize_4d else np.inf, ) + + +@pytest.mark.parametrize("serialize_4d", [True, False]) +def test_apply_4d(serialize_4d): + """Regression test for per-volume transforms with serialized resampling.""" + nvols = 9 + shape = (10, 5, 5) + base = np.zeros(shape, dtype=np.float32) + base[9, 2, 2] = 1 + img = nb.Nifti1Image(np.stack([base] * nvols, axis=-1), np.eye(4)) + + transforms = [] + for i in range(nvols): + mat = np.eye(4) + mat[0, 3] = i + transforms.append(nitl.Affine(mat)) + + extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1} + + xfm = nitl.LinearTransformsMapping(transforms, reference=img) + + moved = apply(xfm, img, order=0, **extraparams) + data = np.asanyarray(moved.dataobj) + idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)] + assert idxs == [(9 - i, 2, 2) for i in range(nvols)] diff --git a/nitransforms/tests/test_vis.py b/nitransforms/tests/test_vis.py new file mode 100644 index 00000000..39dae7e6 --- /dev/null +++ b/nitransforms/tests/test_vis.py @@ -0,0 +1,167 @@ +import numpy as np +import matplotlib.pyplot as plt +import pytest + +import nibabel as nb +from nitransforms.nonlinear import DenseFieldTransform +from nitransforms.vis import PlotDenseField, format_axes + + +def test_read_path(data_path): + """Check that filepaths are a supported method for loading + and reading transforms with PlotDenseField""" + PlotDenseField(transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz") + + +def test_slice_values(): + """Check that ValueError is issued if negative slices are provided""" + with pytest.raises(ValueError): + PlotDenseField( + transform=np.zeros((10, 10, 10, 3)), + reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None), + ).test_slices( + xslice=-1, + yslice=-1, + zslice=-1, + ) + + "Check that IndexError is issued if provided slices are beyond range of transform dimensions" + xfm = DenseFieldTransform( + field=np.zeros((10, 10, 10, 3)), + reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None), + ) + for idx in range(0,3): + if idx == 0: + i, j, k = 1, 0, 0 + elif idx == 1: + i, j, k = 0, 1, 0 + elif idx == 2: + i, j, k = 0, 0, 1 + + with pytest.raises(IndexError): + PlotDenseField( + transform=xfm._field, + reference=xfm._reference, + ).test_slices( + xslice=xfm._field.shape[0] + i, + yslice=xfm._field.shape[1] + j, + zslice=xfm._field.shape[2] + k, + ) + + +def test_show_transform(data_path, output_path): + PlotDenseField( + transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).show_transform( + xslice=45, + yslice=50, + zslice=55, + ) + if output_path is not None: + plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + else: + plt.show() + + +def test_plot_distortion(data_path, output_path): + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + PlotDenseField( + transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_distortion( + axes=axes, + xslice=50, + yslice=50, + zslice=50, + show_grid=True, + show_brain=True, + ) + if output_path is not None: + plt.savefig(output_path / "plot_distortion.svg", bbox_inches="tight") + else: + plt.show() + + +def test_empty_quiver(): + fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True) + PlotDenseField( + transform=np.zeros((10, 10, 10, 3)), + reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None), + ).plot_quiverdsm( + axes=axes, + xslice=5, + yslice=5, + zslice=5, + ) + + +def test_plot_quiverdsm(data_path, output_path): + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + PlotDenseField( + transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_quiverdsm( + axes=axes, + xslice=50, + yslice=50, + zslice=50, + ) + + if output_path is not None: + plt.savefig(output_path / "plot_quiverdsm.svg", bbox_inches="tight") + else: + plt.show() + + +def test_3dquiver(data_path, output_path): + with pytest.raises(NotImplementedError): + fig = plt.figure() + axes = fig.add_subplot(projection='3d') + PlotDenseField( + transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz", + ).plot_quiverdsm( + axes=axes, + xslice=None, + yslice=None, + zslice=None, + three_D=True + ) + format_axes(axes) + + if output_path is not None: + plt.savefig(output_path / "plot_3dquiver.svg", bbox_inches="tight") + else: + plt.show() + + +def test_coeffs(data_path, output_path): + fig, axes = plt.subplots(3, 3, figsize=(10, 9)) + PlotDenseField( + transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_coeffs( + fig=fig, + axes=axes, + xslice=50, + yslice=50, + zslice=50, + ) + + if output_path is not None: + plt.savefig(output_path / "plot_coeffs.svg", bbox_inches="tight") + else: + plt.show() + + +def test_plot_jacobian(data_path, output_path): + fig, axes = plt.subplots(1, 3, figsize=(12, 5)) + PlotDenseField( + transform=data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_jacobian( + axes=axes, + xslice=50, + yslice=50, + zslice=50, + ) + + if output_path is not None: + plt.savefig(output_path / "plot_jacobian.svg", bbox_inches="tight") + else: + plt.show() diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 8502a387..89b49e06 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -1,7 +1,8 @@ import numpy as np +import pytest from h5py import File as H5File -from ..io.x5 import X5Transform, X5Domain, to_filename +from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename def test_x5_transform_defaults(): @@ -39,3 +40,38 @@ def test_to_filename(tmp_path): assert "0" in grp assert grp["0"].attrs["Type"] == "linear" assert grp["0"].attrs["ArrayLength"] == 1 + + +def test_from_filename_roundtrip(tmp_path): + domain = X5Domain(grid=False, size=(5, 5, 5), mapping=np.eye(4)) + node = X5Transform( + type="linear", + transform=np.eye(4), + dimension_kinds=("space", "space", "space", "vector"), + domain=domain, + metadata={"foo": "bar"}, + inverse=np.eye(4), + ) + fname = tmp_path / "test.x5" + to_filename(fname, [node]) + + x5_list = from_filename(fname) + assert len(x5_list) == 1 + x5 = x5_list[0] + assert x5.type == node.type + assert np.allclose(x5.transform, node.transform) + assert x5.dimension_kinds == list(node.dimension_kinds) + assert x5.domain.grid == domain.grid + assert x5.domain.size == tuple(domain.size) + assert np.allclose(x5.domain.mapping, domain.mapping) + assert x5.metadata == node.metadata + assert np.allclose(x5.inverse, node.inverse) + + +def test_from_filename_invalid(tmp_path): + fname = tmp_path / "invalid.h5" + with H5File(fname, "w") as f: + f.attrs["Format"] = "NOTX5" + + with pytest.raises(TypeError): + from_filename(fname) diff --git a/nitransforms/vis.py b/nitransforms/vis.py new file mode 100644 index 00000000..45f52362 --- /dev/null +++ b/nitransforms/vis.py @@ -0,0 +1,779 @@ +import os +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt +import nibabel as nb + +from matplotlib.gridspec import GridSpec +from matplotlib.widgets import Slider + +from nitransforms.nonlinear import DenseFieldTransform + + +class PlotDenseField: + """ + Vizualisation of a transformation file using nitransform's DenseFielTransform module. + Generates four sorts of plots: + i) deformed grid superimposed on the normalised deformation field density map\n + iii) quiver map of the field coloured by its diffusion scalar map\n + iv) quiver map of the field coloured by the jacobian of the coordinate matrices\n + for 3 image projections: + i) axial (fixed z slice)\n + ii) saggital (fixed y slice)\n + iii) coronal (fixed x slice)\n + Outputs the resulting 3 x 3 image grid. + + Parameters + ---------- + + transform: :obj:`str` + Path from which the trasnformation file should be read. + is_deltas: :obj:`bool` + Whether the field is a displacement field or a deformations field. Default = True + reference : :obj:`ImageGrid` + Defines the domain of the transform. If not provided, the domain is defined from + the ``field`` input.""" + + __slots__ = ('_transform', '_xfm', '_voxel_size') + + def __init__(self, transform, is_deltas=True, reference=None): + self._transform = transform + self._xfm = DenseFieldTransform( + field=self._transform, + is_deltas=is_deltas, + reference=reference + ) + try: + """if field provided by path""" + self._voxel_size = nb.load(transform).header.get_zooms()[:3] + assert len(self._voxel_size) == 3 + except TypeError: + """if field provided by numpy array (eg tests)""" + deltas = [] + for i in range(self._xfm.ndim): + deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i])) + / len(self._xfm._field[i])) + + assert np.all(deltas == deltas[0]) + assert len(deltas) == 3 + self._voxel_size = deltas + + def show_transform( + self, + xslice, + yslice, + zslice, + scaling=1, + show_brain=True, + show_grid=True, + lw=0.1, + ): + """ + Plot output field from DenseFieldTransform class. + + Parameters + ---------- + xslice: :obj:`int` + x plane to select for axial projection of the transform. + yslice: :obj:`int` + y plane to select for coronal prjection of the transform. + zslice: :obj:`int` + z plane to select for sagittal prjection of the transform. + scaling: :obj:`float` + Fraction by which the quiver plot arrows are to be scaled (default: 1). + show_brain: :obj:`bool` + Whether to show the brain image with the deformation grid (default: True). + show_grid: :obj:`bool` + Whether to show the deformation grid with the brain deformation (default: True) + + Examples + -------- + PlotDenseField( + test_dir / "someones-displacement-field.nii.gz" + ).show_transform(50, 50, 50) + plt.show() + + PlotDenseField( + transform = test_dir / "someones-displacement-field.nii.gz", + is_deltas = True, + ).show_transform( + xslice = 70, + yslice = 60 + zslice = 90, + scaling = 3, + show_brain=False, + lw = 0.2 + save_to_path = str("test_dir/my_file.jpg"), + ) + plt.show() + """ + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + + fig, axes = format_fig( + figsize=(9,9), + gs_rows=3, + gs_cols=3, + suptitle="Dense Field Transform \n" + os.path.basename(self._transform), + ) + fig.subplots_adjust(bottom=0.15) + + projections = ["Axial", "Coronal", "Sagittal"] + for i, ax in enumerate(axes): + if i < 3: + xlabel = None + ylabel = projections[i] + else: + xlabel = ylabel = None + format_axes(ax, xlabel=xlabel, ylabel=ylabel, labelsize=16) + + self.plot_distortion( + (axes[2], axes[1], axes[0]), + xslice, + yslice, + zslice, + show_grid=show_grid, + show_brain=show_brain, + lw=lw, + show_titles=False, + ) + self.plot_quiverdsm( + (axes[5], axes[4], axes[3]), + xslice, + yslice, + zslice, + scaling=scaling, + show_titles=False, + ) + self.plot_jacobian( + (axes[8],axes[7], axes[6]), + xslice, + yslice, + zslice, + show_titles=False, + ) + + self.sliders(fig, xslice, yslice, zslice) + # NotImplemented: Interactive slider update here: + + def plot_distortion( + self, + axes, + xslice, + yslice, + zslice, + show_brain=True, + show_grid=True, + lw=0.1, + show_titles=True, + ): + """ + Plot the distortion grid. + + Parameters + ---------- + axis :obj:`tuple` + Axes on which the grid should be plotted. Requires 3 axes to illustrate + all projections (eg ax1: Axial, ax2: coronal, ax3: Sagittal) + xslice: :obj:`int` + x plane to select for axial projection of the transform. + yslice: :obj:`int` + y plane to select for coronal prjection of the transform. + zslice: :obj:`int` + z plane to select for sagittal prjection of the transform. + show_brain: :obj:`bool` + Whether the normalised density map of the distortion should be plotted (Default: True). + show_grid: :obj:`bool` + Whether the distorted grid lines should be plotted (Default: True). + lw: :obj:`float` + Line width used for gridlines (Default: 0.1). + show_titles :obj:`bool` + Show plane names as titles (default: True) + + Example: + fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) + PlotDenseField( + transform="test_dir/someones-displacement-field.nii.gz", + is_deltas=True, + ).plot_distortion( + axes=[axes[2], axes[1], axes[0]], + xslice=50, + yslice=75, + zslice=90, + show_brain=True, + show_grid=True, + lw=0.2, + ) + plt.savefig(str("test_dir/deformationgrid.jpg", dpi=300) + plt.show() + """ + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + planes, titles = self.get_planes(xslice, yslice, zslice) + + for index, plane in enumerate(planes): + x,y,z,u,v,w = plane + shape = self._xfm._field.shape[:-1] + + if index == 0: + dim1, dim2, vec1, vec2 = y, z, v, w + len1, len2 = shape[1], shape[2] + elif index == 1: + dim1, dim2, vec1, vec2 = x, z, u, w + len1, len2 = shape[0], shape[2] + else: + dim1, dim2, vec1, vec2 = x, y, u, v + len1, len2 = shape[0], shape[1] + + c = np.sqrt(vec1**2 + vec2**2) + c = c / c.max() + + x_moved = dim1 + vec1 + y_moved = dim2 + vec2 + + if show_grid: + for idx in range(0, len1, 1): + axes[index].plot( + x_moved[idx * len2:(idx + 1) * len2], + y_moved[idx * len2:(idx + 1) * len2], + c='k', + lw=lw, + ) + for idx in range(0, len2, 1): + axes[index].plot( + x_moved[idx::len2], + y_moved[idx::len2], + c='k', + lw=lw, + ) + + if show_brain: + axes[index].scatter(x_moved, y_moved, c=c, cmap='RdPu') + + if show_titles: + axes[index].set_title(titles[index], fontsize=14, weight='bold') + + def plot_quiverdsm( + self, + axes, + xslice, + yslice, + zslice, + scaling=1, + three_D=False, + show_titles=True, + ): + """ + Plot the Diffusion Scalar Map (dsm) as a quiver plot. + + Parameters + ---------- + axis :obj:`tuple` + Axes on which the quiver should be plotted. Requires 3 axes to illustrate + the dsm mapped as a quiver plot for each projection. + xslice: :obj:`int` + x plane to select for axial projection of the transform. + yslice: :obj:`int` + y plane to select for coronal projection of the transform. + zslice: :obj:`int` + z plane to select for sagittal projection of the transform. + scaling: :obj:`float` + Fraction by which the quiver plot arrows are to be scaled (default: 1). + three_D: :obj:`bool` + Whether the quiver plot is to be projected onto a 3D axis (default: False) + show_titles :obj:`bool` + Show plane names as titles (default: True) + + Example: + fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) + PlotDenseField( + transform="test_dir/someones-displacement-field.nii.gz", + is_deltas=True, + ).plot_quiverdsm( + axes=[axes[2], axes[1], axes[0]], + xslice=50, + yslice=75, + zslice=90, + scaling=2, + ) + plt.savefig(str("test_dir/quiverdsm.jpg", dpi=300) + plt.show() + + #Example 2: 3D quiver + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + PlotDenseField(transform, is_deltas=True).plot_quiverdsm( + ax, + xslice=None, + yslice=None, + zslice=None, + scaling=10, + three_D=True, + ) + plt.show() + """ + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + planes, titles = self.get_planes(xslice, yslice, zslice) + + if three_D is not False: + raise NotImplementedError("3d Quiver plot not finalised.") + + # finalise 3d quiver below: + for i, j in enumerate(planes): + x, y, z, u, v, w = j + + magnitude = np.sqrt(u**2 + v**2 + w**2) + clr3d = plt.cm.viridis(magnitude / magnitude.max()) + xyz = axes.quiver(x, y, z, u, v, w, colors=clr3d, length=1 / scaling) + plt.colorbar(xyz) + else: + for index, plane in enumerate(planes): + x, y, z, u, v, w = plane + c_reds, c_greens, c_blues, zeros = [], [], [], [] + + # Optimise here, matrix operations + for idx, (i, j, k, l, m, n) in enumerate(zip(x, y, z, u, v, w)): + if np.abs(u[idx]) > [np.abs(v[idx]) and np.abs(w[idx])]: + c_reds.append((i, j, k, l, m, n, np.abs(u[idx]))) + elif np.abs(v[idx]) > [np.abs(u[idx]) and np.abs(w[idx])]: + c_greens.append((i, j, k, l, m, n, np.abs(v[idx]))) + elif np.abs(w[idx]) > [np.abs(u[idx]) and np.abs(v[idx])]: + c_blues.append((i, j, k, l, m, n, np.abs(w[idx]))) + else: + zeros.append(0) + + '''Check if shape of c_arrays is (0,) ie transform is independent of some dims''' + if np.shape(c_reds) == (0,): + c_reds = np.zeros((1, 7)) + if np.shape(c_greens) == (0,): + c_greens = np.zeros((1, 7)) + if np.shape(c_blues) == (0,): + c_blues = np.zeros((1, 7)) + elif ( + np.shape(c_reds) != (0,) + and np.shape(c_greens) != (0,) + and np.shape(c_blues) != (0,) + ): + assert len(np.concatenate((c_reds, c_greens, c_blues))) == len(x) - len(zeros) + + c_reds = np.asanyarray(c_reds) + c_greens = np.asanyarray(c_greens) + c_blues = np.asanyarray(c_blues) + + if index == 0: + dim1, dim2, vec1, vec2 = 1, 2, 4, 5 + elif index == 1: + dim1, dim2, vec1, vec2 = 0, 2, 3, 5 + elif index == 2: + dim1, dim2, vec1, vec2 = 0, 1, 3, 4 + + axes[index].quiver( + c_reds[:, dim1], + c_reds[:, dim2], + c_reds[:, vec1], + c_reds[:, vec2], + c_reds[:, -1], + cmap='Reds', + ) + axes[index].quiver( + c_greens[:, dim1], + c_greens[:, dim2], + c_greens[:, vec1], + c_greens[:, vec2], + c_greens[:, -1], + cmap='Greens', + ) + axes[index].quiver( + c_blues[:, dim1], + c_blues[:, dim2], + c_blues[:, vec1], + c_blues[:, vec2], + c_blues[:, -1], + cmap='Blues', + ) + + if show_titles: + axes[index].set_title(titles[index], fontsize=14, weight='bold') + + def plot_coeffs(self, fig, axes, xslice, yslice, zslice, s=0.1, show_titles=True): + """ + Plot linear, planar and spherical coefficients. + Parameters + ---------- + fig :obj:`figure` + Figure to use for mapping the coefficients. + axis :obj:`tuple` + Axes on which the quiver should be plotted. Requires 3 axes to illustrate + the dsm mapped as a quiver plot for each projection. + xslice: :obj:`int` + x plane to select for axial projection of the transform. + yslice: :obj:`int` + y plane to select for coronal projection of the transform. + zslice: :obj:`int` + z plane to select for sagittal projection of the transform. + s: :obj:`float` + Size of scatter points (default: 0.1). + show_titles :obj:`bool` + Show plane names as titles (default: True) + + Example: + fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) + PlotDenseField( + transform="test_dir/someones-displacement-field.nii.gz", + is_deltas=True, + ).plot_coeffs( + fig=fig + axes=axes, + xslice=50, + yslice=75, + zslice=90, + ) + plt.show() + """ + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + planes, titles = self.get_planes(xslice, yslice, zslice) + + for index, plane in enumerate(planes): + x, y, z, u, v, w = plane + + if index == 0: + dim1, dim2 = y, z + elif index == 1: + dim1, dim2 = x, z + else: + dim1, dim2 = x, y + + cl_arr, cp_arr, cs_arr = [], [], [] + + for idx, (i, j, k) in enumerate(zip(u, v, w)): + i, j, k = abs(i), abs(j), abs(k) + L1, L2, L3 = sorted([i, j, k], reverse=True) + asum = np.sum([i, j, k]) + + cl = (L1 - L2) / asum + cl_arr.append(cl) if cl != np.nan else cl.append(0) + + cp = 2 * (L2 - L3) / asum + cp_arr.append(cp) if cp != np.nan else cp.append(0) + + cs = 3 * L3 / asum + cs_arr.append(cs) if cs != np.nan else cs.append(0) + + a = axes[0, index].scatter(dim1, dim2, c=cl_arr, cmap='Reds', s=s) + b = axes[1, index].scatter(dim1, dim2, c=cp_arr, cmap='Greens', s=s) + c = axes[2, index].scatter(dim1, dim2, c=cs_arr, cmap='Blues', s=s) + + if show_titles: + axes[0, index].set_title(titles[index], fontsize=14, weight='bold') + + cb = fig.colorbar(a, ax=axes[0,:], location='right') + cb.set_label(label=r"$c_l$",weight='bold', fontsize=14) + + cb = fig.colorbar(b, ax=axes[1,:], location='right') + cb.set_label(label=r"$c_p$",weight='bold', fontsize=14) + + cb = fig.colorbar(c, ax=axes[2,:], location='right') + cb.set_label(label=r"$c_s$",weight='bold', fontsize=14) + + def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True): + """ + Map the divergence of the transformation field using a quiver plot. + + Parameters + ---------- + axis :obj:`tuple` + Axes on which the quiver should be plotted. Requires 3 axes to illustrate + each projection (eg ax1: Axial, ax2: coronal, ax3: Sagittal) + xslice: :obj:`int` + x plane to select for axial projection of the transform. + yslice: :obj:`int` + y plane to select for coronal projection of the transform. + zslice: :obj:`int` + z plane to select for sagittal projection of the transform. + show_titles :obj:`bool` + Show plane names as titles (default: True) + + Example: + fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) + PlotDenseField( + transform="test_dir/someones-displacement-field.nii.gz", + is_deltas=True, + ).plot_jacobian( + axes=[axes[2], axes[1], axes[0]], + xslice=50, + yslice=75, + zslice=90, + ) + plt.savefig(str("test_dir/jacobians.jpg", dpi=300) + plt.show() + """ + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + planes, titles = self.get_planes(xslice, yslice, zslice) + + jacobians = np.zeros((3), dtype=np.ndarray) + + """iterating through the three chosen planes to calculate corresponding coordinates""" + jac = self.get_jacobian().reshape(self._xfm._field[..., -1].shape) + for idx, slicer in enumerate(( + (xslice, slice(None), slice(None), None), + (slice(None), yslice, slice(None), None), + (slice(None), slice(None), zslice, None), + )): + jacobians[idx] = jac[slicer].flatten() + + for index, (ax, plane) in enumerate(zip(axes, planes)): + x, y, z, _, _, _ = plane + + if index == 0: + dim1, dim2 = y, z + elif index == 1: + dim1, dim2 = x, z + else: + dim1, dim2 = x, y + + c = jacobians[index] + plot = ax.scatter(dim1, dim2, c=c, norm=mpl.colors.CenteredNorm(), cmap='seismic') + + if show_titles: + ax.set_title(titles[index], fontsize=14, weight='bold') + plt.colorbar(plot, location='bottom', orientation='horizontal', label=str(r'$J$')) + + def test_slices(self, xslice, yslice, zslice): + """Ensure slices are positive and within range of image dimensions""" + xfm = self._xfm._field + + try: + if xslice < 0 or yslice < 0 or zslice < 0: + raise ValueError("Slice values must be positive integers") + + if int(xslice) > xfm.shape[0]: + raise IndexError(f"x-slice {xslice} out of range for transform object " + f"with x-dimension of length {xfm.shape[0]}") + if int(yslice) > xfm.shape[1]: + raise IndexError(f"y-slice {yslice} out of range for transform object " + f"with y-dimension of length {xfm.shape[1]}") + if int(zslice) > xfm.shape[2]: + raise IndexError(f"z-slice {zslice} out of range for transform object " + f"with z-dimension of length {xfm.shape[2]}") + + return (int(xslice), int(yslice), int(zslice)) + + except TypeError as e: + """exception for case of 3d quiver plot""" + assert str(e) == "'<' not supported between instances of 'NoneType' and 'int'" + + return (xslice, yslice, zslice) + + def get_coords(self): + """Calculate vector components of the field using the reference coordinates""" + x = self._xfm.reference.ndcoords[0].reshape(np.shape(self._xfm._field[...,-1])) + y = self._xfm.reference.ndcoords[1].reshape(np.shape(self._xfm._field[...,-1])) + z = self._xfm.reference.ndcoords[2].reshape(np.shape(self._xfm._field[...,-1])) + u = self._xfm._field[..., 0] - x + v = self._xfm._field[..., 1] - y + w = self._xfm._field[..., 2] - z + return x, y, z, u, v, w + + def get_jacobian(self): + """Calculate the Jacobian matrix of the field""" + x, y, z, u, v, w = self.get_coords() + voxx, voxy, voxz = self._voxel_size + + shape = self._xfm._field[..., -1].shape + zeros = np.zeros(shape) + jacobians = zeros.flatten() + + dxdx = (np.diff(u, axis=0) / voxx) + dydx = (np.diff(v, axis=0) / voxx) + dzdx = (np.diff(w, axis=0) / voxx) + + dxdy = (np.diff(u, axis=1) / voxy) + dydy = (np.diff(v, axis=1) / voxy) + dzdy = (np.diff(w, axis=1) / voxy) + + dxdz = (np.diff(u, axis=2) / voxz) + dydz = (np.diff(v, axis=2) / voxz) + dzdz = (np.diff(w, axis=2) / voxz) + + partials = [dxdx, dydx, dzdx, dxdy, dydy, dzdy, dxdz, dydz, dzdz] + + for idx, j in enumerate(partials): + if idx < 3: + dim = zeros[-1,:,:][None,:,:] + ax = 0 + elif idx >= 3 and idx < 6: + dim = zeros[:,-1,:][:,None,:] + ax = 1 + elif idx >= 6: + dim = zeros[:,:,-1][:,:,None] + ax = 2 + + partials[idx] = np.append(j, dim, axis=ax).flatten() + + dxdx, dydx, dzdx, dxdy, dydy, dzdy, dxdz, dydz, dzdz = partials + + for idx, k in enumerate(jacobians): + jacobians[idx] = np.linalg.det( + np.array( + [ + [dxdx[idx], dxdy[idx], dxdz[idx]], + [dydx[idx], dydy[idx], dydz[idx]], + [dzdx[idx], dzdy[idx], dzdz[idx]] + ] + ) + ) + return jacobians + + def get_planes(self, xslice, yslice, zslice): + """Define slice selection for visualisation""" + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + titles = ["Sagittal", "Coronal", "Axial"] + planes = [0] * 3 + slices = [ + [False, False, False, False], + [False, False, False, False], + [False, False, False, False], + ] + + for idx, s in enumerate(slices): + x, y, z, u, v, w = self.get_coords() + + """indexing for plane selection [x: sagittal, y: coronal, z: axial, dimension]""" + s = [xslice, slice(None), slice(None), None] if idx == 0 else s + s = [slice(None), yslice, slice(None), None] if idx == 1 else s + s = [slice(None), slice(None), zslice, None] if idx == 2 else s + # For 3d quiver: + if xslice == yslice == zslice is None: + s = [slice(None), slice(None), slice(None), None] + + """computing coordinates within each plane""" + x = x[s[0], s[1], s[2]] + y = y[s[0], s[1], s[2]] + z = z[s[0], s[1], s[2]] + u = self._xfm._field[s[0], s[1], s[2], 0] - x + v = self._xfm._field[s[0], s[1], s[2], 1] - y + w = self._xfm._field[s[0], s[1], s[2], 2] - z + + x = x.flatten() + y = y.flatten() + z = z.flatten() + u = u.flatten() + v = v.flatten() + w = w.flatten() + + """check indexing has retrieved correct dimensions""" + if idx == 0 and xslice is not None: + assert x.shape == u.shape == np.shape(self._xfm._field[-1,...,-1].flatten()) + elif idx == 1 and yslice is not None: + assert y.shape == v.shape == np.shape(self._xfm._field[:,-1,:,-1].flatten()) + elif idx == 2 and zslice is not None: + assert z.shape == w.shape == np.shape(self._xfm._field[...,-1,-1].flatten()) + + """store 3 slices of datapoints, with overall shape [3 x [6 x [data]]]""" + planes[idx] = [x, y, z, u, v, w] + return planes, titles + + def sliders(self, fig, xslice, yslice, zslice): + # This successfully generates a slider, but it cannot be used. + # Currently, slider only acts as a label to show slice values. + # raise NotImplementedError("Slider implementation not finalised. + # Static slider can be generated but is not interactive") + + xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) + slices = [ + [zslice, len(self._xfm._field[0][0]), "zslice"], + [yslice, len(self._xfm._field[0]), "yslice"], + [xslice, len(self._xfm._field), "xslice"], + ] + axes = [ + [1 / 7, 0.1, 1 / 7, 0.025], + [3 / 7, 0.1, 1 / 7, 0.025], + [5 / 7, 0.1, 1 / 7, 0.025], + ] + sliders = [] + + for index, slider_axis in enumerate(axes): + slice_dim = slices[index][0] + sax = fig.add_axes(slider_axis) + slider = Slider( + ax=sax, + valmin=0, + valmax=slices[index][1], + valinit=slice_dim, + valstep=1, + valfmt='%d', + label=slices[index][2], + orientation="horizontal" + ) + sliders.append(slider) + + assert sliders[index].val == slices[index][0] + + return sliders + + def update_sliders(self, slider): + raise NotImplementedError("Interactive sliders not implemented.") + + new_slider = slider.val + return new_slider + + +def format_fig(figsize, gs_rows, gs_cols, **kwargs): + params = { + 'gs_wspace': 0, + 'gs_hspace': 1 / 8, + 'suptitle': None, + } + params.update(kwargs) + + fig = plt.figure(figsize=figsize) + fig.suptitle( + params['suptitle'], + fontsize='20', + weight='bold') + + gs = GridSpec( + gs_rows, + gs_cols, + figure=fig, + wspace=params['gs_wspace'], + hspace=params['gs_hspace'] + ) + + axes = [] + for j in range(0, gs_cols): + for i in range(0, gs_rows): + axes.append(fig.add_subplot(gs[i,j])) + return fig, axes + + +def format_axes(axis, **kwargs): + params = { + 'title':None, + 'xlabel':"x", + 'ylabel':"y", + 'zlabel':"z", + 'xticks':[], + 'yticks':[], + 'zticks':[], + 'rotate_3dlabel':False, + 'labelsize':16, + 'ticksize':14, + } + params.update(kwargs) + + '''Format the figure axes. For 2D plots, zlabel and zticks parameters are None.''' + axis.set_title(params['title'], weight='bold') + axis.set_xticks(params['xticks']) + axis.set_yticks(params['yticks']) + axis.set_xlabel(params['xlabel'], fontsize=params['labelsize']) + axis.set_ylabel(params['ylabel'], fontsize=params['labelsize']) + + '''if 3d projection plot''' + try: + axis.set_zticks(params['zticks']) + axis.set_zlabel(params['zlabel']) + axis.xaxis.set_rotate_label(params['rotate_3dlabel']) + axis.yaxis.set_rotate_label(params['rotate_3dlabel']) + axis.zaxis.set_rotate_label(params['rotate_3dlabel']) + except AttributeError: + pass + return diff --git a/pyproject.toml b/pyproject.toml index f11e2e5e..a6ac0859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,3 +98,12 @@ exclude_lines = [ "raise NotImplementedError", "warnings\\.warn", ] + +[tool.flake8] +max-line-length = 99 +doctests = false +ignore = [ + "E266", + "E231", + "W503", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index f355be94..00000000 --- a/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -max-line-length = 99 -doctests = False -ignore = - E266 - E231 - W503 diff --git a/tox.ini b/tox.ini index fe549039..50d167bc 100644 --- a/tox.ini +++ b/tox.ini @@ -59,6 +59,7 @@ description = Check our style guide labels = check deps = flake8 + flake8-pyproject skip_install = true commands = flake8 nitransforms