Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions pysr/julia_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@

from typing import Literal

from .julia_helpers import KNOWN_CLUSTERMANAGER_BACKENDS
from .julia_import import Pkg, jl
from .julia_registry_helpers import try_with_registry_fallback
from .logger_specs import AbstractLoggerSpec, TensorBoardLoggerSpec

PACKAGE_UUIDS = {
"LoopVectorization": "bdcacae8-1622-11e9-2a5c-532679323890",
"Bumper": "8ce10254-0962-460f-a3d8-1f77fea1446e",
"Zygote": "e88e6eb3-aa80-5325-afca-941959d7151f",
"SlurmClusterManager": "c82cd089-7bf7-41d7-976b-6b5d413cbe0a",
"ClusterManagers": "34f1f09b-3a8b-5176-ab39-66d58a4d544e",
"TensorBoardLogger": "899adc3e-224a-11e9-021f-63837185c80f",
}


def load_required_packages(
*,
Expand All @@ -18,26 +28,24 @@
logger_spec: AbstractLoggerSpec | None = None,
):
if turbo:
load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
load_package("LoopVectorization")
if bumper:
load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e")
load_package("Bumper")
if autodiff_backend is not None:
load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f")
load_package("Zygote")

Check warning on line 35 in pysr/julia_extensions.py

View check run for this annotation

Codecov / codecov/patch

pysr/julia_extensions.py#L35

Added line #L35 was not covered by tests
if cluster_manager is not None:
load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")
if cluster_manager == "slurm_native":
load_package("SlurmClusterManager")
elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS:
load_package("ClusterManagers")

Check warning on line 40 in pysr/julia_extensions.py

View check run for this annotation

Codecov / codecov/patch

pysr/julia_extensions.py#L37-L40

Added lines #L37 - L40 were not covered by tests
if isinstance(logger_spec, TensorBoardLoggerSpec):
load_package("TensorBoardLogger", "899adc3e-224a-11e9-021f-63837185c80f")
load_package("TensorBoardLogger")


def load_all_packages():
"""Install and load all Julia extensions available to PySR."""
load_required_packages(
turbo=True,
bumper=True,
autodiff_backend="Zygote",
cluster_manager="slurm",
logger_spec=TensorBoardLoggerSpec(log_dir="logs"),
)
for package_name, uuid_s in PACKAGE_UUIDS.items():
load_package(package_name, uuid_s)


# TODO: Refactor this file so we can install all packages at once using `juliapkg`,
Expand All @@ -48,7 +56,8 @@
return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))


def load_package(package_name: str, uuid_s: str) -> None:
def load_package(package_name: str, uuid_s: str | None = None) -> None:
uuid_s = uuid_s or PACKAGE_UUIDS[package_name]
if not isinstalled(uuid_s):

def _add_package():
Expand Down
19 changes: 16 additions & 3 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,22 @@
return str_repr


def _load_cluster_manager(cluster_manager: str):
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
return jl.seval(f"addprocs_{cluster_manager}")
KNOWN_CLUSTERMANAGER_BACKENDS = ["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"]


def load_cluster_manager(cluster_manager: str) -> AnyValue:
if cluster_manager == "slurm_native":
jl.seval("using SlurmClusterManager: SlurmManager")

Check warning on line 37 in pysr/julia_helpers.py

View check run for this annotation

Codecov / codecov/patch

pysr/julia_helpers.py#L36-L37

Added lines #L36 - L37 were not covered by tests
# TODO: Is this the right way to do this?
jl.seval(f"using Distributed: addprocs")
jl.seval("addprocs_slurm_native(args...; kws...) = addprocs(SlurmManager())")
return jl.addprocs_slurm_native
elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS:
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
return jl.seval(f"addprocs_{cluster_manager}")

Check warning on line 44 in pysr/julia_helpers.py

View check run for this annotation

Codecov / codecov/patch

pysr/julia_helpers.py#L39-L44

Added lines #L39 - L44 were not covered by tests
else:
# Assume it's a function
return jl.seval(cluster_manager)

Check warning on line 47 in pysr/julia_helpers.py

View check run for this annotation

Codecov / codecov/patch

pysr/julia_helpers.py#L47

Added line #L47 was not covered by tests


def jl_array(x, dtype=None):
Expand Down
18 changes: 8 additions & 10 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
from .julia_extensions import load_required_packages
from .julia_helpers import (
_escape_filename,
_load_cluster_manager,
jl_array,
jl_deserialize,
jl_is_function,
jl_serialize,
load_cluster_manager,
)
from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl
from .logger_specs import AbstractLoggerSpec
Expand Down Expand Up @@ -574,8 +574,8 @@
Default is `None`.
cluster_manager : str
For distributed computing, this sets the job queue system. Set
to one of "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", or
"htc". If set to one of these, PySR will run in distributed
to one of "slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld",
or "htc". If set to one of these, PySR will run in distributed
mode, and use `procs` to figure out how many processes to launch.
Default is `None`.
heap_size_hint_in_bytes : int
Expand Down Expand Up @@ -876,13 +876,11 @@
probability_negate_constant: float = 0.00743,
tournament_selection_n: int = 15,
tournament_selection_p: float = 0.982,
parallelism: (
Literal["serial", "multithreading", "multiprocessing"] | None
) = None,
# fmt: off
parallelism: Literal["serial", "multithreading", "multiprocessing"] | None = None,
procs: int | None = None,
cluster_manager: (
Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | None
) = None,
cluster_manager: Literal["slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | str | None = None,
# fmt: on
heap_size_hint_in_bytes: int | None = None,
batching: bool = False,
batch_size: int = 50,
Expand Down Expand Up @@ -1880,7 +1878,7 @@
raise ValueError(
"To use cluster managers, you must set `parallelism='multiprocessing'`."
)
cluster_manager = _load_cluster_manager(cluster_manager)
cluster_manager = load_cluster_manager(cluster_manager)

Check warning on line 1881 in pysr/sr.py

View check run for this annotation

Codecov / codecov/patch

pysr/sr.py#L1881

Added line #L1881 was not covered by tests

# TODO(mcranmer): These functions should be part of this class.
binary_operators, unary_operators = _maybe_create_inline_operators(
Expand Down
Loading