-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix: persist hf token between uploads #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
674d54b
9cb806f
8155767
3ccd69a
07bda90
f48fc7e
446ac06
0ac1272
6fc8619
cbd91c0
88ff356
6777111
6b7657c
f3e3766
cfddcde
c7d3942
59a3c3b
4a13925
184622d
87f95f7
8141315
3c10801
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| from importlib.metadata import version | ||
| from os.path import commonprefix | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import huggingface_hub | ||
| import optuna | ||
|
|
@@ -171,6 +172,12 @@ def run(): | |
| ) | ||
| return | ||
|
|
||
| # Keep Hugging Face credentials in memory for this process only. | ||
| # We don't use huggingface_hub.login() because that stores the token on disk. | ||
| # Since this program will often be run on rented or shared GPU servers, | ||
| # it is better to not persist credentials. | ||
| hf_token = huggingface_hub.get_token() | ||
|
|
||
| # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py | ||
| if torch.cuda.is_available(): | ||
| count = torch.cuda.device_count() | ||
|
|
@@ -598,6 +605,48 @@ def count_completed_trials() -> int: | |
| if count_completed_trials() == settings.n_trials: | ||
| study.set_user_attr("finished", True) | ||
|
|
||
| def print_hf_user_info(user: dict[str, Any]) -> None: | ||
| fullname = user.get( | ||
| "fullname", | ||
| user.get("name", "unknown user"), | ||
| ) | ||
| email = user.get("email", "no email found") | ||
| print(f"Logged in as [bold]{fullname} ({email})[/]") | ||
|
|
||
| def validate_hf_token(t: str) -> dict[str, Any] | None: | ||
| try: | ||
| user = huggingface_hub.whoami(t) | ||
| print_hf_user_info(user) | ||
| return user | ||
| except huggingface_hub.errors.HfHubHTTPError as error: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Surely this isn't the only possible error here? |
||
| print(f"[red]Failed to validate the Hugging Face token: ({error})[/]") | ||
| return None | ||
|
|
||
| def authenticate_hf(token: str | None) -> tuple[dict[str, Any], str]: | ||
|
red40maxxer marked this conversation as resolved.
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function should not accept |
||
| # Try to use an existing token (from env, hf auth login, or a previous upload). | ||
| if token: | ||
| user = validate_hf_token(token) | ||
| if user: | ||
| choice = prompt_select( | ||
| "How do you want to proceed?", | ||
| ["Use this account", "Switch account"], | ||
| ) | ||
| if choice is None: | ||
| raise KeyboardInterrupt | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not an acceptable mechanism for bailing out. |
||
| if choice == "Use this account": | ||
| return user, token | ||
| # User chose "Switch account"; fall through to prompt for new token. | ||
|
|
||
| # No valid token yet (first time, switch account, or validation failed). | ||
| # Prompt for a token until we get a valid one or the user cancels. | ||
| while True: | ||
| new_token = prompt_password("Hugging Face access token:") | ||
| if new_token is None: | ||
| raise KeyboardInterrupt | ||
| user = validate_hf_token(new_token) | ||
| if user: | ||
| return user, new_token | ||
|
|
||
| while True: | ||
| # If no trials at all have been evaluated, the study must have been stopped | ||
| # by pressing Ctrl+C while the first trial was running. In this case, we just | ||
|
|
@@ -765,23 +814,11 @@ def count_completed_trials() -> int: | |
| print(f"Model saved to [bold]{save_directory}[/].") | ||
|
|
||
| case "Upload the model to Hugging Face": | ||
| # We don't use huggingface_hub.login() because that stores the token on disk, | ||
| # and since this program will often be run on rented or shared GPU servers, | ||
| # it's better to not persist credentials. | ||
| token = huggingface_hub.get_token() | ||
| if not token: | ||
| token = prompt_password("Hugging Face access token:") | ||
| if not token: | ||
| try: | ||
| user, hf_token = authenticate_hf(hf_token) | ||
| except KeyboardInterrupt: | ||
| continue | ||
|
|
||
| user = huggingface_hub.whoami(token) | ||
| fullname = user.get( | ||
| "fullname", | ||
| user.get("name", "unknown user"), | ||
| ) | ||
| email = user.get("email", "no email found") | ||
| print(f"Logged in as [bold]{fullname} ({email})[/]") | ||
|
|
||
| repo_id = prompt_text( | ||
| "Name of repository:", | ||
| default=f"{user['name']}/{Path(settings.model).name}-heretic", | ||
|
|
@@ -805,22 +842,22 @@ def count_completed_trials() -> int: | |
| model.model.push_to_hub( | ||
| repo_id, | ||
| private=private, | ||
| token=token, | ||
| token=hf_token, | ||
| ) | ||
| else: | ||
| print("Uploading merged model...") | ||
| merged_model = model.get_merged_model() | ||
| merged_model.push_to_hub( | ||
| repo_id, | ||
| private=private, | ||
| token=token, | ||
| token=hf_token, | ||
| ) | ||
| del merged_model | ||
| empty_cache() | ||
| model.tokenizer.push_to_hub( | ||
| repo_id, | ||
| private=private, | ||
| token=token, | ||
| token=hf_token, | ||
| ) | ||
|
|
||
| # If the model path exists locally and includes the | ||
|
|
@@ -857,7 +894,7 @@ def count_completed_trials() -> int: | |
| ) | ||
| + card.text | ||
| ) | ||
| card.push_to_hub(repo_id, token=token) | ||
| card.push_to_hub(repo_id, token=hf_token) | ||
|
|
||
| print(f"Model uploaded to [bold]{repo_id}[/].") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,7 +133,7 @@ def prompt_path(message: str) -> str: | |
| return questionary.path(message, only_directories=True).ask() | ||
|
|
||
|
|
||
| def prompt_password(message: str) -> str: | ||
| def prompt_password(message: str) -> str | None: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the user CTRL+Cs during password entry,
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is true for all other |
||
| if is_notebook(): | ||
| print() | ||
| return getpass.getpass(message) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A function called "validate" shouldn't print anything.