Skip to content

Commit b433ec8

Browse files
3outeilleydshiehnithinraokeustlb
authored
test tensor parallel: make tests for dense model more robust (#41968)
* make test forward and backward more robust * refactor compile part of test tensor parallel * linting * pass rank around instead of calling it over and over * Run slow v2 (#41914) * Super * Super * Super * Super --------- Co-authored-by: ydshieh <[email protected]> * Fix `detectron2` installation in docker files (#41975) * detectron2 - part 1 * detectron2 - part 2 --------- Co-authored-by: ydshieh <[email protected]> * Fix `autoawq[kernels]` installation in quantization docker file (#41978) fix autoawq[kernels] Co-authored-by: ydshieh <[email protected]> * add support for saving encoder only so any parakeet model can be loaded for inference (#41969) * add support for saving encoder only so any decoder model can be loaded Signed-off-by: nithinraok <[email protected]> * use convolution_bias * convert modular * convolution_bias in convertion script --------- Signed-off-by: nithinraok <[email protected]> Co-authored-by: Eustache Le Bihan <[email protected]> Co-authored-by: eustlb <[email protected]> --------- Signed-off-by: nithinraok <[email protected]> Co-authored-by: Yih-Dar <[email protected]> Co-authored-by: ydshieh <[email protected]> Co-authored-by: Nithin Rao <[email protected]> Co-authored-by: Eustache Le Bihan <[email protected]> Co-authored-by: eustlb <[email protected]>
1 parent 3c16c1a commit b433ec8

File tree

1 file changed

+169
-72
lines changed

1 file changed

+169
-72
lines changed

tests/tensor_parallel/test_tensor_parallel.py

Lines changed: 169 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
1616
# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
1717
# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
18-
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_forward
19-
18+
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
19+
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
2020
import os
2121
import tempfile
2222
import warnings
2323

2424
from safetensors import safe_open
2525

2626
from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_available
27-
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
27+
from transformers.integrations.tensor_parallel import get_packed_weights, get_tensor_shard, repack_weights
2828
from transformers.testing_utils import (
2929
TestCasePlus,
3030
backend_device_count,
@@ -37,6 +37,7 @@
3737

3838
if is_torch_available():
3939
import torch
40+
import torch.distributed as dist
4041
import torch.multiprocessing as mp
4142

4243

@@ -53,14 +54,14 @@ def setup_dist_env(rank, world_size, port):
5354

5455
if torch.cuda.is_available():
5556
torch.cuda.set_device(rank)
56-
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
57+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
5758
else:
58-
torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size)
59+
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
5960

6061
func(rank, *func_args, **func_kwargs)
6162

62-
torch.distributed.barrier()
63-
torch.distributed.destroy_process_group()
63+
dist.barrier()
64+
dist.destroy_process_group()
6465

6566

6667
def init_distributed(tp: int):
@@ -211,95 +212,169 @@ def test_tp_plan_none_handling(self):
211212

212213

213214
# ====== TEST FUNCTIONS ======
214-
def _test_model_forward_impl(rank):
215-
"""Implementation of test_model_forward for distributed execution."""
215+
def _test_model_dense_forward_impl(rank, mode):
216+
"""Implementation for comparing TP and non-TP model outputs."""
216217
model_id = "JackFram/llama-68m"
217218

218-
int(os.environ["RANK"])
219-
int(os.environ["WORLD_SIZE"])
220-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
221-
torch.distributed.barrier()
222-
223-
has_dtensor = 0
224-
for name, parameter in model.named_parameters():
225-
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
226-
has_dtensor = 1
227-
break
228-
229-
assert has_dtensor == 1, "TP model must has DTensor"
219+
# Ensure same random seed for reproducibility
220+
torch.manual_seed(0)
230221

222+
# Load tokenizer and prepare inputs - same for both models
231223
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
232224
prompt = "Can I help"
225+
inputs = tokenizer(prompt, return_tensors="pt")
226+
227+
# Load TP model first to determine device
228+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
229+
dist.barrier()
230+
if mode == "eval":
231+
model_tp.eval()
232+
else:
233+
model_tp.train()
234+
235+
# Load non-TP model and move to same device as TP model
236+
device = model_tp.device
237+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
238+
model = model.to(device)
239+
240+
if mode == "eval":
241+
model.eval()
242+
else:
243+
model.train()
244+
245+
# Prepare inputs on the same device
246+
input_ids = inputs.input_ids.to(device)
247+
248+
# Run forward pass on both models
249+
with torch.no_grad():
250+
# Non-TP model output
251+
outputs = model(input_ids)
252+
logits = outputs.logits
253+
254+
# TP model output
255+
outputs_tp = model_tp(input_ids)
256+
logits_tp = outputs_tp.logits
233257

234-
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
235-
outputs = model(inputs)
258+
# Compare outputs - they should match
259+
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
260+
f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
261+
)
236262

237-
next_token_logits = outputs[0][:, -1, :]
238-
next_token = torch.argmax(next_token_logits, dim=-1)
239-
response = tokenizer.decode(next_token)
240-
assert response == "with"
241-
print("response:", response)
242-
torch.distributed.barrier()
263+
dist.barrier()
243264

244265

245-
def _test_model_backward_pass_impl(rank):
246-
"""Implementation of test_model_backward_pass for distributed execution."""
266+
def _test_model_dense_backward_pass_impl(rank):
267+
"""Implementation for comparing TP and non-TP model backward passes."""
247268
model_id = "JackFram/llama-68m"
248269

249-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
250-
torch.distributed.barrier()
270+
torch.manual_seed(0)
251271

252-
# Dummy forward and backward pass
253-
# Note that loss.backward() will fail if there is a bug in the TP implementation
254-
inputs = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
255-
labels = torch.randint(0, model.config.vocab_size, (2, 10), device=model.device)
256-
loss = model(inputs, labels=labels).loss
272+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
273+
dist.barrier()
274+
model_tp.train()
275+
276+
device = model_tp.device
277+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
278+
model = model.to(device)
279+
model.train()
280+
281+
batch_size, seq_length = 2, 10
282+
torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic
283+
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
284+
labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
285+
286+
outputs = model(input_ids, labels=labels)
287+
loss = outputs.loss
257288
loss.backward()
258289

259-
torch.distributed.barrier()
290+
outputs_tp = model_tp(input_ids, labels=labels)
291+
loss_tp = outputs_tp.loss
292+
loss_tp.backward()
260293

294+
assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), (
295+
f"TP and non-TP model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}"
296+
)
261297

262-
def _test_model_generate_impl(rank):
263-
"""Implementation of test_model_generate for distributed execution."""
264-
model_id = "JackFram/llama-68m"
298+
# Compare gradients for matching parameters
299+
# Note: TP model may have sharded parameters (DTensors), so we slice the reference gradient to match
300+
for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()):
301+
if param.grad is not None and param_tp.grad is not None:
302+
grad = param.grad
303+
grad_tp = param_tp.grad
265304

266-
int(os.environ["RANK"])
267-
int(os.environ["WORLD_SIZE"])
305+
if isinstance(param_tp.data, dist.tensor.DTensor):
306+
placement = param_tp.data.placements[0]
307+
if hasattr(placement, "dim") and placement.dim is not None:
308+
grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim)
309+
else:
310+
grad_shard = grad
311+
else:
312+
grad_shard = grad
268313

269-
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
270-
torch.distributed.barrier()
314+
grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp
271315

272-
model.forward = torch.compile(model.forward)
316+
assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), (
317+
f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}"
318+
)
273319

274-
has_dtensor = 0
275-
for name, parameter in model.named_parameters():
276-
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
277-
has_dtensor = 1
278-
break
320+
dist.barrier()
279321

280-
assert has_dtensor == 1, "TP model must has DTensor"
281322

282-
tokenizer = AutoTokenizer.from_pretrained(model_id)
323+
def _test_model_dense_forward_compile_impl(rank, mode):
324+
"""Implementation for comparing TP and non-TP model outputs with torch.compile."""
325+
model_id = "JackFram/llama-68m"
326+
327+
torch.manual_seed(0)
328+
329+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
283330
prompt = "Can I help"
331+
inputs = tokenizer(prompt, return_tensors="pt")
284332

285-
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
286-
outputs = model.generate(inputs, max_new_tokens=10, cache_implementation="static")
333+
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
334+
dist.barrier()
335+
if mode == "eval":
336+
model_tp.eval()
337+
else:
338+
model_tp.train()
287339

288-
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
289-
assert output_text[0].startswith(prompt), f"Expected output to start with '{prompt}', got '{output_text[0]}'"
340+
device = model_tp.device
341+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
342+
model = model.to(device)
290343

291-
torch.distributed.barrier()
344+
if mode == "eval":
345+
model.eval()
346+
else:
347+
model.train()
292348

349+
# Compile both models
350+
model.forward = torch.compile(model.forward)
351+
model_tp.forward = torch.compile(model_tp.forward)
352+
353+
input_ids = inputs.input_ids.to(device)
354+
355+
with torch.no_grad():
356+
outputs = model(input_ids)
357+
logits = outputs.logits
358+
359+
outputs_tp = model_tp(input_ids)
360+
logits_tp = outputs_tp.logits
361+
362+
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
363+
f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
364+
)
365+
366+
dist.barrier()
293367

294-
def _test_model_save_impl(rank, tmp_dir, is_torchrun):
368+
369+
def _test_model_dense_save_impl(rank, tmp_dir):
295370
"""Implementation of test_model_save for distributed execution."""
296371
model_id = "JackFram/llama-68m"
297-
kwargs = {}
298372

299-
if os.environ.get("RANK", None) is not None:
300-
kwargs["tp_plan"] = "auto"
373+
if dist.is_initialized():
374+
kwargs = {"tp_plan": "auto"}
301375
result_dir = f"{tmp_dir}/tp"
302376
else:
377+
kwargs = {}
303378
result_dir = f"{tmp_dir}/nontp"
304379

305380
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
@@ -312,46 +387,68 @@ class TestTensorParallelBase(TestCasePlus):
312387
nproc_per_node = None
313388

314389
@require_torch_multi_accelerator
315-
def test_model_forward(self):
390+
def test_model_dense_forward_eval(self):
391+
"""Test that TP and non-TP models produce the same outputs in eval mode."""
392+
if self.nproc_per_node is None:
393+
self.skipTest("nproc_per_node not set")
394+
if backend_device_count(torch_device) < self.nproc_per_node:
395+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
396+
397+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("eval")
398+
399+
@require_torch_multi_accelerator
400+
def test_model_dense_forward_train(self):
401+
"""Test that TP and non-TP models produce the same outputs in train mode."""
402+
if self.nproc_per_node is None:
403+
self.skipTest("nproc_per_node not set")
404+
if backend_device_count(torch_device) < self.nproc_per_node:
405+
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
406+
407+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("train")
408+
409+
@require_torch_multi_accelerator
410+
def test_model_dense_backward_pass(self):
316411
if self.nproc_per_node is None:
317412
self.skipTest("nproc_per_node not set")
318413
if backend_device_count(torch_device) < self.nproc_per_node:
319414
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
320415

321-
init_distributed(tp=self.nproc_per_node)(_test_model_forward_impl)()
416+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_backward_pass_impl)()
322417

323418
@require_torch_multi_accelerator
324-
def test_model_backward_pass(self):
419+
def test_model_dense_forward_compile_eval(self):
420+
"""Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
325421
if self.nproc_per_node is None:
326422
self.skipTest("nproc_per_node not set")
327423
if backend_device_count(torch_device) < self.nproc_per_node:
328424
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
329425

330-
init_distributed(tp=self.nproc_per_node)(_test_model_backward_pass_impl)()
426+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("eval")
331427

332428
@require_torch_multi_accelerator
333-
def test_model_generate(self):
429+
def test_model_dense_forward_compile_train(self):
430+
"""Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
334431
if self.nproc_per_node is None:
335432
self.skipTest("nproc_per_node not set")
336433
if backend_device_count(torch_device) < self.nproc_per_node:
337434
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
338435

339-
init_distributed(tp=self.nproc_per_node)(_test_model_generate_impl)()
436+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("train")
340437

341438
@require_huggingface_hub_greater_or_equal("0.31.4")
342439
@require_torch_multi_accelerator
343-
def test_model_save(self):
440+
def test_model_dense_save(self):
344441
if self.nproc_per_node is None:
345442
self.skipTest("nproc_per_node not set")
346443
if backend_device_count(torch_device) < self.nproc_per_node:
347444
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
348445

349446
with tempfile.TemporaryDirectory() as tmp_dir:
350447
# First run with TP (distributed)
351-
init_distributed(tp=self.nproc_per_node)(_test_model_save_impl)(tmp_dir, True)
448+
init_distributed(tp=self.nproc_per_node)(_test_model_dense_save_impl)(tmp_dir)
352449

353450
# Then run without TP (non-distributed)
354-
_test_model_save_impl(0, tmp_dir, False)
451+
_test_model_dense_save_impl(0, tmp_dir)
355452

356453
non_tp_model_path = os.path.join(tmp_dir, "nontp")
357454
tp_model_path = os.path.join(tmp_dir, "tp")

0 commit comments

Comments
 (0)