diff --git a/benchmark/dataset.py b/benchmark/dataset.py index 76bdad99..593cad42 100644 --- a/benchmark/dataset.py +++ b/benchmark/dataset.py @@ -3,9 +3,11 @@ import tarfile import urllib.request from dataclasses import dataclass, field -from typing import Dict, Optional +from typing import Callable, Dict, Optional from urllib.request import build_opener, install_opener +import tqdm + from benchmark import DATASETS_DIR from dataset_reader.ann_compound_reader import AnnCompoundReader from dataset_reader.ann_h5_reader import AnnH5Reader @@ -54,7 +56,12 @@ def download(self): if self.config.link: print(f"Downloading {self.config.link}...") - tmp_path, _ = urllib.request.urlretrieve(self.config.link) + with tqdm.tqdm( + unit="B", unit_scale=True, miniters=1, dynamic_ncols=True, disable=None + ) as t: + tmp_path, _ = urllib.request.urlretrieve( + self.config.link, reporthook=_tqdm_reporthook(t) + ) if self.config.link.endswith(".tgz") or self.config.link.endswith( ".tar.gz" @@ -76,6 +83,15 @@ def get_reader(self, normalize: bool) -> BaseReader: return reader_class(DATASETS_DIR / self.config.path, normalize=normalize) +def _tqdm_reporthook(t: tqdm.tqdm) -> Callable[[int, int, int], None]: + def reporthook(blocknum: int, block_size: int, total_size: int) -> None: + if total_size > 0: + t.total = total_size + t.update(blocknum * block_size - t.n) + + return reporthook + + if __name__ == "__main__": dataset = Dataset( {