Skip to content

Commit

Permalink
switch tqdm for rich progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau committed Jan 22, 2024
1 parent 22a4ae0 commit 0715106
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 60 deletions.
39 changes: 24 additions & 15 deletions giga_connectome/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Union, List

from pathlib import Path
from tqdm import tqdm
import nibabel as nib
from nilearn.image import resample_to_img
from nibabel import Nifti1Image
from pkg_resources import resource_filename

from giga_connectome.logger import gc_logger
from giga_connectome.utils import progress_bar

gc_log = gc_logger()

Expand Down Expand Up @@ -110,21 +110,30 @@ def resample_atlas_collection(
"""
gc_log.info("Resample atlas to group grey matter mask.")
resampled_atlases = []
for desc in tqdm(atlas_config["file_paths"]):
parcellation = atlas_config["file_paths"][desc]
parcellation_resampled = resample_to_img(
parcellation, group_mask, interpolation="nearest"
)
filename = (
f"tpl-{template}_"
f"atlas-{atlas_config['name']}_"
"res-dataset_"
f"desc-{desc}_"
f"{atlas_config['type']}.nii.gz"

with progress_bar(text="Resampling atlases") as progress:
task = progress.add_task(
description="resampling", total=len(atlas_config["file_paths"])
)
save_path = group_mask_dir / filename
nib.save(parcellation_resampled, save_path)
resampled_atlases.append(save_path)

for desc in atlas_config["file_paths"]:
parcellation = atlas_config["file_paths"][desc]
parcellation_resampled = resample_to_img(
parcellation, group_mask, interpolation="nearest"
)
filename = (
f"tpl-{template}_"
f"atlas-{atlas_config['name']}_"
"res-dataset_"
f"desc-{desc}_"
f"{atlas_config['type']}.nii.gz"
)
save_path = group_mask_dir / filename
nib.save(parcellation_resampled, save_path)
resampled_atlases.append(save_path)

progress.update(task, advance=1)

return resampled_atlases


Expand Down
94 changes: 51 additions & 43 deletions giga_connectome/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path

import h5py
from tqdm import tqdm
import numpy as np
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import NiftiLabelsMasker, NiftiMapsMasker
Expand All @@ -12,6 +11,7 @@
from giga_connectome.connectome import generate_timeseries_connectomes
from giga_connectome.denoise import denoise_nifti_voxel
from giga_connectome.logger import gc_logger
from giga_connectome.utils import progress_bar

gc_log = gc_logger()

Expand Down Expand Up @@ -101,53 +101,61 @@ def run_postprocessing_dataset(
)

# transform data
gc_log.info("Processing subject")
with progress_bar(text="Processing subject") as progress:
task = progress.add_task(
description="processing subject", total=len(images)
)

for img in tqdm(images):
print()
gc_log.info(f"Processing image:\n{img.filename}")
for img in images:
print()
gc_log.info(f"Processing image:\n{img.filename}")

# process timeseries
denoised_img = denoise_nifti_voxel(
strategy, group_mask, standardize, smoothing_fwhm, img.path
)
# parse file name
subject, session, specifier = utils.parse_bids_name(img.path)
for desc, masker in atlas_maskers.items():
attribute_name = f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}"
if not denoised_img:
time_series_atlas, correlation_matrix = None, None

gc_log.info(f"{attribute_name}: no volume after scrubbing")

continue

# extract timeseries and connectomes
(
correlation_matrix,
time_series_atlas,
) = generate_timeseries_connectomes(
masker,
denoised_img,
group_mask,
correlation_measure,
calculate_average_correlation,
# process timeseries
denoised_img = denoise_nifti_voxel(
strategy, group_mask, standardize, smoothing_fwhm, img.path
)
connectomes[desc].append(correlation_matrix)

# dump to h5
flag = _set_file_flag(output_path)
with h5py.File(output_path, flag) as f:
group = _fetch_h5_group(f, subject, session)
timeseries_dset = group.create_dataset(
f"{attribute_name}_timeseries", data=time_series_atlas
# parse file name
subject, session, specifier = utils.parse_bids_name(img.path)
for desc, masker in atlas_maskers.items():
attribute_name = (
f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}"
)
timeseries_dset.attrs["RepetitionTime"] = img.entities[
"RepetitionTime"
]
group.create_dataset(
f"{attribute_name}_connectome", data=correlation_matrix
if not denoised_img:
time_series_atlas, correlation_matrix = None, None

gc_log.info(f"{attribute_name}: no volume after scrubbing")

progress.update(task, advance=1)
continue

# extract timeseries and connectomes
(
correlation_matrix,
time_series_atlas,
) = generate_timeseries_connectomes(
masker,
denoised_img,
group_mask,
correlation_measure,
calculate_average_correlation,
)
connectomes[desc].append(correlation_matrix)

# dump to h5
flag = _set_file_flag(output_path)
with h5py.File(output_path, flag) as f:
group = _fetch_h5_group(f, subject, session)
timeseries_dset = group.create_dataset(
f"{attribute_name}_timeseries", data=time_series_atlas
)
timeseries_dset.attrs["RepetitionTime"] = img.entities[
"RepetitionTime"
]
group.create_dataset(
f"{attribute_name}_connectome", data=correlation_matrix
)

progress.update(task, advance=1)

gc_log.info(f"Saved to:\n{output_path}")

Expand Down
23 changes: 23 additions & 0 deletions giga_connectome/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
from bids.layout import Query
from bids import BIDSLayout

from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

from giga_connectome.logger import gc_logger

gc_log = gc_logger()
Expand Down Expand Up @@ -190,3 +201,15 @@ def check_path(path: Path):
f"Specified path already exists:\n\t{path}\n"
"Old file will be overwritten"
)


def progress_bar(text: str, color: str = "green") -> Progress:
return Progress(
TextColumn(f"[{color}]{text}"),
SpinnerColumn("dots"),
TimeElapsedColumn(),
BarColumn(),
MofNCompleteColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"nilearn >=0.10.2",
"pybids >=0.15.0, <0.16.0",
"templateflow < 23.0.0",
"tqdm",
"setuptools",
"jinja2 >= 2.0",
"rich",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ six==1.16.0
SQLAlchemy==1.3.24
templateflow==0.8.1
threadpoolctl==3.2.0
tqdm==4.66.1
typing_extensions==4.9.0
tzdata==2023.3
urllib3==2.1.0
Expand Down

0 comments on commit 0715106

Please sign in to comment.