Skip to content

Commit

Permalink
[ENH] switch tqdm for rich progress bar (#95)
Browse files Browse the repository at this point in the history
* switch tqdm for rich progress bar

* lint

* edit doc - full_test
  • Loading branch information
Remi-Gau authored Mar 25, 2024
1 parent 4f79846 commit 726ba1e
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 60 deletions.
3 changes: 3 additions & 0 deletions docs/source/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ git checkout -b your_branch

4. Commit your changes on this branch.

If you want to make sure all the tests will be run by github continuous integration,
make sure that your commit message contains `full_test`.

5. Run the tests locally; you can run spectfic tests to speed up the process:

```bash
Expand Down
39 changes: 24 additions & 15 deletions giga_connectome/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Union, List

from pathlib import Path
from tqdm import tqdm
import nibabel as nib
from nilearn.image import resample_to_img
from nibabel import Nifti1Image
from pkg_resources import resource_filename

from giga_connectome.logger import gc_logger
from giga_connectome.utils import progress_bar

gc_log = gc_logger()

Expand Down Expand Up @@ -110,21 +110,30 @@ def resample_atlas_collection(
"""
gc_log.info("Resample atlas to group grey matter mask.")
resampled_atlases = []
for desc in tqdm(atlas_config["file_paths"]):
parcellation = atlas_config["file_paths"][desc]
parcellation_resampled = resample_to_img(
parcellation, group_mask, interpolation="nearest"
)
filename = (
f"tpl-{template}_"
f"atlas-{atlas_config['name']}_"
"res-dataset_"
f"desc-{desc}_"
f"{atlas_config['type']}.nii.gz"

with progress_bar(text="Resampling atlases") as progress:
task = progress.add_task(
description="resampling", total=len(atlas_config["file_paths"])
)
save_path = group_mask_dir / filename
nib.save(parcellation_resampled, save_path)
resampled_atlases.append(save_path)

for desc in atlas_config["file_paths"]:
parcellation = atlas_config["file_paths"][desc]
parcellation_resampled = resample_to_img(
parcellation, group_mask, interpolation="nearest"
)
filename = (
f"tpl-{template}_"
f"atlas-{atlas_config['name']}_"
"res-dataset_"
f"desc-{desc}_"
f"{atlas_config['type']}.nii.gz"
)
save_path = group_mask_dir / filename
nib.save(parcellation_resampled, save_path)
resampled_atlases.append(save_path)

progress.update(task, advance=1)

return resampled_atlases


Expand Down
94 changes: 51 additions & 43 deletions giga_connectome/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path

import h5py
from tqdm import tqdm
import numpy as np
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import NiftiLabelsMasker, NiftiMapsMasker
Expand All @@ -12,6 +11,7 @@
from giga_connectome.connectome import generate_timeseries_connectomes
from giga_connectome.denoise import denoise_nifti_voxel
from giga_connectome.logger import gc_logger
from giga_connectome.utils import progress_bar

gc_log = gc_logger()

Expand Down Expand Up @@ -101,53 +101,61 @@ def run_postprocessing_dataset(
)

# transform data
gc_log.info("Processing subject")
with progress_bar(text="Processing subject") as progress:
task = progress.add_task(
description="processing subject", total=len(images)
)

for img in tqdm(images):
print()
gc_log.info(f"Processing image:\n{img.filename}")
for img in images:
print()
gc_log.info(f"Processing image:\n{img.filename}")

# process timeseries
denoised_img = denoise_nifti_voxel(
strategy, group_mask, standardize, smoothing_fwhm, img.path
)
# parse file name
subject, session, specifier = utils.parse_bids_name(img.path)
for desc, masker in atlas_maskers.items():
attribute_name = f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}"
if not denoised_img:
time_series_atlas, correlation_matrix = None, None

gc_log.info(f"{attribute_name}: no volume after scrubbing")

continue

# extract timeseries and connectomes
(
correlation_matrix,
time_series_atlas,
) = generate_timeseries_connectomes(
masker,
denoised_img,
group_mask,
correlation_measure,
calculate_average_correlation,
# process timeseries
denoised_img = denoise_nifti_voxel(
strategy, group_mask, standardize, smoothing_fwhm, img.path
)
connectomes[desc].append(correlation_matrix)

# dump to h5
flag = _set_file_flag(output_path)
with h5py.File(output_path, flag) as f:
group = _fetch_h5_group(f, subject, session)
timeseries_dset = group.create_dataset(
f"{attribute_name}_timeseries", data=time_series_atlas
# parse file name
subject, session, specifier = utils.parse_bids_name(img.path)
for desc, masker in atlas_maskers.items():
attribute_name = (
f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}"
)
timeseries_dset.attrs["RepetitionTime"] = img.entities[
"RepetitionTime"
]
group.create_dataset(
f"{attribute_name}_connectome", data=correlation_matrix
if not denoised_img:
time_series_atlas, correlation_matrix = None, None

gc_log.info(f"{attribute_name}: no volume after scrubbing")

progress.update(task, advance=1)
continue

# extract timeseries and connectomes
(
correlation_matrix,
time_series_atlas,
) = generate_timeseries_connectomes(
masker,
denoised_img,
group_mask,
correlation_measure,
calculate_average_correlation,
)
connectomes[desc].append(correlation_matrix)

# dump to h5
flag = _set_file_flag(output_path)
with h5py.File(output_path, flag) as f:
group = _fetch_h5_group(f, subject, session)
timeseries_dset = group.create_dataset(
f"{attribute_name}_timeseries", data=time_series_atlas
)
timeseries_dset.attrs["RepetitionTime"] = img.entities[
"RepetitionTime"
]
group.create_dataset(
f"{attribute_name}_connectome", data=correlation_matrix
)

progress.update(task, advance=1)

gc_log.info(f"Saved to:\n{output_path}")

Expand Down
23 changes: 23 additions & 0 deletions giga_connectome/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
from bids.layout import Query
from bids import BIDSLayout

from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

from giga_connectome.logger import gc_logger

gc_log = gc_logger()
Expand Down Expand Up @@ -192,3 +203,15 @@ def check_path(path: Path):
f"Specified path already exists:\n\t{path}\n"
"Old file will be overwritten"
)


def progress_bar(text: str, color: str = "green") -> Progress:
return Progress(
TextColumn(f"[{color}]{text}"),
SpinnerColumn("dots"),
TimeElapsedColumn(),
BarColumn(),
MofNCompleteColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"nilearn >=0.10.2",
"pybids >=0.15.0, <0.16.0",
"templateflow < 23.0.0",
"tqdm",
"setuptools",
"jinja2 >= 2.0",
"rich",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ six==1.16.0
SQLAlchemy==1.3.24
templateflow==0.8.1
threadpoolctl==3.2.0
tqdm==4.66.1
typing_extensions==4.9.0
tzdata==2023.3
urllib3==2.1.0
Expand Down

0 comments on commit 726ba1e

Please sign in to comment.