diff --git a/duvidnn/autoclass.py b/duvidnn/autoclass.py index e6ec22a..5ec933e 100644 --- a/duvidnn/autoclass.py +++ b/duvidnn/autoclass.py @@ -25,8 +25,10 @@ from .base.modelboxes import ModelBoxBase from .base.modelbox_registry import DEFAULT_MODELBOX, MODELBOX_REGISTRY -class AutoModelBox: +from .utils.package_data import CACHE_DIR + +class AutoModelBox: _init_kwargs_file: str = ModelBoxBase._init_kwargs_filename _model_config_file: str = ModelBoxBase._model_config_filename _model_class_key: str = "class_name" @@ -49,6 +51,7 @@ def from_pretrained( cache_dir: Optional[str] = None, **kwargs ) -> ModelBoxBase: + cache_dir = cache_dir or CACHE_DIR config_file = cls._init_kwargs_file config = load_checkpoint_file( checkpoint=checkpoint, diff --git a/duvidnn/utils/package_data.py b/duvidnn/utils/package_data.py index 3a10656..c585590 100644 --- a/duvidnn/utils/package_data.py +++ b/duvidnn/utils/package_data.py @@ -25,6 +25,7 @@ def _get_data_path( Copies it from the package resources if not present. """ + from carabiner import print_err cache_dir = os.environ.get( env_key, os.path.expanduser(default), @@ -32,4 +33,5 @@ def _get_data_path( os.makedirs(cache_dir, exist_ok=True) os.environ["HF_HOME"] = cache_dir os.environ["HF_DATASETS_CACHE"] = cache_dir + print_err(f"[INFO] Cache directory set to HF_HOME={os.environ['HF_HOME']}, HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}.") return cache_dir, os.path.join(cache_dir, filename)