Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] switch tqdm for rich progress bar #95

Merged
merged 4 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ git checkout -b your_branch

4. Commit your changes on this branch.

If you want to make sure all the tests will be run by github continuous integration,
make sure that your commit message contains `full_test`.

5. Run the tests locally; you can run spectfic tests to speed up the process:

```bash
Expand Down
39 changes: 24 additions & 15 deletions giga_connectome/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Union, List

from pathlib import Path
from tqdm import tqdm
import nibabel as nib
from nilearn.image import resample_to_img
from nibabel import Nifti1Image
from pkg_resources import resource_filename

from giga_connectome.logger import gc_logger
from giga_connectome.utils import progress_bar

gc_log = gc_logger()

Expand Down Expand Up @@ -110,21 +110,30 @@ def resample_atlas_collection(
"""
gc_log.info("Resample atlas to group grey matter mask.")
resampled_atlases = []
for desc in tqdm(atlas_config["file_paths"]):
parcellation = atlas_config["file_paths"][desc]
parcellation_resampled = resample_to_img(
parcellation, group_mask, interpolation="nearest"
)
filename = (
f"tpl-{template}_"
f"atlas-{atlas_config['name']}_"
"res-dataset_"
f"desc-{desc}_"
f"{atlas_config['type']}.nii.gz"

with progress_bar(text="Resampling atlases") as progress:
task = progress.add_task(
description="resampling", total=len(atlas_config["file_paths"])
)
save_path = group_mask_dir / filename
nib.save(parcellation_resampled, save_path)
resampled_atlases.append(save_path)

for desc in atlas_config["file_paths"]:
parcellation = atlas_config["file_paths"][desc]
parcellation_resampled = resample_to_img(
parcellation, group_mask, interpolation="nearest"
)
filename = (
f"tpl-{template}_"
f"atlas-{atlas_config['name']}_"
"res-dataset_"
f"desc-{desc}_"
f"{atlas_config['type']}.nii.gz"
)
save_path = group_mask_dir / filename
nib.save(parcellation_resampled, save_path)
resampled_atlases.append(save_path)

progress.update(task, advance=1)

return resampled_atlases


Expand Down
1 change: 0 additions & 1 deletion giga_connectome/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def generate_method_section(
average_correlation: bool,
analysis_level: bool,
) -> None:

env = Environment(
loader=FileSystemLoader(Path(__file__).parent),
autoescape=select_autoescape(),
Expand Down
94 changes: 51 additions & 43 deletions giga_connectome/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path

import h5py
from tqdm import tqdm
import numpy as np
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import NiftiLabelsMasker, NiftiMapsMasker
Expand All @@ -12,6 +11,7 @@
from giga_connectome.connectome import generate_timeseries_connectomes
from giga_connectome.denoise import denoise_nifti_voxel
from giga_connectome.logger import gc_logger
from giga_connectome.utils import progress_bar

gc_log = gc_logger()

Expand Down Expand Up @@ -101,53 +101,61 @@ def run_postprocessing_dataset(
)

# transform data
gc_log.info("Processing subject")
with progress_bar(text="Processing subject") as progress:
task = progress.add_task(
description="processing subject", total=len(images)
)

for img in tqdm(images):
print()
gc_log.info(f"Processing image:\n{img.filename}")
for img in images:
print()
gc_log.info(f"Processing image:\n{img.filename}")

# process timeseries
denoised_img = denoise_nifti_voxel(
strategy, group_mask, standardize, smoothing_fwhm, img.path
)
# parse file name
subject, session, specifier = utils.parse_bids_name(img.path)
for desc, masker in atlas_maskers.items():
attribute_name = f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}"
if not denoised_img:
time_series_atlas, correlation_matrix = None, None

gc_log.info(f"{attribute_name}: no volume after scrubbing")

continue

# extract timeseries and connectomes
(
correlation_matrix,
time_series_atlas,
) = generate_timeseries_connectomes(
masker,
denoised_img,
group_mask,
correlation_measure,
calculate_average_correlation,
# process timeseries
denoised_img = denoise_nifti_voxel(
strategy, group_mask, standardize, smoothing_fwhm, img.path
)
connectomes[desc].append(correlation_matrix)

# dump to h5
flag = _set_file_flag(output_path)
with h5py.File(output_path, flag) as f:
group = _fetch_h5_group(f, subject, session)
timeseries_dset = group.create_dataset(
f"{attribute_name}_timeseries", data=time_series_atlas
# parse file name
subject, session, specifier = utils.parse_bids_name(img.path)
for desc, masker in atlas_maskers.items():
attribute_name = (
f"{subject}_{specifier}_atlas-{atlas}_desc-{desc}"
)
timeseries_dset.attrs["RepetitionTime"] = img.entities[
"RepetitionTime"
]
group.create_dataset(
f"{attribute_name}_connectome", data=correlation_matrix
if not denoised_img:
time_series_atlas, correlation_matrix = None, None

gc_log.info(f"{attribute_name}: no volume after scrubbing")

progress.update(task, advance=1)
continue

# extract timeseries and connectomes
(
correlation_matrix,
time_series_atlas,
) = generate_timeseries_connectomes(
masker,
denoised_img,
group_mask,
correlation_measure,
calculate_average_correlation,
)
connectomes[desc].append(correlation_matrix)

# dump to h5
flag = _set_file_flag(output_path)
with h5py.File(output_path, flag) as f:
group = _fetch_h5_group(f, subject, session)
timeseries_dset = group.create_dataset(
f"{attribute_name}_timeseries", data=time_series_atlas
)
timeseries_dset.attrs["RepetitionTime"] = img.entities[
"RepetitionTime"
]
group.create_dataset(
f"{attribute_name}_connectome", data=correlation_matrix
)

progress.update(task, advance=1)

gc_log.info(f"Saved to:\n{output_path}")

Expand Down
23 changes: 23 additions & 0 deletions giga_connectome/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
from bids.layout import Query
from bids import BIDSLayout

from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

from giga_connectome.logger import gc_logger

gc_log = gc_logger()
Expand Down Expand Up @@ -192,3 +203,15 @@ def check_path(path: Path):
f"Specified path already exists:\n\t{path}\n"
"Old file will be overwritten"
)


def progress_bar(text: str, color: str = "green") -> Progress:
return Progress(
TextColumn(f"[{color}]{text}"),
SpinnerColumn("dots"),
TimeElapsedColumn(),
BarColumn(),
MofNCompleteColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"nilearn >=0.10.2",
"pybids >=0.15.0, <0.16.0",
"templateflow < 23.0.0",
"tqdm",
"setuptools",
"jinja2 >= 2.0",
"rich",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ six==1.16.0
SQLAlchemy==1.3.24
templateflow==0.8.1
threadpoolctl==3.2.0
tqdm==4.66.1
typing_extensions==4.9.0
tzdata==2023.3
urllib3==2.1.0
Expand Down
Loading