You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step.
However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).
Files concerned :
mamba_mlx.py (step functions)
misc.py
The depthwise conv implemented in misc.py seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.
Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model=2560)
The text was updated successfully, but these errors were encountered:
The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)
This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the
ssm_step
.However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).
Files concerned :
mamba_mlx.py
(step
functions)misc.py
The depthwise conv implemented in
misc.py
seems to be part of the problem.As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in
misc.py
uses groups=1 but with some weights set at 0. (only workaround found).This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.
Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create
d_model
conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (byd_model
=2560)The text was updated successfully, but these errors were encountered: