6262from eformer .pytree import auto_pytree
6363from huggingface_hub .file_download import REGEX_COMMIT_HASH
6464from jax import numpy as jnp
65+ from jax .sharding import AxisType
6566from jax .sharding import NamedSharding as Ns
6667from jax .sharding import PartitionSpec as Ps
6768from jaxtyping import Array
@@ -429,6 +430,8 @@ class EasyDeLBaseConfig(PretrainedConfig):
429430
430431 _show_private_attrs : bool = False
431432 _hidden_mesh : common_types .Mesh | None = None
433+ _hidden_explicit_mesh : common_types .Mesh | None = None
434+ _hidden_manual_mesh : common_types .Mesh | None = None
432435
433436 def __init__ (
434437 self ,
@@ -566,6 +569,12 @@ def create_mesh(
566569 should_sort_granules_by_key : bool = True ,
567570 allow_split_physical_axes : bool = True ,
568571 backend : str | None = None ,
572+ eformer_craft_mesh : bool | None = None ,
573+ axis_types : tp .Sequence [AxisType | str ]
574+ | AxisType
575+ | str
576+ | None
577+ | tp .Literal ["auto" , "explicit" , "manual" ] = None ,
569578 ):
570579 """Creates a JAX device mesh for distributed model execution.
571580
@@ -594,6 +603,13 @@ def create_mesh(
594603 when mapping to logical mesh axes. Default: True.
595604 backend: Backend platform to create mesh for ('gpu', 'tpu', etc.).
596605 If None or empty string, uses default backend.
606+ eformer_craft_mesh: If True, use eformer's mesh creation path
607+ (mesh_utils-based, supports multi-slice/multi-process). If False, use
608+ JAX's `make_mesh` path when possible. Default: reads
609+ `EFORMER_CREATE_MESH` (True).
610+ axis_types: Optional axis type(s) for mesh axes. Accepts `AxisType` values
611+ or "auto", "explicit", "manual" strings. A single value applies to all
612+ axes; a sequence must match `sharding_axis_names`. Default: "auto".
597613
598614 Returns:
599615 A JAX Mesh object configured for distributed execution with the specified
@@ -611,18 +627,66 @@ def create_mesh(
611627
612628 if backend == "" :
613629 backend = None
614-
630+ if axis_types is None :
631+ axis_types = "auto"
632+ if eformer_craft_mesh is None :
633+ eformer_craft_mesh = check_bool_flag ("EFORMER_CREATE_MESH" , default = True )
615634 mesh = create_mesh (
616635 axis_dims = sharding_axis_dims ,
617636 axis_names = sharding_axis_names ,
618637 dcn_mesh_dims = sharding_dcn_axis_dims ,
619638 should_sort_granules_by_key = should_sort_granules_by_key ,
620639 allow_split_physical_axes = allow_split_physical_axes ,
621640 backend = backend ,
622- use_jax = not check_bool_flag ("ED_CREATE_MESH" , default = False ),
641+ use_jax = not eformer_craft_mesh ,
642+ axis_types = axis_types ,
623643 )
624644 return mesh
625645
646+ def _build_mesh (
647+ self ,
648+ axis_types : tp .Sequence [AxisType | str ]
649+ | AxisType
650+ | str
651+ | None
652+ | tp .Literal ["auto" , "explicit" , "manual" ] = None ,
653+ ) -> common_types .Mesh :
654+ """Create a JAX mesh using the config sharding settings."""
655+ sharding_axis_dims = (
656+ [v for k , v in self .sharding_axis_dims .items ()]
657+ if isinstance (self .sharding_axis_dims , dict )
658+ else self .sharding_axis_dims
659+ )
660+ sharding_axis_names = (
661+ [v for k , v in self .sharding_axis_names .items ()]
662+ if isinstance (self .sharding_axis_names , dict )
663+ else self .sharding_axis_names
664+ )
665+ sharding_dcn_axis_dims = (
666+ [v for k , v in self .sharding_dcn_axis_dims .items ()]
667+ if isinstance (self .sharding_dcn_axis_dims , dict )
668+ else self .sharding_dcn_axis_dims
669+ )
670+ return self .create_mesh (
671+ sharding_axis_dims = tuple (sharding_axis_dims ) if sharding_axis_dims is not None else sharding_axis_dims ,
672+ sharding_axis_names = tuple (sharding_axis_names ) if sharding_axis_names is not None else sharding_axis_names ,
673+ sharding_dcn_axis_dims = tuple (sharding_dcn_axis_dims )
674+ if sharding_dcn_axis_dims is not None
675+ else sharding_dcn_axis_dims ,
676+ should_sort_granules_by_key = (
677+ (self .should_sort_granules_by_key if self .should_sort_granules_by_key is not None else True )
678+ if hasattr (self , "should_sort_granules_by_key" )
679+ else True
680+ ),
681+ allow_split_physical_axes = (
682+ (self .allow_split_physical_axes if self .allow_split_physical_axes is not None else True )
683+ if hasattr (self , "allow_split_physical_axes" )
684+ else True
685+ ),
686+ backend = ((self .backend if self .backend is not None else "" ) if hasattr (self , "backend" ) else "" ),
687+ axis_types = axis_types ,
688+ )
689+
626690 @property
627691 def mesh (self ):
628692 """Gets or creates the JAX device mesh for this configuration.
@@ -656,42 +720,38 @@ def mesh(self):
656720 if self ._hidden_mesh is not None :
657721 return self ._hidden_mesh
658722
659- sharding_axis_dims = (
660- [v for k , v in self .sharding_axis_dims .items ()]
661- if isinstance (self .sharding_axis_dims , dict )
662- else self .sharding_axis_dims
663- )
664- sharding_axis_names = (
665- [v for k , v in self .sharding_axis_names .items ()]
666- if isinstance (self .sharding_axis_names , dict )
667- else self .sharding_axis_names
668- )
669- sharding_dcn_axis_dims = (
670- [v for k , v in self .sharding_dcn_axis_dims .items ()]
671- if isinstance (self .sharding_dcn_axis_dims , dict )
672- else self .sharding_dcn_axis_dims
673- )
674- mesh = self .create_mesh (
675- sharding_axis_dims = tuple (sharding_axis_dims ) if sharding_axis_dims is not None else sharding_axis_dims ,
676- sharding_axis_names = tuple (sharding_axis_names ) if sharding_axis_names is not None else sharding_axis_names ,
677- sharding_dcn_axis_dims = tuple (sharding_dcn_axis_dims )
678- if sharding_dcn_axis_dims is not None
679- else sharding_dcn_axis_dims ,
680- should_sort_granules_by_key = (
681- (self .should_sort_granules_by_key if self .should_sort_granules_by_key is not None else True )
682- if hasattr (self , "should_sort_granules_by_key" )
683- else True
684- ),
685- allow_split_physical_axes = (
686- (self .allow_split_physical_axes if self .allow_split_physical_axes is not None else True )
687- if hasattr (self , "allow_split_physical_axes" )
688- else True
689- ),
690- backend = ((self .backend if self .backend is not None else "" ) if hasattr (self , "backend" ) else "" ),
691- )
723+ mesh = self ._build_mesh ()
692724 self .set_model_mesh (mesh )
693725 return self ._hidden_mesh
694726
727+ @property
728+ def explicit_mesh (self ):
729+ """Gets or creates the JAX device mesh with explicit axis types.
730+
731+ This property mirrors `mesh`, but requests AxisType.Explicit for all axes.
732+ The mesh can be overridden with `set_explicit_mesh()`.
733+ """
734+ if self ._hidden_explicit_mesh is not None :
735+ return self ._hidden_explicit_mesh
736+
737+ mesh = self ._build_mesh (axis_types = "explicit" )
738+ self .set_explicit_mesh (mesh )
739+ return self ._hidden_explicit_mesh
740+
741+ @property
742+ def manual_mesh (self ):
743+ """Gets or creates the JAX device mesh with manual axis types.
744+
745+ This property mirrors `mesh`, but requests AxisType.Manual for all axes.
746+ The mesh can be overridden with `set_manual_mesh()`.
747+ """
748+ if self ._hidden_manual_mesh is not None :
749+ return self ._hidden_manual_mesh
750+
751+ mesh = self ._build_mesh (axis_types = "manual" )
752+ self .set_manual_mesh (mesh )
753+ return self ._hidden_manual_mesh
754+
695755 @property
696756 def expert_mesh (self ) -> jax .sharding .Mesh :
697757 """Get the mesh configuration for expert parallelism.
@@ -791,6 +851,60 @@ def set_model_mesh(self, mesh: common_types.Mesh):
791851 except Exception :
792852 pass
793853
854+ def set_explicit_mesh (self , mesh : common_types .Mesh ):
855+ """Sets a custom explicit-axis mesh for the model.
856+
857+ Args:
858+ mesh: JAX device mesh to use for this model.
859+ """
860+ self ._hidden_explicit_mesh = mesh
861+
862+ sub_configs = getattr (self , "sub_configs" , None )
863+ if not isinstance (sub_configs , dict ):
864+ return
865+
866+ for attr_name in sub_configs .keys ():
867+ sub_cfg = getattr (self , attr_name , None )
868+ if sub_cfg is None :
869+ continue
870+ try :
871+ if hasattr (sub_cfg , "set_explicit_mesh" ):
872+ sub_cfg .set_explicit_mesh (mesh )
873+ else :
874+ sub_cfg ._hidden_explicit_mesh = mesh
875+ except Exception :
876+ try :
877+ sub_cfg ._hidden_explicit_mesh = mesh
878+ except Exception :
879+ pass
880+
881+ def set_manual_mesh (self , mesh : common_types .Mesh ):
882+ """Sets a custom manual-axis mesh for the model.
883+
884+ Args:
885+ mesh: JAX device mesh to use for this model.
886+ """
887+ self ._hidden_manual_mesh = mesh
888+
889+ sub_configs = getattr (self , "sub_configs" , None )
890+ if not isinstance (sub_configs , dict ):
891+ return
892+
893+ for attr_name in sub_configs .keys ():
894+ sub_cfg = getattr (self , attr_name , None )
895+ if sub_cfg is None :
896+ continue
897+ try :
898+ if hasattr (sub_cfg , "set_manual_mesh" ):
899+ sub_cfg .set_manual_mesh (mesh )
900+ else :
901+ sub_cfg ._hidden_manual_mesh = mesh
902+ except Exception :
903+ try :
904+ sub_cfg ._hidden_manual_mesh = mesh
905+ except Exception :
906+ pass
907+
794908 def jax_mesh (self ):
795909 """Deprecated method for getting the JAX mesh.
796910
@@ -1278,17 +1392,20 @@ def to_dict(self) -> dict[str, tp.Any]:
12781392 """Serialize config to a dictionary while temporarily hiding forbidden types.
12791393
12801394 Notes:
1281- EasyDeL caches the active JAX mesh on the config (``_hidden_mesh``) for runtime use.
1282- That object contains non-picklable JAX devices, so we must exclude it from any deep
1283- copies performed during serialization.
1395+ EasyDeL caches the active JAX meshes on the config (``_hidden_mesh``,
1396+ ``_hidden_explicit_mesh``, ``_hidden_manual_mesh``) for runtime use.
1397+ Those objects contain non-picklable JAX devices, so we must exclude them
1398+ from any deep copies performed during serialization.
12841399 """
12851400 sd = self .__dict__
12861401 forbidden_types = {"_ScalarMeta" }
12871402 extracted_values : dict [str , tp .Any ] = {}
12881403
12891404 for key in list (sd .keys ()):
12901405 value = sd .get (key )
1291- if key == "_hidden_mesh" or value .__class__ .__name__ in forbidden_types :
1406+ if key in {"_hidden_mesh" , "_hidden_explicit_mesh" , "_hidden_manual_mesh" } or value .__class__ .__name__ in (
1407+ forbidden_types
1408+ ):
12921409 extracted_values [key ] = sd .pop (key )
12931410
12941411 try :
@@ -1313,13 +1430,13 @@ def to_dict(self) -> dict[str, tp.Any]:
13131430 sd [key ] = value
13141431
13151432 def __deepcopy__ (self , memo ):
1316- """Deep copy the config while keeping the cached runtime mesh by reference."""
1433+ """Deep copy the config while keeping the cached runtime meshes by reference."""
13171434 cls = self .__class__
13181435 result = cls .__new__ (cls )
13191436 memo [id (self )] = result
13201437
13211438 for key , value in self .__dict__ .items ():
1322- if key == "_hidden_mesh" :
1439+ if key in { "_hidden_mesh" , "_hidden_explicit_mesh" , "_hidden_manual_mesh" } :
13231440 setattr (result , key , value )
13241441 else :
13251442 setattr (result , key , copy .deepcopy (value , memo ))
0 commit comments