diff --git a/src/heretic/main.py b/src/heretic/main.py index 016c3920..5c4747a5 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -10,6 +10,7 @@ from importlib.metadata import version from os.path import commonprefix from pathlib import Path +from typing import Any import huggingface_hub import optuna @@ -171,6 +172,12 @@ def run(): ) return + # Keep Hugging Face credentials in memory for this process only. + # We don't use huggingface_hub.login() because that stores the token on disk. + # Since this program will often be run on rented or shared GPU servers, + # it is better to not persist credentials. + hf_token = huggingface_hub.get_token() + # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py if torch.cuda.is_available(): count = torch.cuda.device_count() @@ -598,6 +605,48 @@ def count_completed_trials() -> int: if count_completed_trials() == settings.n_trials: study.set_user_attr("finished", True) + def print_hf_user_info(user: dict[str, Any]) -> None: + fullname = user.get( + "fullname", + user.get("name", "unknown user"), + ) + email = user.get("email", "no email found") + print(f"Logged in as [bold]{fullname} ({email})[/]") + + def validate_hf_token(t: str) -> dict[str, Any] | None: + try: + user = huggingface_hub.whoami(t) + print_hf_user_info(user) + return user + except huggingface_hub.errors.HfHubHTTPError as error: + print(f"[red]Failed to validate the Hugging Face token: ({error})[/]") + return None + + def authenticate_hf(token: str | None) -> tuple[dict[str, Any], str]: + # Try to use an existing token (from env, hf auth login, or a previous upload). + if token: + user = validate_hf_token(token) + if user: + choice = prompt_select( + "How do you want to proceed?", + ["Use this account", "Switch account"], + ) + if choice is None: + raise KeyboardInterrupt + if choice == "Use this account": + return user, token + # User chose "Switch account"; fall through to prompt for new token. + + # No valid token yet (first time, switch account, or validation failed). + # Prompt for a token until we get a valid one or the user cancels. + while True: + new_token = prompt_password("Hugging Face access token:") + if new_token is None: + raise KeyboardInterrupt + user = validate_hf_token(new_token) + if user: + return user, new_token + while True: # If no trials at all have been evaluated, the study must have been stopped # by pressing Ctrl+C while the first trial was running. In this case, we just @@ -765,23 +814,11 @@ def count_completed_trials() -> int: print(f"Model saved to [bold]{save_directory}[/].") case "Upload the model to Hugging Face": - # We don't use huggingface_hub.login() because that stores the token on disk, - # and since this program will often be run on rented or shared GPU servers, - # it's better to not persist credentials. - token = huggingface_hub.get_token() - if not token: - token = prompt_password("Hugging Face access token:") - if not token: + try: + user, hf_token = authenticate_hf(hf_token) + except KeyboardInterrupt: continue - user = huggingface_hub.whoami(token) - fullname = user.get( - "fullname", - user.get("name", "unknown user"), - ) - email = user.get("email", "no email found") - print(f"Logged in as [bold]{fullname} ({email})[/]") - repo_id = prompt_text( "Name of repository:", default=f"{user['name']}/{Path(settings.model).name}-heretic", @@ -805,7 +842,7 @@ def count_completed_trials() -> int: model.model.push_to_hub( repo_id, private=private, - token=token, + token=hf_token, ) else: print("Uploading merged model...") @@ -813,14 +850,14 @@ def count_completed_trials() -> int: merged_model.push_to_hub( repo_id, private=private, - token=token, + token=hf_token, ) del merged_model empty_cache() model.tokenizer.push_to_hub( repo_id, private=private, - token=token, + token=hf_token, ) # If the model path exists locally and includes the @@ -857,7 +894,7 @@ def count_completed_trials() -> int: ) + card.text ) - card.push_to_hub(repo_id, token=token) + card.push_to_hub(repo_id, token=hf_token) print(f"Model uploaded to [bold]{repo_id}[/].") diff --git a/src/heretic/utils.py b/src/heretic/utils.py index a0d5f35f..41d9f0f9 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -133,7 +133,7 @@ def prompt_path(message: str) -> str: return questionary.path(message, only_directories=True).ask() -def prompt_password(message: str) -> str: +def prompt_password(message: str) -> str | None: if is_notebook(): print() return getpass.getpass(message)