@@ -475,24 +475,23 @@ def test_two_sequences_finish_same_time_as_new_arrive(
475475 )
476476
477477
478- @pytest .mark .a
479478@pytest .mark .cb
480479@pytest .mark .parametrize ("model" , get_spyre_model_list ())
481480@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
482- def test_new_sequence_joins_during_decode (
483- model : str , backend : str , monkeypatch : pytest .MonkeyPatch ):
481+ def test_new_sequence_joins_during_decode (model : str , backend : str ,
482+ monkeypatch : pytest .MonkeyPatch ):
484483 """ Scenario where a new sequence joins while decoding other sequences
485484
486485 Configuration:
487486 * max_num_seqs: 4
488487 * number of prompts: 4
489488 * 1: len = 49, max tokens = 119, step joining = 0
490489 * 2: len = 14, max tokens = 52, step joining = 0
491- * 3: len = 89, max tokens = 101 , step joining = 32
490+ * 3: len = 89, max tokens = 104 , step joining = 32
492491 * 4: len = 9, max tokens = 65, step joining = 131
493492 """
494493
495- seqs_max_tokens = [119 , 52 , 101 , 65 ]
494+ seqs_max_tokens = [119 , 52 , 104 , 65 ]
496495 prompts_lengths = [49 , 14 , 89 , 9 ]
497496 steps_add_reqs = [0 , 0 , 32 , 131 ]
498497 available_blocks = - 1 # no restriction
@@ -510,7 +509,6 @@ def test_new_sequence_joins_during_decode(
510509 },
511510 {
512511 # Prefill sequence 0
513- # total blocks in use: 1
514512 "step" : 1 ,
515513 "tkv" : 64 ,
516514 "waiting" : ["1" ],
@@ -521,7 +519,6 @@ def test_new_sequence_joins_during_decode(
521519 },
522520 {
523521 # Prefill sequence 1
524- # total blocks in use: 1 + 1 = 2
525522 "step" : 2 ,
526523 "tkv" : 64 ,
527524 "waiting" : [],
@@ -532,7 +529,6 @@ def test_new_sequence_joins_during_decode(
532529 },
533530 {
534531 # Decode sequences 0 and 1
535- # total blocks in use: 2 + 2 = 4
536532 "step" : 3 ,
537533 "tkv" : 65 ,
538534 "waiting" : [],
@@ -553,18 +549,16 @@ def test_new_sequence_joins_during_decode(
553549 },
554550 {
555551 # Prefill sequence 2
556- # total blocks in use: 4 + 2 (long prefill) = 6
557552 "step" : 33 ,
558553 "tkv" : 94 ,
559554 "waiting" : [],
560555 "running" : ["2" , "1" , "0" ],
561556 "request_outputs" : ["2" ],
562- "n_reserved_blocks" : 9 , # prefill (2 block) + 100 decode (2 block)
557+ "n_reserved_blocks" : 9 , # prefill (2 block) + 103 decode (2 block)
563558 "n_used_blocks" : 6
564559 },
565560 {
566561 # Decode sequences 0, 1, and 2
567- # total blocks in use: 6
568562 "step" : 34 ,
569563 "tkv" : 95 ,
570564 "waiting" : [],
@@ -576,7 +570,6 @@ def test_new_sequence_joins_during_decode(
576570 {
577571 # Sequence 1 finishes at step 54
578572 # (start step + 2 prefills + 51 decodes - 1) = 2 + 2 + 51 - 1 = 54
579- # total blocks in use: 6
580573 "step" : 54 ,
581574 "tkv" : 115 ,
582575 "waiting" : [],
@@ -588,7 +581,6 @@ def test_new_sequence_joins_during_decode(
588581 },
589582 {
590583 # Decode sequences 0 and 2
591- # total blocks in use: 4
592584 "step" : 55 ,
593585 "tkv" : 116 ,
594586 "waiting" : [],
@@ -599,7 +591,6 @@ def test_new_sequence_joins_during_decode(
599591 },
600592 {
601593 # Decode sequences 0 and 2, tkv arrives to new block
602- # total blocks in use: 6
603594 "step" : 68 ,
604595 "tkv" : 129 ,
605596 "waiting" : [],
@@ -611,7 +602,6 @@ def test_new_sequence_joins_during_decode(
611602 {
612603 # Sequence 0 finishes at step 121
613604 # (start step + 3 prefills + 118 decode - 1) = 1 + 3 + 118 - 1 = 121
614- # total blocks in use: 6
615605 "step" : 121 ,
616606 "tkv" : 182 ,
617607 "waiting" : [],
@@ -623,7 +613,6 @@ def test_new_sequence_joins_during_decode(
623613 },
624614 {
625615 # Decode sequence 2
626- # total blocks in use: 3
627616 "step" : 122 ,
628617 "tkv" : 183 ,
629618 "waiting" : [],
@@ -634,7 +623,6 @@ def test_new_sequence_joins_during_decode(
634623 },
635624 {
636625 # Sequence 3 joins: one iteration in waiting queue
637- # total blocks in use: 3
638626 "step" : 131 ,
639627 "tkv" : 192 ,
640628 "waiting" : ["3" ],
@@ -643,6 +631,71 @@ def test_new_sequence_joins_during_decode(
643631 "n_reserved_blocks" : 4 ,
644632 "n_used_blocks" : 3
645633 },
634+ {
635+ # Prefill sequence 3
636+ "step" : 132 ,
637+ "tkv" : 192 ,
638+ "waiting" : [],
639+ "running" : ["3" , "2" ],
640+ "request_outputs" : ["3" ],
641+ "n_reserved_blocks" : 8 , # prefill (3 blocks) + 64 decode (1 block)
642+ "n_used_blocks" : 6 # prefill (3 block)
643+ },
644+ {
645+ # Decode sequences 2 and 3
646+ "step" : 133 ,
647+ "tkv" : 193 ,
648+ "waiting" : [],
649+ "running" : ["3" , "2" ],
650+ "request_outputs" : ["3" , "2" ],
651+ "n_reserved_blocks" : 8 , # prefill (3 blocks) + 64 decode (1 block)
652+ "n_used_blocks" : 8 # 2 blocks extended, one for each sequence
653+ },
654+ {
655+ # Sequence 2 finishes at step 137
656+ # (start step + 2 prefills + 103 decodes) = 33 + 2 + 103 - 1 = 137
657+ "step" : 137 ,
658+ "tkv" : 197 ,
659+ "waiting" : [],
660+ "running" : ["3" ],
661+ "request_outputs" : ["3" , "2" ],
662+ "finished_requests" : ["2" ],
663+ "n_reserved_blocks" : 8 ,
664+ "n_used_blocks" : 8
665+ },
666+ {
667+ # Decode sequence 3
668+ "step" : 138 ,
669+ "tkv" : 70 ,
670+ "waiting" : [],
671+ "running" : ["3" ],
672+ "request_outputs" : ["3" ],
673+ # 6 blocks freed: finished sequence (4) + left padding stripping (2)
674+ "n_reserved_blocks" : 2 ,
675+ "n_used_blocks" : 2
676+ },
677+ {
678+ # Sequence 3 finishes at step 196
679+ # (start step + 1 prefills + 103 decodes) = 132 + 1 + 64 - 1 = 196
680+ "step" : 196 ,
681+ "tkv" : 128 ,
682+ "waiting" : [],
683+ "running" : [],
684+ "request_outputs" : ["3" ],
685+ "finished_requests" : ["3" ],
686+ "n_reserved_blocks" : 2 ,
687+ "n_used_blocks" : 2
688+ },
689+ {
690+ # Tkv should be cleared one step later
691+ "step" : 197 ,
692+ "tkv" : 0 ,
693+ "waiting" : [],
694+ "running" : [],
695+ "request_outputs" : [],
696+ "n_reserved_blocks" : 0 ,
697+ "n_used_blocks" : 0
698+ },
646699 ]
647700
648701 check_scheduler_inference_steps (
@@ -658,6 +711,7 @@ def test_new_sequence_joins_during_decode(
658711 use_cb = True ,
659712 )
660713
714+
661715@pytest .mark .cb
662716@pytest .mark .parametrize ("model" , get_spyre_model_list ())
663717@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
0 commit comments