Replies: 1 comment 6 replies
-
Hey @srossi93, I was playing around with your code. Since you didn't provide a full snippet I had to guess how you initialized the model, I also change some stuff a little bit. I currently cannot offer any advice as profiling is hard. Here is my code if its of any use: #%%
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import flax.linen as nn
device_mesh = mesh_utils.create_device_mesh((jax.local_device_count(), 1))
mesh = Mesh(devices=device_mesh, axis_names=('model', 'data'))
print(mesh)
#%%
class MLP(nn.Module):
features: int = 5000
@nn.compact
def __call__(self, x):
y = nn.Dense(
self.features,
use_bias=False,
kernel_init=nn.spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('input', 'hidden')),
)(x)
y = nn.spmd.with_logical_constraint(y, ('batch', None))
y = nn.Dense(
1,
use_bias=False,
kernel_init=nn.spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('hidden', 'output')),
)(y)
y = y.reshape(-1)
y = nn.spmd.with_logical_constraint(y, ('batch',))
return y
EnsembleMLP = nn.vmap(
MLP,
in_axes=None,
out_axes=1,
axis_size=jax.local_device_count(),
variable_axes={'params': 0},
metadata_params={nn.PARTITION_NAME: 'ensemble'},
split_rngs={'params': True},
)
model = EnsembleMLP()
rules = (('ensemble', 'model'), ('batch', 'data'))
#%% Partition Data
data_spec = PartitionSpec()
x = jax.random.normal(jax.random.PRNGKey(0), (8, 20))
rng = jax.random.PRNGKey(0)
x = jax.device_put(x, NamedSharding(mesh, data_spec))
print(x.shape)
jax.debug.visualize_array_sharding(x)
#%% Partition Model
@jax.jit
def create_variables():
variables = model.init(rng, x)
spec = nn.get_partition_spec(variables)
mesh_spec = nn.logical_to_mesh(spec, rules)
variables = nn.unbox(variables)
variables = jax.tree_map(
lambda p, s: jax.lax.with_sharding_constraint(p, NamedSharding(mesh, s)), variables, mesh_spec)
return variables
variables = create_variables()
kernel = variables['params']['Dense_0']['kernel']
kernel = kernel.reshape(kernel.shape[0], -1)
print(kernel.shape)
jax.debug.visualize_array_sharding(kernel)
# %%
@jax.jit
def forward(variables, x):
return model.apply(variables, x)
y = forward(variables, x)
print(y.shape)
jax.debug.visualize_array_sharding(y)
# %% |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
So, I'm trying to play around with sharding API (#2730) but I have unexpected results regarding the memory consumption of the GPUs.
The simple setup that I want to investigate is ensembling few models, where the shard happens at model level (e.g. all model parameters and outputs are sharded on the GPUs, while the input is replicated, see below).
This is my setup. Let's start with creating the device mesh and the mesh
and a simple MLP
Now let's use the lifted vmap to ensemble models
and few extra bit for mapping the logical axis
Now if I inspect the various parameters and variables with
jax.debug.visualize_array_sharding
I see that they are correctly sharded. For example, the output of the model is hereNow, if I understood everything right I should see a reduced memory usage when I move ensembling from 1 GPU to 2 GPUs but this is not the case (I go from 2447MB to 2193MB each).
This is more clear when I know that the ensemble doesn't fit on a single device but it would do in two. Counterintuitively, in this case the (p)jit fails with OOM in both cases.
Can you help me to understand this? Am I missing something stupid?
Beta Was this translation helpful? Give feedback.
All reactions