From 726ba1e169e42a183ee1d4b328d47b02970ebccb Mon Sep 17 00:00:00 2001 From: Remi Gau Date: Mon, 25 Mar 2024 09:47:23 -0400 Subject: [PATCH] [ENH] switch tqdm for rich progress bar (#95) * switch tqdm for rich progress bar * lint * edit doc - full_test --- docs/source/contributing.md | 3 ++ giga_connectome/atlas.py | 39 ++++++++------ giga_connectome/postprocess.py | 94 ++++++++++++++++++---------------- giga_connectome/utils.py | 23 +++++++++ pyproject.toml | 1 - requirements.txt | 1 - 6 files changed, 101 insertions(+), 60 deletions(-) diff --git a/docs/source/contributing.md b/docs/source/contributing.md index 8e1fbb4..c8ddd74 100644 --- a/docs/source/contributing.md +++ b/docs/source/contributing.md @@ -56,6 +56,9 @@ git checkout -b your_branch 4. Commit your changes on this branch. +If you want to make sure all the tests will be run by github continuous integration, +make sure that your commit message contains `full_test`. + 5. Run the tests locally; you can run spectfic tests to speed up the process: ```bash diff --git a/giga_connectome/atlas.py b/giga_connectome/atlas.py index d0c7515..063e356 100644 --- a/giga_connectome/atlas.py +++ b/giga_connectome/atlas.py @@ -3,13 +3,13 @@ from typing import Union, List from pathlib import Path -from tqdm import tqdm import nibabel as nib from nilearn.image import resample_to_img from nibabel import Nifti1Image from pkg_resources import resource_filename from giga_connectome.logger import gc_logger +from giga_connectome.utils import progress_bar gc_log = gc_logger() @@ -110,21 +110,30 @@ def resample_atlas_collection( """ gc_log.info("Resample atlas to group grey matter mask.") resampled_atlases = [] - for desc in tqdm(atlas_config["file_paths"]): - parcellation = atlas_config["file_paths"][desc] - parcellation_resampled = resample_to_img( - parcellation, group_mask, interpolation="nearest" - ) - filename = ( - f"tpl-{template}_" - f"atlas-{atlas_config['name']}_" - "res-dataset_" - f"desc-{desc}_" - f"{atlas_config['type']}.nii.gz" + + with progress_bar(text="Resampling atlases") as progress: + task = progress.add_task( + description="resampling", total=len(atlas_config["file_paths"]) ) - save_path = group_mask_dir / filename - nib.save(parcellation_resampled, save_path) - resampled_atlases.append(save_path) + + for desc in atlas_config["file_paths"]: + parcellation = atlas_config["file_paths"][desc] + parcellation_resampled = resample_to_img( + parcellation, group_mask, interpolation="nearest" + ) + filename = ( + f"tpl-{template}_" + f"atlas-{atlas_config['name']}_" + "res-dataset_" + f"desc-{desc}_" + f"{atlas_config['type']}.nii.gz" + ) + save_path = group_mask_dir / filename + nib.save(parcellation_resampled, save_path) + resampled_atlases.append(save_path) + + progress.update(task, advance=1) + return resampled_atlases diff --git a/giga_connectome/postprocess.py b/giga_connectome/postprocess.py index 975d1c2..488b3f7 100644 --- a/giga_connectome/postprocess.py +++ b/giga_connectome/postprocess.py @@ -2,7 +2,6 @@ from pathlib import Path import h5py -from tqdm import tqdm import numpy as np from nilearn.connectome import ConnectivityMeasure from nilearn.maskers import NiftiLabelsMasker, NiftiMapsMasker @@ -12,6 +11,7 @@ from giga_connectome.connectome import generate_timeseries_connectomes from giga_connectome.denoise import denoise_nifti_voxel from giga_connectome.logger import gc_logger +from giga_connectome.utils import progress_bar gc_log = gc_logger() @@ -101,53 +101,61 @@ def run_postprocessing_dataset( ) # transform data - gc_log.info("Processing subject") + with progress_bar(text="Processing subject") as progress: + task = progress.add_task( + description="processing subject", total=len(images) + ) - for img in tqdm(images): - print() - gc_log.info(f"Processing image:\n{img.filename}") + for img in images: + print() + gc_log.info(f"Processing image:\n{img.filename}") - # process timeseries - denoised_img = denoise_nifti_voxel( - strategy, group_mask, standardize, smoothing_fwhm, img.path - ) - # parse file name - subject, session, specifier = utils.parse_bids_name(img.path) - for desc, masker in atlas_maskers.items(): - attribute_name = f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}" - if not denoised_img: - time_series_atlas, correlation_matrix = None, None - - gc_log.info(f"{attribute_name}: no volume after scrubbing") - - continue - - # extract timeseries and connectomes - ( - correlation_matrix, - time_series_atlas, - ) = generate_timeseries_connectomes( - masker, - denoised_img, - group_mask, - correlation_measure, - calculate_average_correlation, + # process timeseries + denoised_img = denoise_nifti_voxel( + strategy, group_mask, standardize, smoothing_fwhm, img.path ) - connectomes[desc].append(correlation_matrix) - - # dump to h5 - flag = _set_file_flag(output_path) - with h5py.File(output_path, flag) as f: - group = _fetch_h5_group(f, subject, session) - timeseries_dset = group.create_dataset( - f"{attribute_name}_timeseries", data=time_series_atlas + # parse file name + subject, session, specifier = utils.parse_bids_name(img.path) + for desc, masker in atlas_maskers.items(): + attribute_name = ( + f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}" ) - timeseries_dset.attrs["RepetitionTime"] = img.entities[ - "RepetitionTime" - ] - group.create_dataset( - f"{attribute_name}_connectome", data=correlation_matrix + if not denoised_img: + time_series_atlas, correlation_matrix = None, None + + gc_log.info(f"{attribute_name}: no volume after scrubbing") + + progress.update(task, advance=1) + continue + + # extract timeseries and connectomes + ( + correlation_matrix, + time_series_atlas, + ) = generate_timeseries_connectomes( + masker, + denoised_img, + group_mask, + correlation_measure, + calculate_average_correlation, ) + connectomes[desc].append(correlation_matrix) + + # dump to h5 + flag = _set_file_flag(output_path) + with h5py.File(output_path, flag) as f: + group = _fetch_h5_group(f, subject, session) + timeseries_dset = group.create_dataset( + f"{attribute_name}_timeseries", data=time_series_atlas + ) + timeseries_dset.attrs["RepetitionTime"] = img.entities[ + "RepetitionTime" + ] + group.create_dataset( + f"{attribute_name}_connectome", data=correlation_matrix + ) + + progress.update(task, advance=1) gc_log.info(f"Saved to:\n{output_path}") diff --git a/giga_connectome/utils.py b/giga_connectome/utils.py index 74149bd..02ea57a 100644 --- a/giga_connectome/utils.py +++ b/giga_connectome/utils.py @@ -4,6 +4,17 @@ from bids.layout import Query from bids import BIDSLayout +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) + from giga_connectome.logger import gc_logger gc_log = gc_logger() @@ -192,3 +203,15 @@ def check_path(path: Path): f"Specified path already exists:\n\t{path}\n" "Old file will be overwritten" ) + + +def progress_bar(text: str, color: str = "green") -> Progress: + return Progress( + TextColumn(f"[{color}]{text}"), + SpinnerColumn("dots"), + TimeElapsedColumn(), + BarColumn(), + MofNCompleteColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + ) diff --git a/pyproject.toml b/pyproject.toml index 79e829c..df8608f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "nilearn >=0.10.2", "pybids >=0.15.0, <0.16.0", "templateflow < 23.0.0", - "tqdm", "setuptools", "jinja2 >= 2.0", "rich", diff --git a/requirements.txt b/requirements.txt index 00e9d5b..71e0b0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,6 @@ six==1.16.0 SQLAlchemy==1.3.24 templateflow==0.8.1 threadpoolctl==3.2.0 -tqdm==4.66.1 typing_extensions==4.9.0 tzdata==2023.3 urllib3==2.1.0