Skip to content

Commit 46a2a1f

Browse files
authored
Upload improvements (#171)
* `upload --repo_id` -> `upload --repo-id` Existing options use a dash rather than an underscore. * Support uploading directly from the `build` or `result` directory
1 parent ed04861 commit 46a2a1f

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/kernels/cli.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import dataclasses
33
import json
4+
import re
45
import sys
56
from pathlib import Path
67

@@ -13,6 +14,8 @@
1314
from .doc import generate_readme_for_kernel
1415
from .wheel import build_variant_to_wheel
1516

17+
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-universal)")
18+
1619

1720
def main():
1821
parser = argparse.ArgumentParser(
@@ -65,14 +68,14 @@ def main():
6568
help="Directory of the kernel build",
6669
)
6770
upload_parser.add_argument(
68-
"--repo_id",
71+
"--repo-id",
6972
type=str,
7073
help="Repository ID to use to upload to the Hugging Face Hub",
7174
)
7275
upload_parser.add_argument(
7376
"--branch",
7477
type=None,
75-
help="If set, the upload will be made to a particular branch of the provided `repo_id`.",
78+
help="If set, the upload will be made to a particular branch of the provided `repo-id`.",
7679
)
7780
upload_parser.add_argument(
7881
"--private",
@@ -206,11 +209,21 @@ def lock_kernels(args):
206209
def upload_kernels(args):
207210
# Resolve `kernel_dir` to be uploaded.
208211
kernel_dir = Path(args.kernel_dir).resolve()
209-
build_dir = kernel_dir / "build"
210-
if not kernel_dir.is_dir():
211-
raise ValueError(f"{kernel_dir} is not a directory")
212-
if not build_dir.is_dir():
213-
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
212+
213+
build_dir = None
214+
for candidate in [kernel_dir / "build", kernel_dir]:
215+
variants = [
216+
variant_path
217+
for variant_path in candidate.glob("torch*")
218+
if BUILD_VARIANT_REGEX.match(variant_path.name) is not None
219+
]
220+
if variants:
221+
build_dir = candidate
222+
break
223+
if build_dir is None:
224+
raise ValueError(
225+
f"Couldn't find any build variants in: {kernel_dir.absolute()} or {(kernel_dir / 'build').absolute()}"
226+
)
214227

215228
repo_id = create_repo(
216229
repo_id=args.repo_id, private=args.private, exist_ok=True

0 commit comments

Comments
 (0)