1515# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
1616# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
1717# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
18- # Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_forward
19-
18+ # Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
19+ # Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
2020import os
2121import tempfile
2222import warnings
2323
2424from safetensors import safe_open
2525
2626from transformers import AutoModelForCausalLM , AutoTokenizer , is_torch_available
27- from transformers .integrations .tensor_parallel import get_packed_weights , repack_weights
27+ from transformers .integrations .tensor_parallel import get_packed_weights , get_tensor_shard , repack_weights
2828from transformers .testing_utils import (
2929 TestCasePlus ,
3030 backend_device_count ,
3737
3838if is_torch_available ():
3939 import torch
40+ import torch .distributed as dist
4041 import torch .multiprocessing as mp
4142
4243
@@ -53,14 +54,14 @@ def setup_dist_env(rank, world_size, port):
5354
5455 if torch .cuda .is_available ():
5556 torch .cuda .set_device (rank )
56- torch . distributed .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
57+ dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
5758 else :
58- torch . distributed .init_process_group (backend = "gloo" , rank = rank , world_size = world_size )
59+ dist .init_process_group (backend = "gloo" , rank = rank , world_size = world_size )
5960
6061 func (rank , * func_args , ** func_kwargs )
6162
62- torch . distributed .barrier ()
63- torch . distributed .destroy_process_group ()
63+ dist .barrier ()
64+ dist .destroy_process_group ()
6465
6566
6667def init_distributed (tp : int ):
@@ -211,95 +212,169 @@ def test_tp_plan_none_handling(self):
211212
212213
213214# ====== TEST FUNCTIONS ======
214- def _test_model_forward_impl (rank ):
215- """Implementation of test_model_forward for distributed execution ."""
215+ def _test_model_dense_forward_impl (rank , mode ):
216+ """Implementation for comparing TP and non-TP model outputs ."""
216217 model_id = "JackFram/llama-68m"
217218
218- int (os .environ ["RANK" ])
219- int (os .environ ["WORLD_SIZE" ])
220- model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
221- torch .distributed .barrier ()
222-
223- has_dtensor = 0
224- for name , parameter in model .named_parameters ():
225- if isinstance (parameter .data , torch .distributed .tensor .DTensor ):
226- has_dtensor = 1
227- break
228-
229- assert has_dtensor == 1 , "TP model must has DTensor"
219+ # Ensure same random seed for reproducibility
220+ torch .manual_seed (0 )
230221
222+ # Load tokenizer and prepare inputs - same for both models
231223 tokenizer = AutoTokenizer .from_pretrained (model_id , use_fast = False )
232224 prompt = "Can I help"
225+ inputs = tokenizer (prompt , return_tensors = "pt" )
226+
227+ # Load TP model first to determine device
228+ model_tp = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
229+ dist .barrier ()
230+ if mode == "eval" :
231+ model_tp .eval ()
232+ else :
233+ model_tp .train ()
234+
235+ # Load non-TP model and move to same device as TP model
236+ device = model_tp .device
237+ model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" )
238+ model = model .to (device )
239+
240+ if mode == "eval" :
241+ model .eval ()
242+ else :
243+ model .train ()
244+
245+ # Prepare inputs on the same device
246+ input_ids = inputs .input_ids .to (device )
247+
248+ # Run forward pass on both models
249+ with torch .no_grad ():
250+ # Non-TP model output
251+ outputs = model (input_ids )
252+ logits = outputs .logits
253+
254+ # TP model output
255+ outputs_tp = model_tp (input_ids )
256+ logits_tp = outputs_tp .logits
233257
234- inputs = tokenizer (prompt , return_tensors = "pt" ).input_ids .to (model .device )
235- outputs = model (inputs )
258+ # Compare outputs - they should match
259+ assert torch .allclose (logits , logits_tp , atol = 1e-5 , rtol = 1e-5 ), (
260+ f"TP and non-TP model outputs differ. Max diff: { (logits - logits_tp ).abs ().max ().item ()} | Min diff: { (logits - logits_tp ).abs ().min ().item ()} "
261+ )
236262
237- next_token_logits = outputs [0 ][:, - 1 , :]
238- next_token = torch .argmax (next_token_logits , dim = - 1 )
239- response = tokenizer .decode (next_token )
240- assert response == "with"
241- print ("response:" , response )
242- torch .distributed .barrier ()
263+ dist .barrier ()
243264
244265
245- def _test_model_backward_pass_impl (rank ):
246- """Implementation of test_model_backward_pass for distributed execution ."""
266+ def _test_model_dense_backward_pass_impl (rank ):
267+ """Implementation for comparing TP and non-TP model backward passes ."""
247268 model_id = "JackFram/llama-68m"
248269
249- model = AutoModelForCausalLM .from_pretrained (model_id , dtype = torch .float32 , tp_plan = "auto" )
250- torch .distributed .barrier ()
270+ torch .manual_seed (0 )
251271
252- # Dummy forward and backward pass
253- # Note that loss.backward() will fail if there is a bug in the TP implementation
254- inputs = torch .randint (0 , model .config .vocab_size , (2 , 10 ), device = model .device )
255- labels = torch .randint (0 , model .config .vocab_size , (2 , 10 ), device = model .device )
256- loss = model (inputs , labels = labels ).loss
272+ model_tp = AutoModelForCausalLM .from_pretrained (model_id , dtype = torch .float32 , tp_plan = "auto" )
273+ dist .barrier ()
274+ model_tp .train ()
275+
276+ device = model_tp .device
277+ model = AutoModelForCausalLM .from_pretrained (model_id , dtype = torch .float32 )
278+ model = model .to (device )
279+ model .train ()
280+
281+ batch_size , seq_length = 2 , 10
282+ torch .manual_seed (42 ) # Different seed for inputs to ensure they're deterministic
283+ input_ids = torch .randint (0 , model .config .vocab_size , (batch_size , seq_length ), device = device )
284+ labels = torch .randint (0 , model .config .vocab_size , (batch_size , seq_length ), device = device )
285+
286+ outputs = model (input_ids , labels = labels )
287+ loss = outputs .loss
257288 loss .backward ()
258289
259- torch .distributed .barrier ()
290+ outputs_tp = model_tp (input_ids , labels = labels )
291+ loss_tp = outputs_tp .loss
292+ loss_tp .backward ()
260293
294+ assert torch .allclose (loss , loss_tp , atol = 1e-5 , rtol = 1e-5 ), (
295+ f"TP and non-TP model losses differ. Non-TP loss: { loss .item ()} , TP loss: { loss_tp .item ()} , Diff: { (loss - loss_tp ).abs ().item ()} "
296+ )
261297
262- def _test_model_generate_impl (rank ):
263- """Implementation of test_model_generate for distributed execution."""
264- model_id = "JackFram/llama-68m"
298+ # Compare gradients for matching parameters
299+ # Note: TP model may have sharded parameters (DTensors), so we slice the reference gradient to match
300+ for (name , param ), (name_tp , param_tp ) in zip (model .named_parameters (), model_tp .named_parameters ()):
301+ if param .grad is not None and param_tp .grad is not None :
302+ grad = param .grad
303+ grad_tp = param_tp .grad
265304
266- int (os .environ ["RANK" ])
267- int (os .environ ["WORLD_SIZE" ])
305+ if isinstance (param_tp .data , dist .tensor .DTensor ):
306+ placement = param_tp .data .placements [0 ]
307+ if hasattr (placement , "dim" ) and placement .dim is not None :
308+ grad_shard = get_tensor_shard (grad , grad , param_tp .data .device_mesh , rank , placement .dim )
309+ else :
310+ grad_shard = grad
311+ else :
312+ grad_shard = grad
268313
269- model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
270- torch .distributed .barrier ()
314+ grad_tp_local = grad_tp .to_local () if isinstance (grad_tp , dist .tensor .DTensor ) else grad_tp
271315
272- model .forward = torch .compile (model .forward )
316+ assert torch .allclose (grad_shard .cpu (), grad_tp_local .cpu (), atol = 1e-5 , rtol = 1e-5 ), (
317+ f"Gradients differ for parameter { name } . Max diff: { (grad_shard .cpu () - grad_tp_local .cpu ()).abs ().max ().item ()} | Min diff: { (grad_shard .cpu () - grad_tp_local .cpu ()).abs ().min ().item ()} "
318+ )
273319
274- has_dtensor = 0
275- for name , parameter in model .named_parameters ():
276- if isinstance (parameter .data , torch .distributed .tensor .DTensor ):
277- has_dtensor = 1
278- break
320+ dist .barrier ()
279321
280- assert has_dtensor == 1 , "TP model must has DTensor"
281322
282- tokenizer = AutoTokenizer .from_pretrained (model_id )
323+ def _test_model_dense_forward_compile_impl (rank , mode ):
324+ """Implementation for comparing TP and non-TP model outputs with torch.compile."""
325+ model_id = "JackFram/llama-68m"
326+
327+ torch .manual_seed (0 )
328+
329+ tokenizer = AutoTokenizer .from_pretrained (model_id , use_fast = False )
283330 prompt = "Can I help"
331+ inputs = tokenizer (prompt , return_tensors = "pt" )
284332
285- inputs = tokenizer (prompt , return_tensors = "pt" ).input_ids .to (model .device )
286- outputs = model .generate (inputs , max_new_tokens = 10 , cache_implementation = "static" )
333+ model_tp = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" , tp_plan = "auto" )
334+ dist .barrier ()
335+ if mode == "eval" :
336+ model_tp .eval ()
337+ else :
338+ model_tp .train ()
287339
288- output_text = tokenizer .batch_decode (outputs , skip_special_tokens = True )
289- assert output_text [0 ].startswith (prompt ), f"Expected output to start with '{ prompt } ', got '{ output_text [0 ]} '"
340+ device = model_tp .device
341+ model = AutoModelForCausalLM .from_pretrained (model_id , dtype = "auto" )
342+ model = model .to (device )
290343
291- torch .distributed .barrier ()
344+ if mode == "eval" :
345+ model .eval ()
346+ else :
347+ model .train ()
292348
349+ # Compile both models
350+ model .forward = torch .compile (model .forward )
351+ model_tp .forward = torch .compile (model_tp .forward )
352+
353+ input_ids = inputs .input_ids .to (device )
354+
355+ with torch .no_grad ():
356+ outputs = model (input_ids )
357+ logits = outputs .logits
358+
359+ outputs_tp = model_tp (input_ids )
360+ logits_tp = outputs_tp .logits
361+
362+ assert torch .allclose (logits , logits_tp , atol = 1e-5 , rtol = 1e-5 ), (
363+ f"TP and non-TP model outputs differ. Max diff: { (logits - logits_tp ).abs ().max ().item ()} | Min diff: { (logits - logits_tp ).abs ().min ().item ()} "
364+ )
365+
366+ dist .barrier ()
293367
294- def _test_model_save_impl (rank , tmp_dir , is_torchrun ):
368+
369+ def _test_model_dense_save_impl (rank , tmp_dir ):
295370 """Implementation of test_model_save for distributed execution."""
296371 model_id = "JackFram/llama-68m"
297- kwargs = {}
298372
299- if os . environ . get ( "RANK" , None ) is not None :
300- kwargs [ "tp_plan" ] = "auto"
373+ if dist . is_initialized () :
374+ kwargs = { "tp_plan" : "auto" }
301375 result_dir = f"{ tmp_dir } /tp"
302376 else :
377+ kwargs = {}
303378 result_dir = f"{ tmp_dir } /nontp"
304379
305380 model = AutoModelForCausalLM .from_pretrained (model_id , ** kwargs )
@@ -312,46 +387,68 @@ class TestTensorParallelBase(TestCasePlus):
312387 nproc_per_node = None
313388
314389 @require_torch_multi_accelerator
315- def test_model_forward (self ):
390+ def test_model_dense_forward_eval (self ):
391+ """Test that TP and non-TP models produce the same outputs in eval mode."""
392+ if self .nproc_per_node is None :
393+ self .skipTest ("nproc_per_node not set" )
394+ if backend_device_count (torch_device ) < self .nproc_per_node :
395+ self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
396+
397+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_forward_impl )("eval" )
398+
399+ @require_torch_multi_accelerator
400+ def test_model_dense_forward_train (self ):
401+ """Test that TP and non-TP models produce the same outputs in train mode."""
402+ if self .nproc_per_node is None :
403+ self .skipTest ("nproc_per_node not set" )
404+ if backend_device_count (torch_device ) < self .nproc_per_node :
405+ self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
406+
407+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_forward_impl )("train" )
408+
409+ @require_torch_multi_accelerator
410+ def test_model_dense_backward_pass (self ):
316411 if self .nproc_per_node is None :
317412 self .skipTest ("nproc_per_node not set" )
318413 if backend_device_count (torch_device ) < self .nproc_per_node :
319414 self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
320415
321- init_distributed (tp = self .nproc_per_node )(_test_model_forward_impl )()
416+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_backward_pass_impl )()
322417
323418 @require_torch_multi_accelerator
324- def test_model_backward_pass (self ):
419+ def test_model_dense_forward_compile_eval (self ):
420+ """Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
325421 if self .nproc_per_node is None :
326422 self .skipTest ("nproc_per_node not set" )
327423 if backend_device_count (torch_device ) < self .nproc_per_node :
328424 self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
329425
330- init_distributed (tp = self .nproc_per_node )(_test_model_backward_pass_impl )( )
426+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_forward_compile_impl )( "eval" )
331427
332428 @require_torch_multi_accelerator
333- def test_model_generate (self ):
429+ def test_model_dense_forward_compile_train (self ):
430+ """Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
334431 if self .nproc_per_node is None :
335432 self .skipTest ("nproc_per_node not set" )
336433 if backend_device_count (torch_device ) < self .nproc_per_node :
337434 self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
338435
339- init_distributed (tp = self .nproc_per_node )(_test_model_generate_impl )( )
436+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_forward_compile_impl )( "train" )
340437
341438 @require_huggingface_hub_greater_or_equal ("0.31.4" )
342439 @require_torch_multi_accelerator
343- def test_model_save (self ):
440+ def test_model_dense_save (self ):
344441 if self .nproc_per_node is None :
345442 self .skipTest ("nproc_per_node not set" )
346443 if backend_device_count (torch_device ) < self .nproc_per_node :
347444 self .skipTest (f"Need at least { self .nproc_per_node } devices, have { backend_device_count (torch_device )} " )
348445
349446 with tempfile .TemporaryDirectory () as tmp_dir :
350447 # First run with TP (distributed)
351- init_distributed (tp = self .nproc_per_node )(_test_model_save_impl )(tmp_dir , True )
448+ init_distributed (tp = self .nproc_per_node )(_test_model_dense_save_impl )(tmp_dir )
352449
353450 # Then run without TP (non-distributed)
354- _test_model_save_impl (0 , tmp_dir , False )
451+ _test_model_dense_save_impl (0 , tmp_dir )
355452
356453 non_tp_model_path = os .path .join (tmp_dir , "nontp" )
357454 tp_model_path = os .path .join (tmp_dir , "tp" )
0 commit comments