Skip to content

Commit

Permalink
intermediate commit to save work (will likely choose other approach)
Browse files Browse the repository at this point in the history
  • Loading branch information
Titus-von-Koeller committed Jul 22, 2024
1 parent 1f2ca43 commit 3c4e1c0
Showing 1 changed file with 360 additions and 0 deletions.
360 changes: 360 additions & 0 deletions scripts/install_preview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
#!/usr/bin/env python3

import argparse
import atexit
import datetime
import json
import os
import subprocess
import sys
import tempfile
import urllib.request
import zipfile

GITHUB_PAT = os.environ.get("GITHUB_PAT")
REST_REQUEST_COUNT = 0


def print_rest_request_count():
print(f"\nTotal GitHub REST API requests made: {REST_REQUEST_COUNT}")


atexit.register(print_rest_request_count)


def get_artifacts_json(page=1, per_page=100):
repo_owner = "TimDettmers"
repo_name = "bitsandbytes"
api_url = (
f"https://api.github.com/repos/{repo_owner}/{repo_name}/actions/artifacts?per_page={per_page}&page={page}"
)

try:
request = urllib.request.Request(api_url)
if GITHUB_PAT:
request.add_header("Authorization", f"token {GITHUB_PAT}")
with urllib.request.urlopen(request) as response:
global REST_REQUEST_COUNT
REST_REQUEST_COUNT += 1
artifacts_data = response.read().decode("utf-8")
return json.loads(artifacts_data)

except urllib.error.HTTPError as e:
if e.code == 403 and "rate limit exceeded" in str(e.reason):
print("Error: GitHub API rate limit exceeded.")
print("To increase your rate limit, create a Personal Access Token (PAT) with 'public_repo' scope:")
print("1. Go to https://github.com/settings/tokens")
print("2. Click 'Generate new token' and select 'public_repo' scope")
print("3. Set the GITHUB_PAT environment variable with your new token")
sys.exit(1)
else:
print(f"HTTP Error {e.code}: {e.reason}")
sys.exit(1)

except urllib.error.URLError as e:
print(f"Error fetching artifacts: {e}")
sys.exit(1)

except json.JSONDecodeError:
print("Error: Invalid JSON response from GitHub API")
sys.exit(1)


def parse_artifact_info(artifact):
name = artifact["name"]
branch = artifact["workflow_run"]["head_branch"] if artifact.get("workflow_run") else "Unknown"
commit_sha = artifact["workflow_run"]["head_sha"] if artifact.get("workflow_run") else "Unknown"
return name, branch, commit_sha


def is_compatible_platform(artifact_name, current_platform):
return current_platform in artifact_name


def get_all_wheel_artifacts(branch="main", platform=None, show_all_platforms=False, limit=None):
page = 1
per_page = 100
all_wheel_artifacts = []
current_platform = get_platform_string()
last_commit = None

while True:
print(f"Fetching page {page} of artifacts...")
artifacts_data = get_artifacts_json(page, per_page)
wheel_artifacts = [a for a in artifacts_data["artifacts"] if a["name"].startswith("bdist_wheel_")]

for artifact in wheel_artifacts:
artifact_name, artifact_branch, commit_sha = parse_artifact_info(artifact)

if artifact_branch == branch:
if show_all_platforms or is_compatible_platform(artifact_name, current_platform):
if last_commit is None:
last_commit = commit_sha
elif commit_sha != last_commit and show_all_platforms:
return all_wheel_artifacts

all_wheel_artifacts.append(artifact)
print(f"Found wheel: {artifact_name} (Branch: {artifact_branch}, Commit: {commit_sha[:7]})")

if not show_all_platforms and len(all_wheel_artifacts) == 1:
return all_wheel_artifacts

if limit and len(all_wheel_artifacts) >= limit:
return all_wheel_artifacts

if len(artifacts_data["artifacts"]) < per_page:
break
page += 1

return all_wheel_artifacts


def get_most_recent_wheels(branch="multi-backend-refactor"):
all_artifacts = get_all_wheel_artifacts(branch=branch, show_all_platforms=True)
most_recent_wheels = {}

for artifact in all_artifacts:
platform, artifact_branch, _ = parse_artifact_info(artifact)
if artifact_branch == branch:
key = platform
if key not in most_recent_wheels or artifact["created_at"] > most_recent_wheels[key]["created_at"]:
most_recent_wheels[key] = artifact

return most_recent_wheels


def get_recent_wheels(branch, limit=7):
return get_all_wheel_artifacts(branch=branch, show_all_platforms=True, limit=limit)


def print_most_recent_wheels(wheels, branch, show_json=False):
print(f"\nMost recent wheels for branch '{branch}':")
for platform, artifact in wheels.items():
_, python_version, _, commit_sha = parse_artifact_info(artifact)
print(f"Platform: {platform}")
print(f"Name: {artifact['name']}")
print(f"Python Version: {python_version}")
print(f"Created at: {artifact['created_at']}")
print(f"Commit: {commit_sha[:7]}")
print(f"URL: {artifact['archive_download_url']}")

if show_json:
print("Raw JSON:")
print(json.dumps(artifact, indent=2))

print("-" * 50)


def get_platform_string():
if sys.platform.startswith("win"):
return "windows-latest_x86_64"
elif sys.platform.startswith("darwin"):
return "macos-latest_x86_64" if sys.platform.machine() != "arm64" else "macos-latest_aarch64"
elif sys.platform.startswith("linux"):
return "ubuntu-latest_x86_64"
else:
raise NotImplementedError(f"Unsupported platform: {sys.platform}")


def print_artifacts_info(artifacts, show_json=False):
print("\nWheel artifacts from 'build-wheel' job:")
print("=" * 50)
for artifact in artifacts:
artifact_platform, artifact_branch, commit_sha = parse_artifact_info(artifact)

print(f"Name: {artifact['name']}")
print(f"ID: {artifact['id']}")
print(f"Size: {artifact['size_in_bytes']} bytes")
print(f"Created at: {artifact['created_at']}")
print(f"URL: {artifact['archive_download_url']}")
print(f"Platform: {artifact_platform}")
print(f"Branch: {artifact_branch}")
print(f"Commit: {commit_sha[:7]}")

if show_json:
print("Raw JSON:")
print(json.dumps(artifact, indent=2))

print("-" * 50)


def get_latest_artifact_url(artifacts, branch, platform=None):
platform = platform or get_platform_string()
for artifact in artifacts:
artifact_platform, _, artifact_branch, _ = parse_artifact_info(artifact)
if artifact_branch == branch and is_compatible_platform(artifact_platform, platform):
return artifact["archive_download_url"]

print(f"No artifact found for branch '{branch}' and platform '{platform}'")
sys.exit(1)


def download_and_extract_artifact(url, temp_dir):
artifact_path = os.path.join(temp_dir, "artifact.zip")

try:
request = urllib.request.Request(url)
# if GITHUB_PAT:
# request.add_header("Authorization", f"token {GITHUB_PAT}")

with urllib.request.urlopen(request) as response, open(artifact_path, "wb") as artifact_file:
global REST_REQUEST_COUNT
REST_REQUEST_COUNT += 1
artifact_file.write(response.read())

with zipfile.ZipFile(artifact_path, "r") as artifact_zip:
artifact_zip.extractall(temp_dir)

except urllib.error.HTTPError as e:
print(f"HTTP Error {e.code}: {e.reason}")
print(f"Error occurred while trying to download: {url}")
if e.code == 403:
print("This could be due to an expired artifact or insufficient permissions.")
sys.exit(1)

except urllib.error.URLError as e:
print(f"Error downloading artifact: {e}")
print(f"URL attempted: {url}")
sys.exit(1)

except zipfile.BadZipFile:
print("Error: Downloaded file is not a valid zip file.")
print(f"This might indicate that the artifact at {url} has expired or is corrupted.")
sys.exit(1)


def get_wheel_file(temp_dir):
wheel_files = [f for f in os.listdir(temp_dir) if f.endswith(".whl")]

if not wheel_files:
print("No wheel file found in the artifact")
sys.exit(1)

platform = get_platform_string()

for wheel_file in wheel_files:
if platform in wheel_file:
return os.path.join(temp_dir, wheel_file)

print(f"No wheel file found for platform: {platform}")
sys.exit(1)


def install_wheel(wheel_path):
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", wheel_path])
except subprocess.CalledProcessError as e:
print(f"Error installing wheel: {e}")
sys.exit(1)


def get_rate_limit():
url = "https://api.github.com/rate_limit"
request = urllib.request.Request(url)
if GITHUB_PAT:
request.add_header("Authorization", f"token {GITHUB_PAT}")
request.add_header("Accept", "application/vnd.github+json")
request.add_header("X-GitHub-Api-Version", "2022-11-28")

try:
with urllib.request.urlopen(request) as response:
global REST_REQUEST_COUNT
REST_REQUEST_COUNT += 1
data = json.loads(response.read().decode())
return data["rate"]
except urllib.error.HTTPError as e:
print(f"HTTP Error {e.code}: {e.reason}")
return None
except Exception as e:
print(f"Error fetching rate limit: {e}")
return None


def main():
parser = argparse.ArgumentParser(description="Install or list bitsandbytes wheel artifacts")
parser.add_argument("--rate-limit", action="store_true", help="Display the current GitHub API rate limit")
parser.add_argument("--list", action="store_true", help="List available wheel artifacts")
parser.add_argument("--all", action="store_true", help="Show all available wheels, regardless of platform")
parser.add_argument(
"--branch",
default="multi-backend-refactor",
help="Specify the branch to use (default: multi-backend-refactor)",
)
parser.add_argument("--main", action="store_true", help="Use the 'main' branch")
parser.add_argument(
"--mpr", "--multi-backend-refactor", action="store_true", help="Use the 'multi-backend-refactor' branch"
)
parser.add_argument("--url", help="Specify a direct URL to a wheel artifact")
parser.add_argument(
"--install", action="store_true", help="Install the wheel (default action if no other option is specified)"
)
parser.add_argument("--json", action="store_true", help="Output raw JSON for each artifact")
parser.add_argument(
"--most-recent",
action="store_true",
help="Get the most recent wheels for all platforms for the specified branch",
)
parser.add_argument("--more", action="store_true", help="Show the 7 most recent wheels")

args = parser.parse_args()

if not GITHUB_PAT:
print("Warning: GITHUB_PAT environment variable not set. You may encounter rate limiting.")

if args.rate_limit:
rate_limit = get_rate_limit()
if rate_limit:
print("GitHub API Rate Limit:")
print(f"Limit: {rate_limit['limit']}")
print(f"Remaining: {rate_limit['remaining']}")
print(f"Reset time: {datetime.datetime.fromtimestamp(rate_limit['reset']).strftime('%Y-%m-%d %H:%M:%S')}")

sys.exit(0)

if args.main:
branch = "main"
elif args.mpr:
branch = "multi-backend-refactor"
else:
branch = args.branch

print(f"Configured branch for wheel search: {branch}")

if args.more:
recent_wheels = get_recent_wheels(branch)
print_artifacts_info(recent_wheels, show_json=args.json)

elif args.most_recent:
most_recent_wheels = get_most_recent_wheels(branch)
print_most_recent_wheels(most_recent_wheels, branch, show_json=args.json)

elif args.list or args.all:
all_wheel_artifacts = get_all_wheel_artifacts(branch, show_all_platforms=args.all)
print_artifacts_info(all_wheel_artifacts, show_json=args.json)
if all_wheel_artifacts:
latest_artifact = all_wheel_artifacts[0]
print("\nTo install the most recent wheel for your platform, run:")
print(f"python {sys.argv[0]} --url {latest_artifact['archive_download_url']}")

elif args.url:
with tempfile.TemporaryDirectory() as temp_dir:
download_and_extract_artifact(args.url, temp_dir)
wheel_path = get_wheel_file(temp_dir)
install_wheel(wheel_path)
print("Installation completed successfully!")

elif args.install or (not args.list and not args.url and not args.most_recent and not args.more):
all_wheel_artifacts = get_all_wheel_artifacts(branch)
if all_wheel_artifacts:
url = all_wheel_artifacts[0]["archive_download_url"]
with tempfile.TemporaryDirectory() as temp_dir:
download_and_extract_artifact(url, temp_dir)
wheel_path = get_wheel_file(temp_dir)
install_wheel(wheel_path)
print("Installation completed successfully!")
else:
print(f"No suitable wheel found for branch '{branch}' and platform '{get_platform_string()}'")


if __name__ == "__main__":
main()

0 comments on commit 3c4e1c0

Please sign in to comment.