diff --git a/vast.py b/vast.py index c9af6aa0..85afa348 100644 --- a/vast.py +++ b/vast.py @@ -3468,11 +3468,13 @@ def _parse_region(region): return region @parser.command( - argument("-g", "--gpu-name", type=str, required=True, choices=_get_gpu_names(), help="Name of the GPU model, replace spaces with underscores"), - argument("-n", "--num-gpus", type=str, required=True, choices=["1", "2", "4", "8", "12", "14"], help="Number of GPUs required"), + argument("-g", "--gpu-name", type=str, default=None, choices=_get_gpu_names(), help="Name of the GPU model, replace spaces with underscores (required unless --query is set)"), + argument("-n", "--num-gpus", type=str, default=None, choices=["1", "2", "4", "8", "12", "14"], help="Number of GPUs required (required unless --query is set)"), argument("-r", "--region", type=str, help="Geographical location of the instance"), argument("-i", "--image", required=True, help="Name of the image to use for instance"), + argument("--query", type=str, default=None, help="Search query in simple query syntax (e.g. 'gpu_name=RTX_3090 num_gpus=1'). If set, overrides -g/-n/-r/--disk/--cpu-ram for offer selection."), argument("-d", "--disk", type=float, default=16.0, help="Disk space required in GB"), + argument("--cpu-ram", type=float, default=None, help="Minimum system RAM required in GB"), argument("--limit", default=3, type=int, help=""), argument("-o", "--order", type=str, help="Comma-separated list of fields to sort on. postfix field with - to sort desc. ex: -o 'num_gpus,total_flops-'. default='score-'", default='score-'), argument("--login", help="docker login arguments for private repo authentication, surround with '' ", type=str), @@ -3491,7 +3493,7 @@ def _parse_region(region): argument("--env", help="env variables and port mapping options, surround with '' ", type=str), argument("--args", nargs=argparse.REMAINDER, help="list of arguments passed to container ENTRYPOINT. Onstart is recommended for this purpose. (must be last argument)"), argument("--force", help="Skip sanity checks when creating from an existing instance", action="store_true"), - argument("--cancel-unavail", help="Return error if scheduling fails (rather than creating a stopped instance)", action="store_true"), + argument("--cancel-unavail", default=True, help="Return error if scheduling fails (rather than creating a stopped instance)", action="store_true"), argument("--template_hash", help="template hash which contains all relevant information about an instance. This can be used as a replacement for other parameters describing the instance configuration", type=str), usage="vastai launch instance [--help] [--api-key API_KEY] [geolocation] [disk_space]", help="Launch the top instance from the search offers based on the given parameters", @@ -3531,20 +3533,41 @@ def launch__instance(args): :param argparse.Namespace args: Namespace with many fields relevant to the endpoint. """ - args_query = f"num_gpus={args.num_gpus} gpu_name={args.gpu_name}" + base_query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}} - if args.region: - if not _is_valid_region(args.region): - print("Invalid region or country codes provided.") - return - region_query = _parse_region(args.region) - args_query += f" geolocation in {region_query}" + if args.query is not None: + # --query mode: parse the raw query string directly, ignoring individual offer-selection args + try: + query = parse_query(args.query, base_query, offers_fields, offers_alias, offers_mult) + except ValueError as e: + print("Error parsing --query:", e) + return 1 + else: + # individual-arg mode: --gpu-name and --num-gpus are required + if args.gpu_name is None or args.num_gpus is None: + print("Error: --gpu-name and --num-gpus are required when --query is not set.") + return 1 - if args.disk: - args_query += f" disk_space>={args.disk}" + args_query = f"num_gpus={args.num_gpus} gpu_name={args.gpu_name}" - base_query = {"verified": {"eq": True}, "external": {"eq": False}, "rentable": {"eq": True}, "rented": {"eq": False}} - query = parse_query(args_query, base_query, offers_fields, offers_alias, offers_mult) + if args.region: + if not _is_valid_region(args.region): + print("Invalid region or country codes provided.") + return + region_query = _parse_region(args.region) + args_query += f" geolocation in {region_query}" + + if args.disk: + args_query += f" disk_space>={args.disk}" + + if args.cpu_ram: + args_query += f" cpu_ram>={args.cpu_ram}" + + try: + query = parse_query(args_query, base_query, offers_fields, offers_alias, offers_mult) + except ValueError as e: + print("Error:", e) + return 1 order = [] for name in args.order.split(","): @@ -3598,7 +3621,8 @@ def launch__instance(args): "jupyter_dir": args.jupyter_dir, "force": args.force, "cancel_unavail": args.cancel_unavail, - "template_hash_id" : args.template_hash + "template_hash_id" : args.template_hash, + "cpu_ram": args.cpu_ram } # don't send runtype with template_hash if args.template_hash is None: