Skip to content

Commit 08b55f1

Browse files
Enable exporting Voxtral on MPS (#142)
* Enable exporting Voxtral on MPS * Take care of example input dtype * Fix warning * Reformat --------- Co-authored-by: Mengwei Liu <[email protected]> Co-authored-by: Mengwei Liu <[email protected]>
1 parent 1a582c6 commit 08b55f1

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ def __init__(
5151
batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device
5252
)
5353

54-
assert device is None or device == "cpu", "Device must be None or 'cpu'"
54+
assert device is None or device in [
55+
"cpu",
56+
"cuda",
57+
"mps",
58+
], "Device must be None or one of 'cpu', 'cuda' or 'mps'."
5559

5660
# Create a list of CustomKVCache instances derived from each layer of the original Transformers cache, one per layer.
5761
self.kv_cache = torch.nn.ModuleList()
@@ -63,6 +67,8 @@ def __init__(
6367
head_dim=layer.head_dim,
6468
dtype=dtype,
6569
)
70+
layer_cache.k_cache = layer_cache.k_cache.to(device)
71+
layer_cache.v_cache = layer_cache.v_cache.to(device)
6672
self.kv_cache.append(layer_cache)
6773

6874
def update(
@@ -160,7 +166,7 @@ def from_legacy_cache(
160166
elif dtype is None and hasattr(legacy_cache.k_cache, "dtype"):
161167
dtype = legacy_cache.k_cache.dtype
162168

163-
assert device is None or device == "cpu"
169+
# assert device is None or device == "cpu"
164170
assert dtype is None or dtype == torch.float32
165171

166172
# Use the legacy cache's max_seq_len if max_cache_len is not specified
@@ -206,7 +212,11 @@ def __init__(
206212
batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device
207213
)
208214

209-
assert device is None or device == "cpu", "Device must be None or 'cpu'"
215+
assert device is None or device in [
216+
"cpu",
217+
"cuda",
218+
"mps",
219+
], "Device must be None or one of 'cpu', 'cuda' or 'mps'."
210220

211221
self.cache_position = None
212222
# Create a list of cache instances, one per layer.
@@ -230,6 +240,8 @@ def __init__(
230240
head_dim=layer.head_dim,
231241
dtype=dtype,
232242
)
243+
layer_cache.k_cache = layer_cache.k_cache.to(device)
244+
layer_cache.v_cache = layer_cache.v_cache.to(device)
233245
self.kv_cache.append(layer_cache)
234246

235247
def update(

optimum/exporters/executorch/integrations.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def prepare_export_inputs(self):
6565
raise ValueError(
6666
f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'pixel_values' key: {processed_inputs}"
6767
)
68-
export_inputs = processed_inputs["pixel_values"]
68+
export_inputs = processed_inputs["pixel_values"].to(dtype=self.model.dtype)
6969

7070
# 2. Get export dynamic shapes
7171
dynamic_shapes = None # No batching for now.
@@ -126,7 +126,7 @@ def prepare_export_inputs(self):
126126
raise ValueError(
127127
f"Unable to obtain sample audio encoder inputs for export for {model_id} - the processor did not return formatted inputs with the 'input_features' key: {processed_inputs}"
128128
)
129-
export_inputs = processed_inputs["input_features"]
129+
export_inputs = processed_inputs["input_features"].to(dtype=self.model.dtype)
130130
# Make sure the export inputs has a batch size > 1 so that it doesn't 0/1 specialize.
131131
if export_inputs.shape[0] == 1:
132132
export_inputs = export_inputs.repeat(2, 1, 1)
@@ -242,7 +242,9 @@ def _prepare_decoder_only_export_inputs(self, max_seq_len: int):
242242

243243
# Prepare inputs with dynamic shapes
244244
seq_length = 3
245-
example_inputs_embeds = torch.zeros((1, seq_length, self.config.text_config.hidden_size), dtype=torch.float)
245+
example_inputs_embeds = torch.zeros(
246+
(1, seq_length, self.config.text_config.hidden_size), dtype=self.model.dtype
247+
)
246248
example_cache_position = torch.arange(seq_length, dtype=torch.long)
247249

248250
seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_len)
@@ -311,6 +313,9 @@ def export(
311313
logging.info(
312314
f"Exporting decoder using inputs_embeds({inputs_embeds.shape}), cache_position({cache_position.shape})={cache_position}, dynamic_shapes={dynamic_shapes}"
313315
)
316+
# Move inputs to the same device as the model
317+
inputs_embeds = inputs_embeds.to(self.model.device)
318+
cache_position = cache_position.to(self.model.device)
314319
exported_program = exportable_module.export(
315320
inputs_embeds=inputs_embeds,
316321
cache_position=cache_position,
@@ -341,7 +346,8 @@ def export(
341346
logging.info(
342347
f"Exporting token embeddings using input_ids({input_ids.shape}), dynamic_shapes={dynamic_shapes}"
343348
)
344-
349+
# Move inputs to the same device as the model
350+
input_ids = input_ids.to(self.model.device)
345351
token_embedding_exported_program = torch.export.export(
346352
self.model.get_input_embeddings(),
347353
args=(input_ids,),
@@ -369,6 +375,8 @@ def export(
369375
f"Exporting {self.modality} encoder using input_features({input_features.shape}), dynamic_shapes={dynamic_shapes}"
370376
)
371377

378+
# Move inputs to the same device as the model
379+
input_features = input_features.to(self.model.device)
372380
encoder_exported_program = torch.export.export(
373381
encoder,
374382
args=(),

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
115115
MultiModalTextToTextExportableModule:
116116
An instance of `MultiModalTextToTextExportableModule` for exporting and lowering to ExecuTorch.
117117
"""
118-
device = "cpu"
118+
device = kwargs.get("device", "cpu")
119119
batch_size = 1
120120
dtype = kwargs.get("dtype", "float32")
121121
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
@@ -166,7 +166,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
166166
eager_model = AutoModelForPreTraining.from_pretrained(
167167
model_name_or_path,
168168
device_map=device,
169-
torch_dtype=dtype,
169+
dtype=dtype,
170170
config=config,
171171
attn_implementation=attn_implementation,
172172
)
@@ -177,6 +177,7 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
177177
cache_config={
178178
"batch_size": batch_size,
179179
"max_cache_len": max_length,
180+
"device": device,
180181
},
181182
)
182183
decoder_name, audio_encoder_name, vision_encoder_name = _validate_multimodal_components(eager_model)

0 commit comments

Comments
 (0)