Skip to content

Commit

Permalink
Add --num-threads option
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Feb 24, 2025
1 parent 4bce2b0 commit 5dd2a31
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mergekit/scripts/extract_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@
help="Threshold for singular values to discard",
show_default=True,
)
@click.option(
"--num-threads",
type=int,
help="Number of threads to use for parallel CPU operations",
default=None,
)
@add_merge_options
def main(
base_model: str,
Expand All @@ -111,9 +117,12 @@ def main(
include_regexes: List[str],
verbose: bool,
sv_epsilon: float,
num_threads: Optional[int],
merge_options: MergeOptions,
):
logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)
if num_threads is not None:
torch.set_num_threads(num_threads)

if not modules_to_save:
modules_to_save = []
Expand Down
13 changes: 13 additions & 0 deletions mergekit/scripts/merge_raw_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,32 @@ def construct_param_dicts(
is_flag=True,
help="Merge all tensors present in any input model",
)
@click.option(
"--num-threads",
type=int,
help="Number of threads to use for parallel CPU operations",
default=None,
)
@add_merge_options
def main(
config_path: str,
out_path: str,
tensor_union: bool,
tensor_intersection: bool,
num_threads: Optional[int],
merge_options: MergeOptions,
):
"""Merge arbitrary PyTorch models.
Uses similar configuration syntax to `mergekit-yaml`, minus the
`slices` sections. Each input model should be the path on disk to a
pytorch pickle file or safetensors file."""
logging.basicConfig(
level=logging.INFO if not merge_options.quiet else logging.WARNING
)
if num_threads is not None:
torch.set_num_threads(num_threads)

with open(config_path, "r", encoding="utf-8") as file:
config_source = file.read()

Expand Down
11 changes: 11 additions & 0 deletions mergekit/scripts/multimerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Optional, Set, Tuple, Union

import click
import torch
import yaml

from mergekit.common import ImmutableMap, ModelReference
Expand Down Expand Up @@ -83,13 +84,20 @@ def execute(self, **kwargs):
default=True,
help="Skip merges that already exist",
)
@click.option(
"--num-threads",
type=int,
help="Number of threads to use for parallel CPU operations",
default=None,
)
@add_merge_options
def main(
config_file: str,
intermediate_dir: str,
out_path: Optional[str],
verbose: bool,
lazy: bool,
num_threads: Optional[int],
merge_options: MergeOptions,
):
"""Execute a set of potentially interdependent merge recipes.
Expand All @@ -103,6 +111,9 @@ def main(
directory. If an unnamed merge configuration is present, it will be
saved to `out_path` (which is required in this case)."""
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
if num_threads is not None:
torch.set_num_threads(num_threads)

os.makedirs(intermediate_dir, exist_ok=True)

with open(config_file, "r", encoding="utf-8") as file:
Expand Down
11 changes: 11 additions & 0 deletions mergekit/scripts/run_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# SPDX-License-Identifier: BUSL-1.1

import logging
from typing import Optional

import click
import torch
import yaml

from mergekit.config import MergeConfiguration
Expand All @@ -17,14 +19,23 @@
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--num-threads",
type=int,
help="Number of threads to use for parallel CPU operations",
default=None,
)
@add_merge_options
def main(
merge_options: MergeOptions,
config_file: str,
out_path: str,
verbose: bool,
num_threads: Optional[int],
):
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
if num_threads is not None:
torch.set_num_threads(num_threads)

with open(config_file, "r", encoding="utf-8") as file:
config_source = file.read()
Expand Down

0 comments on commit 5dd2a31

Please sign in to comment.