forked from bitsandbytes-foundation/bitsandbytes
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
intermediate commit to save work (will likely choose other approach)
- Loading branch information
1 parent
1f2ca43
commit 3c4e1c0
Showing
1 changed file
with
360 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,360 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import atexit | ||
import datetime | ||
import json | ||
import os | ||
import subprocess | ||
import sys | ||
import tempfile | ||
import urllib.request | ||
import zipfile | ||
|
||
GITHUB_PAT = os.environ.get("GITHUB_PAT") | ||
REST_REQUEST_COUNT = 0 | ||
|
||
|
||
def print_rest_request_count(): | ||
print(f"\nTotal GitHub REST API requests made: {REST_REQUEST_COUNT}") | ||
|
||
|
||
atexit.register(print_rest_request_count) | ||
|
||
|
||
def get_artifacts_json(page=1, per_page=100): | ||
repo_owner = "TimDettmers" | ||
repo_name = "bitsandbytes" | ||
api_url = ( | ||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/actions/artifacts?per_page={per_page}&page={page}" | ||
) | ||
|
||
try: | ||
request = urllib.request.Request(api_url) | ||
if GITHUB_PAT: | ||
request.add_header("Authorization", f"token {GITHUB_PAT}") | ||
with urllib.request.urlopen(request) as response: | ||
global REST_REQUEST_COUNT | ||
REST_REQUEST_COUNT += 1 | ||
artifacts_data = response.read().decode("utf-8") | ||
return json.loads(artifacts_data) | ||
|
||
except urllib.error.HTTPError as e: | ||
if e.code == 403 and "rate limit exceeded" in str(e.reason): | ||
print("Error: GitHub API rate limit exceeded.") | ||
print("To increase your rate limit, create a Personal Access Token (PAT) with 'public_repo' scope:") | ||
print("1. Go to https://github.com/settings/tokens") | ||
print("2. Click 'Generate new token' and select 'public_repo' scope") | ||
print("3. Set the GITHUB_PAT environment variable with your new token") | ||
sys.exit(1) | ||
else: | ||
print(f"HTTP Error {e.code}: {e.reason}") | ||
sys.exit(1) | ||
|
||
except urllib.error.URLError as e: | ||
print(f"Error fetching artifacts: {e}") | ||
sys.exit(1) | ||
|
||
except json.JSONDecodeError: | ||
print("Error: Invalid JSON response from GitHub API") | ||
sys.exit(1) | ||
|
||
|
||
def parse_artifact_info(artifact): | ||
name = artifact["name"] | ||
branch = artifact["workflow_run"]["head_branch"] if artifact.get("workflow_run") else "Unknown" | ||
commit_sha = artifact["workflow_run"]["head_sha"] if artifact.get("workflow_run") else "Unknown" | ||
return name, branch, commit_sha | ||
|
||
|
||
def is_compatible_platform(artifact_name, current_platform): | ||
return current_platform in artifact_name | ||
|
||
|
||
def get_all_wheel_artifacts(branch="main", platform=None, show_all_platforms=False, limit=None): | ||
page = 1 | ||
per_page = 100 | ||
all_wheel_artifacts = [] | ||
current_platform = get_platform_string() | ||
last_commit = None | ||
|
||
while True: | ||
print(f"Fetching page {page} of artifacts...") | ||
artifacts_data = get_artifacts_json(page, per_page) | ||
wheel_artifacts = [a for a in artifacts_data["artifacts"] if a["name"].startswith("bdist_wheel_")] | ||
|
||
for artifact in wheel_artifacts: | ||
artifact_name, artifact_branch, commit_sha = parse_artifact_info(artifact) | ||
|
||
if artifact_branch == branch: | ||
if show_all_platforms or is_compatible_platform(artifact_name, current_platform): | ||
if last_commit is None: | ||
last_commit = commit_sha | ||
elif commit_sha != last_commit and show_all_platforms: | ||
return all_wheel_artifacts | ||
|
||
all_wheel_artifacts.append(artifact) | ||
print(f"Found wheel: {artifact_name} (Branch: {artifact_branch}, Commit: {commit_sha[:7]})") | ||
|
||
if not show_all_platforms and len(all_wheel_artifacts) == 1: | ||
return all_wheel_artifacts | ||
|
||
if limit and len(all_wheel_artifacts) >= limit: | ||
return all_wheel_artifacts | ||
|
||
if len(artifacts_data["artifacts"]) < per_page: | ||
break | ||
page += 1 | ||
|
||
return all_wheel_artifacts | ||
|
||
|
||
def get_most_recent_wheels(branch="multi-backend-refactor"): | ||
all_artifacts = get_all_wheel_artifacts(branch=branch, show_all_platforms=True) | ||
most_recent_wheels = {} | ||
|
||
for artifact in all_artifacts: | ||
platform, artifact_branch, _ = parse_artifact_info(artifact) | ||
if artifact_branch == branch: | ||
key = platform | ||
if key not in most_recent_wheels or artifact["created_at"] > most_recent_wheels[key]["created_at"]: | ||
most_recent_wheels[key] = artifact | ||
|
||
return most_recent_wheels | ||
|
||
|
||
def get_recent_wheels(branch, limit=7): | ||
return get_all_wheel_artifacts(branch=branch, show_all_platforms=True, limit=limit) | ||
|
||
|
||
def print_most_recent_wheels(wheels, branch, show_json=False): | ||
print(f"\nMost recent wheels for branch '{branch}':") | ||
for platform, artifact in wheels.items(): | ||
_, python_version, _, commit_sha = parse_artifact_info(artifact) | ||
print(f"Platform: {platform}") | ||
print(f"Name: {artifact['name']}") | ||
print(f"Python Version: {python_version}") | ||
print(f"Created at: {artifact['created_at']}") | ||
print(f"Commit: {commit_sha[:7]}") | ||
print(f"URL: {artifact['archive_download_url']}") | ||
|
||
if show_json: | ||
print("Raw JSON:") | ||
print(json.dumps(artifact, indent=2)) | ||
|
||
print("-" * 50) | ||
|
||
|
||
def get_platform_string(): | ||
if sys.platform.startswith("win"): | ||
return "windows-latest_x86_64" | ||
elif sys.platform.startswith("darwin"): | ||
return "macos-latest_x86_64" if sys.platform.machine() != "arm64" else "macos-latest_aarch64" | ||
elif sys.platform.startswith("linux"): | ||
return "ubuntu-latest_x86_64" | ||
else: | ||
raise NotImplementedError(f"Unsupported platform: {sys.platform}") | ||
|
||
|
||
def print_artifacts_info(artifacts, show_json=False): | ||
print("\nWheel artifacts from 'build-wheel' job:") | ||
print("=" * 50) | ||
for artifact in artifacts: | ||
artifact_platform, artifact_branch, commit_sha = parse_artifact_info(artifact) | ||
|
||
print(f"Name: {artifact['name']}") | ||
print(f"ID: {artifact['id']}") | ||
print(f"Size: {artifact['size_in_bytes']} bytes") | ||
print(f"Created at: {artifact['created_at']}") | ||
print(f"URL: {artifact['archive_download_url']}") | ||
print(f"Platform: {artifact_platform}") | ||
print(f"Branch: {artifact_branch}") | ||
print(f"Commit: {commit_sha[:7]}") | ||
|
||
if show_json: | ||
print("Raw JSON:") | ||
print(json.dumps(artifact, indent=2)) | ||
|
||
print("-" * 50) | ||
|
||
|
||
def get_latest_artifact_url(artifacts, branch, platform=None): | ||
platform = platform or get_platform_string() | ||
for artifact in artifacts: | ||
artifact_platform, _, artifact_branch, _ = parse_artifact_info(artifact) | ||
if artifact_branch == branch and is_compatible_platform(artifact_platform, platform): | ||
return artifact["archive_download_url"] | ||
|
||
print(f"No artifact found for branch '{branch}' and platform '{platform}'") | ||
sys.exit(1) | ||
|
||
|
||
def download_and_extract_artifact(url, temp_dir): | ||
artifact_path = os.path.join(temp_dir, "artifact.zip") | ||
|
||
try: | ||
request = urllib.request.Request(url) | ||
# if GITHUB_PAT: | ||
# request.add_header("Authorization", f"token {GITHUB_PAT}") | ||
|
||
with urllib.request.urlopen(request) as response, open(artifact_path, "wb") as artifact_file: | ||
global REST_REQUEST_COUNT | ||
REST_REQUEST_COUNT += 1 | ||
artifact_file.write(response.read()) | ||
|
||
with zipfile.ZipFile(artifact_path, "r") as artifact_zip: | ||
artifact_zip.extractall(temp_dir) | ||
|
||
except urllib.error.HTTPError as e: | ||
print(f"HTTP Error {e.code}: {e.reason}") | ||
print(f"Error occurred while trying to download: {url}") | ||
if e.code == 403: | ||
print("This could be due to an expired artifact or insufficient permissions.") | ||
sys.exit(1) | ||
|
||
except urllib.error.URLError as e: | ||
print(f"Error downloading artifact: {e}") | ||
print(f"URL attempted: {url}") | ||
sys.exit(1) | ||
|
||
except zipfile.BadZipFile: | ||
print("Error: Downloaded file is not a valid zip file.") | ||
print(f"This might indicate that the artifact at {url} has expired or is corrupted.") | ||
sys.exit(1) | ||
|
||
|
||
def get_wheel_file(temp_dir): | ||
wheel_files = [f for f in os.listdir(temp_dir) if f.endswith(".whl")] | ||
|
||
if not wheel_files: | ||
print("No wheel file found in the artifact") | ||
sys.exit(1) | ||
|
||
platform = get_platform_string() | ||
|
||
for wheel_file in wheel_files: | ||
if platform in wheel_file: | ||
return os.path.join(temp_dir, wheel_file) | ||
|
||
print(f"No wheel file found for platform: {platform}") | ||
sys.exit(1) | ||
|
||
|
||
def install_wheel(wheel_path): | ||
try: | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_path]) | ||
except subprocess.CalledProcessError as e: | ||
print(f"Error installing wheel: {e}") | ||
sys.exit(1) | ||
|
||
|
||
def get_rate_limit(): | ||
url = "https://api.github.com/rate_limit" | ||
request = urllib.request.Request(url) | ||
if GITHUB_PAT: | ||
request.add_header("Authorization", f"token {GITHUB_PAT}") | ||
request.add_header("Accept", "application/vnd.github+json") | ||
request.add_header("X-GitHub-Api-Version", "2022-11-28") | ||
|
||
try: | ||
with urllib.request.urlopen(request) as response: | ||
global REST_REQUEST_COUNT | ||
REST_REQUEST_COUNT += 1 | ||
data = json.loads(response.read().decode()) | ||
return data["rate"] | ||
except urllib.error.HTTPError as e: | ||
print(f"HTTP Error {e.code}: {e.reason}") | ||
return None | ||
except Exception as e: | ||
print(f"Error fetching rate limit: {e}") | ||
return None | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Install or list bitsandbytes wheel artifacts") | ||
parser.add_argument("--rate-limit", action="store_true", help="Display the current GitHub API rate limit") | ||
parser.add_argument("--list", action="store_true", help="List available wheel artifacts") | ||
parser.add_argument("--all", action="store_true", help="Show all available wheels, regardless of platform") | ||
parser.add_argument( | ||
"--branch", | ||
default="multi-backend-refactor", | ||
help="Specify the branch to use (default: multi-backend-refactor)", | ||
) | ||
parser.add_argument("--main", action="store_true", help="Use the 'main' branch") | ||
parser.add_argument( | ||
"--mpr", "--multi-backend-refactor", action="store_true", help="Use the 'multi-backend-refactor' branch" | ||
) | ||
parser.add_argument("--url", help="Specify a direct URL to a wheel artifact") | ||
parser.add_argument( | ||
"--install", action="store_true", help="Install the wheel (default action if no other option is specified)" | ||
) | ||
parser.add_argument("--json", action="store_true", help="Output raw JSON for each artifact") | ||
parser.add_argument( | ||
"--most-recent", | ||
action="store_true", | ||
help="Get the most recent wheels for all platforms for the specified branch", | ||
) | ||
parser.add_argument("--more", action="store_true", help="Show the 7 most recent wheels") | ||
|
||
args = parser.parse_args() | ||
|
||
if not GITHUB_PAT: | ||
print("Warning: GITHUB_PAT environment variable not set. You may encounter rate limiting.") | ||
|
||
if args.rate_limit: | ||
rate_limit = get_rate_limit() | ||
if rate_limit: | ||
print("GitHub API Rate Limit:") | ||
print(f"Limit: {rate_limit['limit']}") | ||
print(f"Remaining: {rate_limit['remaining']}") | ||
print(f"Reset time: {datetime.datetime.fromtimestamp(rate_limit['reset']).strftime('%Y-%m-%d %H:%M:%S')}") | ||
|
||
sys.exit(0) | ||
|
||
if args.main: | ||
branch = "main" | ||
elif args.mpr: | ||
branch = "multi-backend-refactor" | ||
else: | ||
branch = args.branch | ||
|
||
print(f"Configured branch for wheel search: {branch}") | ||
|
||
if args.more: | ||
recent_wheels = get_recent_wheels(branch) | ||
print_artifacts_info(recent_wheels, show_json=args.json) | ||
|
||
elif args.most_recent: | ||
most_recent_wheels = get_most_recent_wheels(branch) | ||
print_most_recent_wheels(most_recent_wheels, branch, show_json=args.json) | ||
|
||
elif args.list or args.all: | ||
all_wheel_artifacts = get_all_wheel_artifacts(branch, show_all_platforms=args.all) | ||
print_artifacts_info(all_wheel_artifacts, show_json=args.json) | ||
if all_wheel_artifacts: | ||
latest_artifact = all_wheel_artifacts[0] | ||
print("\nTo install the most recent wheel for your platform, run:") | ||
print(f"python {sys.argv[0]} --url {latest_artifact['archive_download_url']}") | ||
|
||
elif args.url: | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
download_and_extract_artifact(args.url, temp_dir) | ||
wheel_path = get_wheel_file(temp_dir) | ||
install_wheel(wheel_path) | ||
print("Installation completed successfully!") | ||
|
||
elif args.install or (not args.list and not args.url and not args.most_recent and not args.more): | ||
all_wheel_artifacts = get_all_wheel_artifacts(branch) | ||
if all_wheel_artifacts: | ||
url = all_wheel_artifacts[0]["archive_download_url"] | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
download_and_extract_artifact(url, temp_dir) | ||
wheel_path = get_wheel_file(temp_dir) | ||
install_wheel(wheel_path) | ||
print("Installation completed successfully!") | ||
else: | ||
print(f"No suitable wheel found for branch '{branch}' and platform '{get_platform_string()}'") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |