11import os
22import shutil
33import tarfile
4+ import tqdm
45import urllib .request
56from dataclasses import dataclass , field
6- from typing import Dict , Optional
7+ from typing import Dict , Optional , Callable
78from urllib .request import build_opener , install_opener
89
910from benchmark import DATASETS_DIR
@@ -54,7 +55,12 @@ def download(self):
5455
5556 if self .config .link :
5657 print (f"Downloading { self .config .link } ..." )
57- tmp_path , _ = urllib .request .urlretrieve (self .config .link )
58+ with tqdm .tqdm (
59+ unit = "B" , unit_scale = True , miniters = 1 , dynamic_ncols = True , disable = None
60+ ) as t :
61+ tmp_path , _ = urllib .request .urlretrieve (
62+ self .config .link , reporthook = _tqdm_reporthook (t )
63+ )
5864
5965 if self .config .link .endswith (".tgz" ) or self .config .link .endswith (
6066 ".tar.gz"
@@ -76,6 +82,15 @@ def get_reader(self, normalize: bool) -> BaseReader:
7682 return reader_class (DATASETS_DIR / self .config .path , normalize = normalize )
7783
7884
85+ def _tqdm_reporthook (t : tqdm .tqdm ) -> Callable [[int , int , int ], None ]:
86+ def reporthook (blocknum : int , block_size : int , total_size : int ) -> None :
87+ if total_size > 0 :
88+ t .total = total_size
89+ t .update (blocknum * block_size - t .n )
90+
91+ return reporthook
92+
93+
7994if __name__ == "__main__" :
8095 dataset = Dataset (
8196 {
0 commit comments