|
1 | 1 | import argparse |
2 | 2 | import dataclasses |
3 | 3 | import json |
| 4 | +import re |
4 | 5 | import sys |
5 | 6 | from pathlib import Path |
6 | 7 |
|
|
13 | 14 | from .doc import generate_readme_for_kernel |
14 | 15 | from .wheel import build_variant_to_wheel |
15 | 16 |
|
| 17 | +BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-universal)") |
| 18 | + |
16 | 19 |
|
17 | 20 | def main(): |
18 | 21 | parser = argparse.ArgumentParser( |
@@ -65,14 +68,14 @@ def main(): |
65 | 68 | help="Directory of the kernel build", |
66 | 69 | ) |
67 | 70 | upload_parser.add_argument( |
68 | | - "--repo_id", |
| 71 | + "--repo-id", |
69 | 72 | type=str, |
70 | 73 | help="Repository ID to use to upload to the Hugging Face Hub", |
71 | 74 | ) |
72 | 75 | upload_parser.add_argument( |
73 | 76 | "--branch", |
74 | 77 | type=None, |
75 | | - help="If set, the upload will be made to a particular branch of the provided `repo_id`.", |
| 78 | + help="If set, the upload will be made to a particular branch of the provided `repo-id`.", |
76 | 79 | ) |
77 | 80 | upload_parser.add_argument( |
78 | 81 | "--private", |
@@ -206,11 +209,21 @@ def lock_kernels(args): |
206 | 209 | def upload_kernels(args): |
207 | 210 | # Resolve `kernel_dir` to be uploaded. |
208 | 211 | kernel_dir = Path(args.kernel_dir).resolve() |
209 | | - build_dir = kernel_dir / "build" |
210 | | - if not kernel_dir.is_dir(): |
211 | | - raise ValueError(f"{kernel_dir} is not a directory") |
212 | | - if not build_dir.is_dir(): |
213 | | - raise ValueError("Couldn't find `build` directory inside `kernel_dir`") |
| 212 | + |
| 213 | + build_dir = None |
| 214 | + for candidate in [kernel_dir / "build", kernel_dir]: |
| 215 | + variants = [ |
| 216 | + variant_path |
| 217 | + for variant_path in candidate.glob("torch*") |
| 218 | + if BUILD_VARIANT_REGEX.match(variant_path.name) is not None |
| 219 | + ] |
| 220 | + if variants: |
| 221 | + build_dir = candidate |
| 222 | + break |
| 223 | + if build_dir is None: |
| 224 | + raise ValueError( |
| 225 | + f"Couldn't find any build variants in: {kernel_dir.absolute()} or {(kernel_dir / 'build').absolute()}" |
| 226 | + ) |
214 | 227 |
|
215 | 228 | repo_id = create_repo( |
216 | 229 | repo_id=args.repo_id, private=args.private, exist_ok=True |
|
0 commit comments