@@ -311,127 +311,82 @@ def _decode(
311
311
return self ._token_forward (model_input .input_ids , infer_state )
312
312
313
313
@torch .no_grad ()
314
- def microbatch_overlap_decode (self , batch : DecodeMicroBatch , batch1 : DecodeMicroBatch ):
315
- assert batch .batch_size == batch1 .batch_size
316
- assert batch .mem_indexes .is_cuda
317
- assert batch1 .mem_indexes .is_cuda
318
- input_ids , input_ids1 = batch .input_ids , batch1 .input_ids
319
-
320
- def create_inferstate (cur_batch : DecodeMicroBatch , batch_index ):
321
- infer_state = self .infer_state_class ()
322
- infer_state .is_prefill = False
323
- infer_state .batch_size = cur_batch .batch_size
324
- infer_state .total_token_num = cur_batch .total_token_num
325
- infer_state .max_len_in_batch = cur_batch .max_len_in_batch
326
- infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
327
- assert cur_batch .b_req_idx .shape [0 ] == cur_batch .b_seq_len .shape [0 ]
328
- infer_state .b_req_idx = cur_batch .b_req_idx
329
- infer_state .b_seq_len = cur_batch .b_seq_len
330
- infer_state .multimodal_params = None
331
- infer_state .microbatch_index = batch_index
332
-
333
- infer_state .mem_manager = self .mem_manager
334
- infer_state .req_manager = self .req_manager
335
-
336
- infer_state .mem_index = cur_batch .mem_indexes
337
- infer_state .kv_buffer_shapedtype = (
338
- (cur_batch .batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
339
- self .data_type ,
340
- )
341
- infer_state .dist_group = dist_group_manager .get_group (batch_index )
342
- copy_kv_index_to_req (
343
- self .req_manager .req_to_token_indexs , cur_batch .b_req_idx , cur_batch .b_seq_len , infer_state .mem_index
344
- )
345
- return infer_state
314
+ def microbatch_overlap_decode (self , model_input0 : ModelInput , model_input1 : ModelInput ):
315
+ assert model_input0 .batch_size == model_input1 .batch_size
316
+ assert model_input0 .mem_indexes .is_cuda
317
+ assert model_input1 .mem_indexes .is_cuda
318
+ input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
346
319
347
- infer_state = create_inferstate (batch , 0 )
348
- infer_state1 = create_inferstate (batch1 , 1 )
320
+ infer_state0 = self ._create_inferstate (model_input0 , 0 )
321
+ copy_kv_index_to_req (
322
+ self .req_manager .req_to_token_indexs , model_input0 .b_req_idx , model_input0 .b_seq_len , infer_state0 .mem_index
323
+ )
324
+ infer_state0 .init_some_extra_state (self , input_ids0 )
349
325
350
- infer_state .init_some_extra_state (self , input_ids )
326
+ infer_state1 = self ._create_inferstate (model_input1 , 1 )
327
+ copy_kv_index_to_req (
328
+ self .req_manager .req_to_token_indexs , model_input1 .b_req_idx , model_input1 .b_seq_len , infer_state1 .mem_index
329
+ )
351
330
infer_state1 .init_some_extra_state (self , input_ids1 )
352
331
353
- batch_size = batch .batch_size
354
- max_len_in_batch = max (batch .max_len_in_batch , batch1 .max_len_in_batch )
332
+ batch_size = model_input0 .batch_size
333
+ max_len_in_batch = max (model_input0 .max_len_in_batch , model_input1 .max_len_in_batch )
355
334
356
335
if self .graph is not None and self .graph .can_run (batch_size , max_len_in_batch ):
357
336
if self .graph .need_capture (batch_size ):
358
- infer_state .is_cuda_graph = True
337
+ infer_state0 .is_cuda_graph = True
359
338
infer_state1 .is_cuda_graph = True
360
339
361
- predict_logits , predict_logits1 = self .graph .capture_decode (
340
+ model_output0 , model_output1 = self .graph .capture_decode (
362
341
self ._overlap_tpsp_token_forward ,
363
- input_ids ,
364
- infer_state ,
342
+ input_ids0 ,
343
+ infer_state0 ,
365
344
input_ids1 = input_ids1 ,
366
345
infer_state1 = infer_state1 ,
367
346
)
368
347
else :
369
- predict_logits , predict_logits1 = self .graph .replay (
370
- input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
348
+ model_output0 , model_output1 = self .graph .replay (
349
+ input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
371
350
)
372
351
else :
373
- predict_logits , predict_logits1 = self ._overlap_tpsp_token_forward (
374
- input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
352
+ model_output0 , model_output1 = self ._overlap_tpsp_token_forward (
353
+ input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
375
354
)
376
- return predict_logits , predict_logits1
355
+ return model_output0 , model_output1
377
356
378
357
@torch .no_grad ()
379
- def microbatch_overlap_prefill (self , batch : PrefillMicroBatch , batch1 : PrefillMicroBatch ):
380
- assert batch .mem_indexes .is_cuda
381
- assert batch1 .mem_indexes .is_cuda
382
- input_ids , input_ids1 = batch .input_ids , batch1 .input_ids
383
-
384
- def create_inferstate (cur_batch : PrefillMicroBatch , batch_index ):
385
- infer_state = self .infer_state_class ()
386
- infer_state .is_prefill = True
387
- infer_state .is_token_healing = self .is_token_healing
388
- infer_state .return_all_prompt_logics = self .return_all_prompt_logics
389
- infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
390
- infer_state .batch_size = cur_batch .batch_size
391
- infer_state .total_token_num = cur_batch .total_token_num
392
- infer_state .max_len_in_batch = cur_batch .max_len_in_batch
393
- assert cur_batch .b_req_idx .shape [0 ] == cur_batch .b_seq_len .shape [0 ]
394
- infer_state .b_req_idx = cur_batch .b_req_idx
395
- infer_state .b_seq_len = cur_batch .b_seq_len
396
- if cur_batch .b_ready_cache_len is not None :
397
- infer_state .b_ready_cache_len = cur_batch .b_ready_cache_len
398
- else :
399
- infer_state .b_ready_cache_len = torch .zeros_like (
400
- cur_batch .b_seq_len , dtype = cur_batch .b_seq_len .dtype , device = cur_batch .b_seq_len .device
401
- )
402
- infer_state .multimodal_params = cur_batch .multimodal_params
403
- infer_state .microbatch_index = batch_index
358
+ def microbatch_overlap_prefill (self , model_input0 : ModelInput , model_input1 : ModelInput ):
359
+ assert model_input0 .mem_indexes .is_cuda
360
+ assert model_input1 .mem_indexes .is_cuda
361
+ input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
404
362
405
- infer_state .mem_manager = self .mem_manager
406
- infer_state .req_manager = self .req_manager
407
-
408
- infer_state .mem_index = cur_batch .mem_indexes
409
- infer_state .kv_buffer_shapedtype = (
410
- (cur_batch .input_ids .shape [0 ], self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
411
- self .data_type ,
412
- )
413
- infer_state .dist_group = dist_group_manager .get_group (batch_index )
414
- init_req_to_token_indexes (
415
- self .req_manager .req_to_token_indexs ,
416
- cur_batch .b_req_idx ,
417
- cur_batch .b_seq_len ,
418
- infer_state .b_ready_cache_len ,
419
- cur_batch .max_len_in_batch ,
420
- infer_state .mem_index ,
421
- )
422
- return infer_state
423
-
424
- infer_state = create_inferstate (batch , 0 )
425
- infer_state1 = create_inferstate (batch1 , 1 )
426
-
427
- infer_state .init_some_extra_state (self , input_ids )
363
+ infer_state0 = self ._create_inferstate (model_input0 , 0 )
364
+ init_req_to_token_indexes (
365
+ self .req_manager .req_to_token_indexs ,
366
+ model_input0 .b_req_idx ,
367
+ model_input0 .b_seq_len ,
368
+ infer_state0 .b_ready_cache_len ,
369
+ model_input0 .max_len_in_batch ,
370
+ infer_state0 .mem_index ,
371
+ )
372
+ infer_state0 .init_some_extra_state (self , input_ids0 )
373
+
374
+ infer_state1 = self ._create_inferstate (model_input1 , 1 )
375
+ init_req_to_token_indexes (
376
+ self .req_manager .req_to_token_indexs ,
377
+ model_input1 .b_req_idx ,
378
+ model_input1 .b_seq_len ,
379
+ infer_state1 .b_ready_cache_len ,
380
+ model_input1 .max_len_in_batch ,
381
+ infer_state1 .mem_index ,
382
+ )
428
383
infer_state1 .init_some_extra_state (self , input_ids1 )
429
384
430
- predict_logits , predict_logits1 = self ._overlap_tpsp_context_forward (
431
- input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
385
+ model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
386
+ input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
432
387
)
433
388
dist_group_manager .clear_deepep_buffer ()
434
- return predict_logits , predict_logits1
389
+ return model_output0 , model_output1
435
390
436
391
@final
437
392
def _context_forward (self , input_ids , infer_state : InferStateInfo ):
@@ -508,9 +463,21 @@ def _overlap_tpsp_token_forward(
508
463
predict_logits , predict_logits1 = self .post_infer .overlap_tpsp_token_forward (
509
464
input_embs , input_embs1 , infer_state , infer_state1 , self .pre_post_weight
510
465
)
511
-
466
+
512
467
g_cache_manager .cache_env_out ()
513
- return predict_logits , predict_logits1
468
+ is_return_hidden_states = self .spec_algo .is_mtp () or (
469
+ self .spec_algo .is_mtp_module () and not self .last_mtp_module
470
+ )
471
+ model_output = ModelOutput (
472
+ logits = predict_logits ,
473
+ hidden_states = input_embs if is_return_hidden_states else None ,
474
+ )
475
+
476
+ model_output1 = ModelOutput (
477
+ logits = predict_logits1 ,
478
+ hidden_states = input_embs1 if is_return_hidden_states else None ,
479
+ )
480
+ return model_output , model_output1
514
481
515
482
@final
516
483
def _overlap_tpsp_context_forward (
@@ -528,7 +495,21 @@ def _overlap_tpsp_context_forward(
528
495
input_embs , input_embs1 , infer_state , infer_state1 , self .pre_post_weight
529
496
)
530
497
g_cache_manager .cache_env_out ()
531
- return predict_logits , predict_logits1
498
+
499
+ is_return_hidden_states = self .spec_algo .is_mtp () or (
500
+ self .spec_algo .is_mtp_module () and not self .last_mtp_module
501
+ )
502
+ model_output = ModelOutput (
503
+ logits = predict_logits ,
504
+ hidden_states = input_embs if is_return_hidden_states else None ,
505
+ )
506
+
507
+ model_output1 = ModelOutput (
508
+ logits = predict_logits1 ,
509
+ hidden_states = input_embs1 if is_return_hidden_states else None ,
510
+ )
511
+
512
+ return model_output , model_output1
532
513
533
514
@final
534
515
@torch .no_grad ()
0 commit comments