diff --git a/src/datasets/commands/convert_to_parquet.py b/src/datasets/commands/convert_to_parquet.py index 43e82de802b..b64f8bbfac4 100644 --- a/src/datasets/commands/convert_to_parquet.py +++ b/src/datasets/commands/convert_to_parquet.py @@ -1,6 +1,9 @@ from argparse import ArgumentParser from typing import Optional +from huggingface_hub import HfApi + +import datasets.config from datasets.commands import BaseDatasetsCLICommand from datasets.hub import convert_to_parquet @@ -11,6 +14,7 @@ def _command_factory(args): args.token, args.revision, args.trust_remote_code, + args.merge_pull_request, ) @@ -26,6 +30,11 @@ def register_subcommand(parser): parser.add_argument( "--trust_remote_code", action="store_true", help="whether to trust the code execution of the load script" ) + parser.add_argument( + "--merge-pull-request", + action="store_true", + help="whether to automatically merge the pull request after conversion", + ) parser.set_defaults(func=_command_factory) def __init__( @@ -34,13 +43,24 @@ def __init__( token: Optional[str], revision: Optional[str], trust_remote_code: bool, + merge_pull_request: bool, ): self._dataset_id = dataset_id self._token = token self._revision = revision self._trust_remote_code = trust_remote_code + self._merge_pull_request = merge_pull_request def run(self) -> None: - _ = convert_to_parquet( + commit_info = convert_to_parquet( self._dataset_id, revision=self._revision, token=self._token, trust_remote_code=self._trust_remote_code ) + + if self._merge_pull_request: + api = HfApi(endpoint=datasets.config.HF_ENDPOINT, token=self._token) + api.merge_pull_request( + repo_id=self._dataset_id, + discussion_num=int(commit_info.pr_num), + token=self._token, + repo_type="dataset", + )