@@ -33,6 +33,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
3333 steps_add_reqs = [0 , 0 , 0 ] # add all requests in the beginning
3434 available_blocks = - 1 # no restriction
3535 max_num_seqs = 2
36+ max_model_len = 256
3637
3738 checked_steps = [
3839 {
@@ -170,6 +171,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
170171 steps_add_reqs = steps_add_reqs ,
171172 checked_steps = checked_steps ,
172173 max_num_seqs = max_num_seqs ,
174+ max_model_len = max_model_len ,
173175 available_blocks = available_blocks ,
174176 use_cb = True ,
175177 )
@@ -197,6 +199,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
197199 steps_add_reqs = [0 , 0 , 0 ] # add all requests in the beginning
198200 available_blocks = - 1 # no restriction
199201 max_num_seqs = 2
202+ max_model_len = 256
200203
201204 checked_steps = [
202205 {
@@ -332,6 +335,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
332335 steps_add_reqs = steps_add_reqs ,
333336 checked_steps = checked_steps ,
334337 max_num_seqs = max_num_seqs ,
338+ max_model_len = max_model_len ,
335339 available_blocks = available_blocks ,
336340 use_cb = True ,
337341 )
@@ -358,6 +362,7 @@ def test_two_sequences_finish_same_time_as_new_arrive(
358362 steps_add_reqs = [0 , 0 , 31 ]
359363 available_blocks = - 1 # no restriction
360364 max_num_seqs = 2
365+ max_model_len = 256
361366
362367 checked_steps = [
363368 {
@@ -470,6 +475,270 @@ def test_two_sequences_finish_same_time_as_new_arrive(
470475 steps_add_reqs = steps_add_reqs ,
471476 checked_steps = checked_steps ,
472477 max_num_seqs = max_num_seqs ,
478+ max_model_len = max_model_len ,
479+ available_blocks = available_blocks ,
480+ use_cb = True ,
481+ )
482+
483+
484+ @pytest .mark .cb
485+ @pytest .mark .parametrize ("model" , get_spyre_model_list ())
486+ @pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
487+ def test_new_sequence_joins_during_decode (model : str , backend : str ,
488+ monkeypatch : pytest .MonkeyPatch ):
489+ """ Scenario where a new sequence joins while decoding other sequences
490+
491+ Configuration:
492+ * max_num_seqs: 4
493+ * number of prompts: 4
494+ * 1: len = 49, max tokens = 119, step joining = 0
495+ * 2: len = 14, max tokens = 52, step joining = 0
496+ * 3: len = 89, max tokens = 104, step joining = 32
497+ * 4: len = 9, max tokens = 64, step joining = 131
498+ """
499+ # TODO change to 65 max_tokens for last prompt if ever possible
500+
501+ seqs_max_tokens = [119 , 52 , 104 , 64 ]
502+ prompts_lengths = [49 , 14 , 89 , 9 ]
503+ steps_add_reqs = [0 , 0 , 32 , 131 ]
504+ available_blocks = - 1 # no restriction
505+ max_num_seqs = 4
506+ max_model_len = 256
507+
508+ checked_steps = [
509+ {
510+ "step" : 0 ,
511+ "tkv" : 0 ,
512+ "waiting" : ["0" , "1" ],
513+ "running" : [],
514+ "request_outputs" : [],
515+ "n_reserved_blocks" : 0 ,
516+ "n_used_blocks" : 0
517+ },
518+ {
519+ # Prefill sequence 0
520+ "step" : 1 ,
521+ "tkv" : 64 ,
522+ "waiting" : ["1" ],
523+ "running" : ["0" ],
524+ "request_outputs" : ["0" ],
525+ "n_reserved_blocks" : 3 , # prefill (1 block) + 119 decode (2 block)
526+ "n_used_blocks" : 1
527+ },
528+ {
529+ # Prefill sequence 1
530+ "step" : 2 ,
531+ "tkv" : 64 ,
532+ "waiting" : [],
533+ "running" : ["1" , "0" ],
534+ "request_outputs" : ["1" ],
535+ "n_reserved_blocks" : 5 , # prefill (1 block) + 51 decodes (1 block)
536+ "n_used_blocks" : 2
537+ },
538+ {
539+ # Decode sequences 0 and 1
540+ "step" : 3 ,
541+ "tkv" : 65 ,
542+ "waiting" : [],
543+ "running" : ["1" , "0" ],
544+ "request_outputs" : ["1" , "0" ],
545+ "n_reserved_blocks" : 5 ,
546+ "n_used_blocks" : 4 # 2 blocks extended, one for each sequence
547+ },
548+ {
549+ # Sequence 2 joins: one iteration in waiting queue
550+ "step" : 32 ,
551+ "tkv" : 94 ,
552+ "waiting" : ["2" ],
553+ "running" : ["1" , "0" ],
554+ "request_outputs" : ["1" , "0" ],
555+ "n_reserved_blocks" : 5 ,
556+ "n_used_blocks" : 4
557+ },
558+ {
559+ # Prefill sequence 2
560+ "step" : 33 ,
561+ "tkv" : 94 ,
562+ "waiting" : [],
563+ "running" : ["2" , "1" , "0" ],
564+ "request_outputs" : ["2" ],
565+ "n_reserved_blocks" : 9 , # prefill (2 block) + 103 decode (2 block)
566+ "n_used_blocks" : 6
567+ },
568+ {
569+ # Decode sequences 0, 1, and 2
570+ "step" : 34 ,
571+ "tkv" : 95 ,
572+ "waiting" : [],
573+ "running" : ["2" , "1" , "0" ],
574+ "request_outputs" : ["2" , "1" , "0" ],
575+ "n_reserved_blocks" : 9 ,
576+ "n_used_blocks" : 6
577+ },
578+ {
579+ # Sequence 1 finishes at step 54
580+ # (start step + 2 prefills + 51 decodes - 1) = 2 + 2 + 51 - 1 = 54
581+ "step" : 54 ,
582+ "tkv" : 115 ,
583+ "waiting" : [],
584+ "running" : ["2" , "0" ],
585+ "request_outputs" : ["2" , "1" , "0" ],
586+ "finished_requests" : ["1" ],
587+ "n_reserved_blocks" : 9 ,
588+ "n_used_blocks" : 6
589+ },
590+ {
591+ # Decode sequences 0 and 2
592+ "step" : 55 ,
593+ "tkv" : 116 ,
594+ "waiting" : [],
595+ "running" : ["2" , "0" ],
596+ "request_outputs" : ["2" , "0" ],
597+ "n_reserved_blocks" : 7 , # two blocks released
598+ "n_used_blocks" : 4 # two blocks released
599+ },
600+ {
601+ # Decode sequences 0 and 2, tkv arrives to new block
602+ "step" : 68 ,
603+ "tkv" : 129 ,
604+ "waiting" : [],
605+ "running" : ["2" , "0" ],
606+ "request_outputs" : ["2" , "0" ],
607+ "n_reserved_blocks" : 7 ,
608+ "n_used_blocks" : 6 # 2 blocks extended, one for each sequence
609+ },
610+ {
611+ # Sequence 0 finishes at step 121
612+ # (start step + 3 prefills + 118 decode - 1) = 1 + 3 + 118 - 1 = 121
613+ "step" : 121 ,
614+ "tkv" : 182 ,
615+ "waiting" : [],
616+ "running" : ["2" ],
617+ "request_outputs" : ["2" , "0" ],
618+ "finished_requests" : ["0" ],
619+ "n_reserved_blocks" : 7 ,
620+ "n_used_blocks" : 6
621+ },
622+ {
623+ # Decode sequence 2
624+ "step" : 122 ,
625+ "tkv" : 183 ,
626+ "waiting" : [],
627+ "running" : ["2" ],
628+ "request_outputs" : ["2" ],
629+ "n_reserved_blocks" : 4 , # 3 blocks released
630+ "n_used_blocks" : 3 # 3 blocks released
631+ },
632+ {
633+ # Sequence 3 joins: one iteration in waiting queue
634+ "step" : 131 ,
635+ "tkv" : 192 ,
636+ "waiting" : ["3" ],
637+ "running" : ["2" ],
638+ "request_outputs" : ["2" ],
639+ "n_reserved_blocks" : 4 ,
640+ "n_used_blocks" : 3
641+ },
642+ {
643+ # Prefill sequence 3
644+ "step" : 132 ,
645+ "tkv" : 192 ,
646+ "waiting" : [],
647+ "running" : ["3" , "2" ],
648+ "request_outputs" : ["3" ],
649+ "n_reserved_blocks" : 8 , # prefill (3 blocks) + 63 decode (1 block)
650+ "n_used_blocks" : 6 # prefill (3 block)
651+ },
652+ {
653+ # Decode sequences 2 and 3
654+ "step" : 133 ,
655+ "tkv" : 193 ,
656+ "waiting" : [],
657+ "running" : ["3" , "2" ],
658+ "request_outputs" : ["3" , "2" ],
659+ "n_reserved_blocks" : 8 ,
660+ "n_used_blocks" : 8 # 2 blocks extended, one for each sequence
661+ },
662+ {
663+ # Sequence 2 finishes at step 137
664+ # (start step + 2 prefills + 103 decodes) = 33 + 2 + 103 - 1 = 137
665+ "step" : 137 ,
666+ "tkv" : 197 ,
667+ "waiting" : [],
668+ "running" : ["3" ],
669+ "request_outputs" : ["3" , "2" ],
670+ "finished_requests" : ["2" ],
671+ "n_reserved_blocks" : 8 ,
672+ "n_used_blocks" : 8
673+ },
674+ {
675+ # Decode sequence 3
676+ "step" : 138 ,
677+ "tkv" : 70 ,
678+ "waiting" : [],
679+ "running" : ["3" ],
680+ "request_outputs" : ["3" ],
681+ # 6 blocks freed: finished sequence (4) + left padding stripping (2)
682+ "n_reserved_blocks" : 2 ,
683+ "n_used_blocks" : 2
684+ },
685+ {
686+ # Sequence 3 finishes at step 196
687+ # (start step + 1 prefills + 103 decodes) = 132 + 1 + 63 - 1 = 196
688+ "step" : 195 ,
689+ "tkv" : 127 ,
690+ "waiting" : [],
691+ "running" : [],
692+ "request_outputs" : ["3" ],
693+ "finished_requests" : ["3" ],
694+ "n_reserved_blocks" : 2 ,
695+ "n_used_blocks" : 2
696+ },
697+ {
698+ # Tkv should be cleared one step later
699+ "step" : 196 ,
700+ "tkv" : 0 ,
701+ "waiting" : [],
702+ "running" : [],
703+ "request_outputs" : [],
704+ "n_reserved_blocks" : 0 ,
705+ "n_used_blocks" : 0
706+ },
707+ # TODO this is when max_tokens = 65 for last prompt
708+ # {
709+ # # Sequence 3 finishes at step 196
710+ # # (start step + 1 prefills + 103 decodes) = 132 + 1 + 64 - 1 = 196
711+ # "step": 196,
712+ # "tkv": 128,
713+ # "waiting": [],
714+ # "running": [],
715+ # "request_outputs": ["3"],
716+ # "finished_requests": ["3"],
717+ # "n_reserved_blocks": 2,
718+ # "n_used_blocks": 2
719+ # },
720+ # {
721+ # # Tkv should be cleared one step later
722+ # "step": 197,
723+ # "tkv": 0,
724+ # "waiting": [],
725+ # "running": [],
726+ # "request_outputs": [],
727+ # "n_reserved_blocks": 0,
728+ # "n_used_blocks": 0
729+ # },
730+ ]
731+
732+ check_scheduler_inference_steps (
733+ model = model ,
734+ backend = backend ,
735+ monkeypatch = monkeypatch ,
736+ seqs_max_tokens = seqs_max_tokens ,
737+ prompts_lengths = prompts_lengths ,
738+ steps_add_reqs = steps_add_reqs ,
739+ checked_steps = checked_steps ,
740+ max_num_seqs = max_num_seqs ,
741+ max_model_len = max_model_len ,
473742 available_blocks = available_blocks ,
474743 use_cb = True ,
475744 )
@@ -494,6 +763,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
494763 steps_add_reqs = [0 , 0 ]
495764 available_blocks = - 1 # no restriction
496765 max_num_seqs = 2
766+ max_model_len = 256
497767
498768 checked_steps = [
499769 {
@@ -617,6 +887,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
617887 steps_add_reqs = steps_add_reqs ,
618888 checked_steps = checked_steps ,
619889 max_num_seqs = max_num_seqs ,
890+ max_model_len = max_model_len ,
620891 available_blocks = available_blocks ,
621892 use_cb = True ,
622893 )
@@ -642,6 +913,7 @@ def test_requested_tokens_not_fitting_remaining_space(
642913 steps_add_reqs = [0 , 0 , 0 ]
643914 available_blocks = - 1 # no restriction
644915 max_num_seqs = 2
916+ max_model_len = 256
645917
646918 checked_steps = [
647919 {
@@ -802,6 +1074,7 @@ def test_requested_tokens_not_fitting_remaining_space(
8021074 steps_add_reqs = steps_add_reqs ,
8031075 checked_steps = checked_steps ,
8041076 max_num_seqs = max_num_seqs ,
1077+ max_model_len = max_model_len ,
8051078 available_blocks = available_blocks ,
8061079 use_cb = True ,
8071080 )
@@ -830,6 +1103,8 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
8301103 # total number of blocks needed if scheduled together : 4 * (1 + 1) = 8
8311104 available_blocks = 8
8321105 max_num_seqs = 4
1106+ max_model_len = 256
1107+
8331108 checked_steps = [
8341109 {
8351110 "step" : 0 ,
@@ -933,6 +1208,7 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
9331208 steps_add_reqs = steps_add_reqs ,
9341209 checked_steps = checked_steps ,
9351210 max_num_seqs = max_num_seqs ,
1211+ max_model_len = max_model_len ,
9361212 available_blocks = available_blocks ,
9371213 use_cb = True ,
9381214 )
@@ -962,6 +1238,8 @@ def test_requests_use_more_than_available_blocks(
9621238 # total number of blocks needed if scheduled together : 4 * (1 + 1) = 8
9631239 available_blocks = 4
9641240 max_num_seqs = 4
1241+ max_model_len = 256
1242+
9651243 checked_steps = [
9661244 {
9671245 "step" : 0 ,
@@ -1090,6 +1368,7 @@ def test_requests_use_more_than_available_blocks(
10901368 steps_add_reqs = steps_add_reqs ,
10911369 checked_steps = checked_steps ,
10921370 max_num_seqs = max_num_seqs ,
1371+ max_model_len = max_model_len ,
10931372 available_blocks = available_blocks ,
10941373 use_cb = True ,
10951374 )
0 commit comments