Skip to content

Commit 527625c

Browse files
committed
Add comprehensive documentation for EasyDeL, eformer, and ejkernel environment variables
This commit introduces a new markdown document detailing the environment variables utilized by EasyDeL, eformer, and ejkernel. It covers startup auto-tuning, distributed initialization, mesh and sharding configurations, compilation caching, autotuning and profiling settings, training and inference parameters, and more. Each variable is described with its default value, purpose, and use case to aid developers in configuring their environments effectively.
1 parent 95265b7 commit 527625c

6 files changed

Lines changed: 2883 additions & 328 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,7 @@ For comprehensive documentation, tutorials, and API reference:
15291529
- Installation Guide
15301530
- Training Tutorials (SFT, DPO, GRPO, etc.)
15311531
- eSurge Deployment Guide
1532+
- Environment Flags: [`docs/environment_variables.md`](docs/environment_variables.md)
15321533
- Model Architecture Details
15331534
- API Reference
15341535
- Advanced Topics (Sharding, MoE, Quantization)

docs/environment_variables.md

Lines changed: 259 additions & 0 deletions
Large diffs are not rendered by default.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Zare Chavoshi, Erfan. "EasyDeL, an open-source library, is specifically designed
106106

107107
infra/index
108108
infra/overview.md
109+
environment_variables.md
109110
infra/base_config.md
110111
infra/base_module.md
111112
infra/customization.md

easydel/infra/base_config.py

Lines changed: 158 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from eformer.pytree import auto_pytree
6363
from huggingface_hub.file_download import REGEX_COMMIT_HASH
6464
from jax import numpy as jnp
65+
from jax.sharding import AxisType
6566
from jax.sharding import NamedSharding as Ns
6667
from jax.sharding import PartitionSpec as Ps
6768
from jaxtyping import Array
@@ -429,6 +430,8 @@ class EasyDeLBaseConfig(PretrainedConfig):
429430

430431
_show_private_attrs: bool = False
431432
_hidden_mesh: common_types.Mesh | None = None
433+
_hidden_explicit_mesh: common_types.Mesh | None = None
434+
_hidden_manual_mesh: common_types.Mesh | None = None
432435

433436
def __init__(
434437
self,
@@ -566,6 +569,12 @@ def create_mesh(
566569
should_sort_granules_by_key: bool = True,
567570
allow_split_physical_axes: bool = True,
568571
backend: str | None = None,
572+
eformer_craft_mesh: bool | None = None,
573+
axis_types: tp.Sequence[AxisType | str]
574+
| AxisType
575+
| str
576+
| None
577+
| tp.Literal["auto", "explicit", "manual"] = None,
569578
):
570579
"""Creates a JAX device mesh for distributed model execution.
571580
@@ -594,6 +603,13 @@ def create_mesh(
594603
when mapping to logical mesh axes. Default: True.
595604
backend: Backend platform to create mesh for ('gpu', 'tpu', etc.).
596605
If None or empty string, uses default backend.
606+
eformer_craft_mesh: If True, use eformer's mesh creation path
607+
(mesh_utils-based, supports multi-slice/multi-process). If False, use
608+
JAX's `make_mesh` path when possible. Default: reads
609+
`EFORMER_CREATE_MESH` (True).
610+
axis_types: Optional axis type(s) for mesh axes. Accepts `AxisType` values
611+
or "auto", "explicit", "manual" strings. A single value applies to all
612+
axes; a sequence must match `sharding_axis_names`. Default: "auto".
597613
598614
Returns:
599615
A JAX Mesh object configured for distributed execution with the specified
@@ -611,18 +627,66 @@ def create_mesh(
611627

612628
if backend == "":
613629
backend = None
614-
630+
if axis_types is None:
631+
axis_types = "auto"
632+
if eformer_craft_mesh is None:
633+
eformer_craft_mesh = check_bool_flag("EFORMER_CREATE_MESH", default=True)
615634
mesh = create_mesh(
616635
axis_dims=sharding_axis_dims,
617636
axis_names=sharding_axis_names,
618637
dcn_mesh_dims=sharding_dcn_axis_dims,
619638
should_sort_granules_by_key=should_sort_granules_by_key,
620639
allow_split_physical_axes=allow_split_physical_axes,
621640
backend=backend,
622-
use_jax=not check_bool_flag("ED_CREATE_MESH", default=False),
641+
use_jax=not eformer_craft_mesh,
642+
axis_types=axis_types,
623643
)
624644
return mesh
625645

646+
def _build_mesh(
647+
self,
648+
axis_types: tp.Sequence[AxisType | str]
649+
| AxisType
650+
| str
651+
| None
652+
| tp.Literal["auto", "explicit", "manual"] = None,
653+
) -> common_types.Mesh:
654+
"""Create a JAX mesh using the config sharding settings."""
655+
sharding_axis_dims = (
656+
[v for k, v in self.sharding_axis_dims.items()]
657+
if isinstance(self.sharding_axis_dims, dict)
658+
else self.sharding_axis_dims
659+
)
660+
sharding_axis_names = (
661+
[v for k, v in self.sharding_axis_names.items()]
662+
if isinstance(self.sharding_axis_names, dict)
663+
else self.sharding_axis_names
664+
)
665+
sharding_dcn_axis_dims = (
666+
[v for k, v in self.sharding_dcn_axis_dims.items()]
667+
if isinstance(self.sharding_dcn_axis_dims, dict)
668+
else self.sharding_dcn_axis_dims
669+
)
670+
return self.create_mesh(
671+
sharding_axis_dims=tuple(sharding_axis_dims) if sharding_axis_dims is not None else sharding_axis_dims,
672+
sharding_axis_names=tuple(sharding_axis_names) if sharding_axis_names is not None else sharding_axis_names,
673+
sharding_dcn_axis_dims=tuple(sharding_dcn_axis_dims)
674+
if sharding_dcn_axis_dims is not None
675+
else sharding_dcn_axis_dims,
676+
should_sort_granules_by_key=(
677+
(self.should_sort_granules_by_key if self.should_sort_granules_by_key is not None else True)
678+
if hasattr(self, "should_sort_granules_by_key")
679+
else True
680+
),
681+
allow_split_physical_axes=(
682+
(self.allow_split_physical_axes if self.allow_split_physical_axes is not None else True)
683+
if hasattr(self, "allow_split_physical_axes")
684+
else True
685+
),
686+
backend=((self.backend if self.backend is not None else "") if hasattr(self, "backend") else ""),
687+
axis_types=axis_types,
688+
)
689+
626690
@property
627691
def mesh(self):
628692
"""Gets or creates the JAX device mesh for this configuration.
@@ -656,42 +720,38 @@ def mesh(self):
656720
if self._hidden_mesh is not None:
657721
return self._hidden_mesh
658722

659-
sharding_axis_dims = (
660-
[v for k, v in self.sharding_axis_dims.items()]
661-
if isinstance(self.sharding_axis_dims, dict)
662-
else self.sharding_axis_dims
663-
)
664-
sharding_axis_names = (
665-
[v for k, v in self.sharding_axis_names.items()]
666-
if isinstance(self.sharding_axis_names, dict)
667-
else self.sharding_axis_names
668-
)
669-
sharding_dcn_axis_dims = (
670-
[v for k, v in self.sharding_dcn_axis_dims.items()]
671-
if isinstance(self.sharding_dcn_axis_dims, dict)
672-
else self.sharding_dcn_axis_dims
673-
)
674-
mesh = self.create_mesh(
675-
sharding_axis_dims=tuple(sharding_axis_dims) if sharding_axis_dims is not None else sharding_axis_dims,
676-
sharding_axis_names=tuple(sharding_axis_names) if sharding_axis_names is not None else sharding_axis_names,
677-
sharding_dcn_axis_dims=tuple(sharding_dcn_axis_dims)
678-
if sharding_dcn_axis_dims is not None
679-
else sharding_dcn_axis_dims,
680-
should_sort_granules_by_key=(
681-
(self.should_sort_granules_by_key if self.should_sort_granules_by_key is not None else True)
682-
if hasattr(self, "should_sort_granules_by_key")
683-
else True
684-
),
685-
allow_split_physical_axes=(
686-
(self.allow_split_physical_axes if self.allow_split_physical_axes is not None else True)
687-
if hasattr(self, "allow_split_physical_axes")
688-
else True
689-
),
690-
backend=((self.backend if self.backend is not None else "") if hasattr(self, "backend") else ""),
691-
)
723+
mesh = self._build_mesh()
692724
self.set_model_mesh(mesh)
693725
return self._hidden_mesh
694726

727+
@property
728+
def explicit_mesh(self):
729+
"""Gets or creates the JAX device mesh with explicit axis types.
730+
731+
This property mirrors `mesh`, but requests AxisType.Explicit for all axes.
732+
The mesh can be overridden with `set_explicit_mesh()`.
733+
"""
734+
if self._hidden_explicit_mesh is not None:
735+
return self._hidden_explicit_mesh
736+
737+
mesh = self._build_mesh(axis_types="explicit")
738+
self.set_explicit_mesh(mesh)
739+
return self._hidden_explicit_mesh
740+
741+
@property
742+
def manual_mesh(self):
743+
"""Gets or creates the JAX device mesh with manual axis types.
744+
745+
This property mirrors `mesh`, but requests AxisType.Manual for all axes.
746+
The mesh can be overridden with `set_manual_mesh()`.
747+
"""
748+
if self._hidden_manual_mesh is not None:
749+
return self._hidden_manual_mesh
750+
751+
mesh = self._build_mesh(axis_types="manual")
752+
self.set_manual_mesh(mesh)
753+
return self._hidden_manual_mesh
754+
695755
@property
696756
def expert_mesh(self) -> jax.sharding.Mesh:
697757
"""Get the mesh configuration for expert parallelism.
@@ -791,6 +851,60 @@ def set_model_mesh(self, mesh: common_types.Mesh):
791851
except Exception:
792852
pass
793853

854+
def set_explicit_mesh(self, mesh: common_types.Mesh):
855+
"""Sets a custom explicit-axis mesh for the model.
856+
857+
Args:
858+
mesh: JAX device mesh to use for this model.
859+
"""
860+
self._hidden_explicit_mesh = mesh
861+
862+
sub_configs = getattr(self, "sub_configs", None)
863+
if not isinstance(sub_configs, dict):
864+
return
865+
866+
for attr_name in sub_configs.keys():
867+
sub_cfg = getattr(self, attr_name, None)
868+
if sub_cfg is None:
869+
continue
870+
try:
871+
if hasattr(sub_cfg, "set_explicit_mesh"):
872+
sub_cfg.set_explicit_mesh(mesh)
873+
else:
874+
sub_cfg._hidden_explicit_mesh = mesh
875+
except Exception:
876+
try:
877+
sub_cfg._hidden_explicit_mesh = mesh
878+
except Exception:
879+
pass
880+
881+
def set_manual_mesh(self, mesh: common_types.Mesh):
882+
"""Sets a custom manual-axis mesh for the model.
883+
884+
Args:
885+
mesh: JAX device mesh to use for this model.
886+
"""
887+
self._hidden_manual_mesh = mesh
888+
889+
sub_configs = getattr(self, "sub_configs", None)
890+
if not isinstance(sub_configs, dict):
891+
return
892+
893+
for attr_name in sub_configs.keys():
894+
sub_cfg = getattr(self, attr_name, None)
895+
if sub_cfg is None:
896+
continue
897+
try:
898+
if hasattr(sub_cfg, "set_manual_mesh"):
899+
sub_cfg.set_manual_mesh(mesh)
900+
else:
901+
sub_cfg._hidden_manual_mesh = mesh
902+
except Exception:
903+
try:
904+
sub_cfg._hidden_manual_mesh = mesh
905+
except Exception:
906+
pass
907+
794908
def jax_mesh(self):
795909
"""Deprecated method for getting the JAX mesh.
796910
@@ -1278,17 +1392,20 @@ def to_dict(self) -> dict[str, tp.Any]:
12781392
"""Serialize config to a dictionary while temporarily hiding forbidden types.
12791393
12801394
Notes:
1281-
EasyDeL caches the active JAX mesh on the config (``_hidden_mesh``) for runtime use.
1282-
That object contains non-picklable JAX devices, so we must exclude it from any deep
1283-
copies performed during serialization.
1395+
EasyDeL caches the active JAX meshes on the config (``_hidden_mesh``,
1396+
``_hidden_explicit_mesh``, ``_hidden_manual_mesh``) for runtime use.
1397+
Those objects contain non-picklable JAX devices, so we must exclude them
1398+
from any deep copies performed during serialization.
12841399
"""
12851400
sd = self.__dict__
12861401
forbidden_types = {"_ScalarMeta"}
12871402
extracted_values: dict[str, tp.Any] = {}
12881403

12891404
for key in list(sd.keys()):
12901405
value = sd.get(key)
1291-
if key == "_hidden_mesh" or value.__class__.__name__ in forbidden_types:
1406+
if key in {"_hidden_mesh", "_hidden_explicit_mesh", "_hidden_manual_mesh"} or value.__class__.__name__ in (
1407+
forbidden_types
1408+
):
12921409
extracted_values[key] = sd.pop(key)
12931410

12941411
try:
@@ -1313,13 +1430,13 @@ def to_dict(self) -> dict[str, tp.Any]:
13131430
sd[key] = value
13141431

13151432
def __deepcopy__(self, memo):
1316-
"""Deep copy the config while keeping the cached runtime mesh by reference."""
1433+
"""Deep copy the config while keeping the cached runtime meshes by reference."""
13171434
cls = self.__class__
13181435
result = cls.__new__(cls)
13191436
memo[id(self)] = result
13201437

13211438
for key, value in self.__dict__.items():
1322-
if key == "_hidden_mesh":
1439+
if key in {"_hidden_mesh", "_hidden_explicit_mesh", "_hidden_manual_mesh"}:
13231440
setattr(result, key, value)
13241441
else:
13251442
setattr(result, key, copy.deepcopy(value, memo))

easydel/infra/base_module.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,37 @@ def mesh(self: Self) -> jax.sharding.Mesh:
332332
"""
333333
return self.config.mesh
334334

335+
@property
336+
def explicit_mesh(self: Self) -> jax.sharding.Mesh:
337+
"""
338+
Retrieves the explicit-axis JAX device mesh from the module's configuration.
339+
340+
Returns:
341+
jax.sharding.Mesh: The device mesh defined in `self.config.explicit_mesh`.
342+
"""
343+
return self.config.explicit_mesh
344+
345+
@property
346+
def manual_mesh(self: Self) -> jax.sharding.Mesh:
347+
"""
348+
Retrieves the manual-axis JAX device mesh from the module's configuration.
349+
350+
Returns:
351+
jax.sharding.Mesh: The device mesh defined in `self.config.manual_mesh`.
352+
"""
353+
return self.config.manual_mesh
354+
355+
def mesh_call(self: Self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
356+
"""
357+
Calls the module under the configured JAX mesh.
358+
359+
This is equivalent to `with self.mesh: self(*args, **kwargs)` and uses
360+
the same arguments/return types as `__call__`. It does not use
361+
`explicit_mesh` or `manual_mesh`; enter those contexts explicitly when needed.
362+
"""
363+
with self.mesh:
364+
return self(*args, **kwargs)
365+
335366
@property
336367
def model_task(self: Self) -> str | None:
337368
"""

0 commit comments

Comments
 (0)