Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion duvidnn/autoclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
from .base.modelboxes import ModelBoxBase
from .base.modelbox_registry import DEFAULT_MODELBOX, MODELBOX_REGISTRY

class AutoModelBox:
from .utils.package_data import CACHE_DIR


class AutoModelBox:
_init_kwargs_file: str = ModelBoxBase._init_kwargs_filename
_model_config_file: str = ModelBoxBase._model_config_filename
_model_class_key: str = "class_name"
Expand All @@ -49,6 +51,7 @@ def from_pretrained(
cache_dir: Optional[str] = None,
**kwargs
) -> ModelBoxBase:
cache_dir = cache_dir or CACHE_DIR
config_file = cls._init_kwargs_file
config = load_checkpoint_file(
checkpoint=checkpoint,
Expand Down
2 changes: 2 additions & 0 deletions duvidnn/utils/package_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def _get_data_path(
Copies it from the package resources if not present.

"""
from carabiner import print_err
cache_dir = os.environ.get(
env_key,
os.path.expanduser(default),
)
os.makedirs(cache_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["HF_DATASETS_CACHE"] = cache_dir
print_err(f"[INFO] Cache directory set to HF_HOME={os.environ['HF_HOME']}, HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}.")
return cache_dir, os.path.join(cache_dir, filename)
Loading