Skip to content

Commit 608ba62

Browse files
committed
finished test
Signed-off-by: Sophie du Couédic <[email protected]>
1 parent 84db72b commit 608ba62

File tree

1 file changed

+71
-17
lines changed

1 file changed

+71
-17
lines changed

tests/e2e/test_spyre_cb_scheduler_steps.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -475,24 +475,23 @@ def test_two_sequences_finish_same_time_as_new_arrive(
475475
)
476476

477477

478-
@pytest.mark.a
479478
@pytest.mark.cb
480479
@pytest.mark.parametrize("model", get_spyre_model_list())
481480
@pytest.mark.parametrize("backend", get_spyre_backend_list())
482-
def test_new_sequence_joins_during_decode(
483-
model: str, backend: str, monkeypatch: pytest.MonkeyPatch):
481+
def test_new_sequence_joins_during_decode(model: str, backend: str,
482+
monkeypatch: pytest.MonkeyPatch):
484483
""" Scenario where a new sequence joins while decoding other sequences
485484
486485
Configuration:
487486
* max_num_seqs: 4
488487
* number of prompts: 4
489488
* 1: len = 49, max tokens = 119, step joining = 0
490489
* 2: len = 14, max tokens = 52, step joining = 0
491-
* 3: len = 89, max tokens = 101, step joining = 32
490+
* 3: len = 89, max tokens = 104, step joining = 32
492491
* 4: len = 9, max tokens = 65, step joining = 131
493492
"""
494493

495-
seqs_max_tokens = [119, 52, 101, 65]
494+
seqs_max_tokens = [119, 52, 104, 65]
496495
prompts_lengths = [49, 14, 89, 9]
497496
steps_add_reqs = [0, 0, 32, 131]
498497
available_blocks = -1 # no restriction
@@ -510,7 +509,6 @@ def test_new_sequence_joins_during_decode(
510509
},
511510
{
512511
# Prefill sequence 0
513-
# total blocks in use: 1
514512
"step": 1,
515513
"tkv": 64,
516514
"waiting": ["1"],
@@ -521,7 +519,6 @@ def test_new_sequence_joins_during_decode(
521519
},
522520
{
523521
# Prefill sequence 1
524-
# total blocks in use: 1 + 1 = 2
525522
"step": 2,
526523
"tkv": 64,
527524
"waiting": [],
@@ -532,7 +529,6 @@ def test_new_sequence_joins_during_decode(
532529
},
533530
{
534531
# Decode sequences 0 and 1
535-
# total blocks in use: 2 + 2 = 4
536532
"step": 3,
537533
"tkv": 65,
538534
"waiting": [],
@@ -553,18 +549,16 @@ def test_new_sequence_joins_during_decode(
553549
},
554550
{
555551
# Prefill sequence 2
556-
# total blocks in use: 4 + 2 (long prefill) = 6
557552
"step": 33,
558553
"tkv": 94,
559554
"waiting": [],
560555
"running": ["2", "1", "0"],
561556
"request_outputs": ["2"],
562-
"n_reserved_blocks": 9, # prefill (2 block) + 100 decode (2 block)
557+
"n_reserved_blocks": 9, # prefill (2 block) + 103 decode (2 block)
563558
"n_used_blocks": 6
564559
},
565560
{
566561
# Decode sequences 0, 1, and 2
567-
# total blocks in use: 6
568562
"step": 34,
569563
"tkv": 95,
570564
"waiting": [],
@@ -576,7 +570,6 @@ def test_new_sequence_joins_during_decode(
576570
{
577571
# Sequence 1 finishes at step 54
578572
# (start step + 2 prefills + 51 decodes - 1) = 2 + 2 + 51 - 1 = 54
579-
# total blocks in use: 6
580573
"step": 54,
581574
"tkv": 115,
582575
"waiting": [],
@@ -588,7 +581,6 @@ def test_new_sequence_joins_during_decode(
588581
},
589582
{
590583
# Decode sequences 0 and 2
591-
# total blocks in use: 4
592584
"step": 55,
593585
"tkv": 116,
594586
"waiting": [],
@@ -599,7 +591,6 @@ def test_new_sequence_joins_during_decode(
599591
},
600592
{
601593
# Decode sequences 0 and 2, tkv arrives to new block
602-
# total blocks in use: 6
603594
"step": 68,
604595
"tkv": 129,
605596
"waiting": [],
@@ -611,7 +602,6 @@ def test_new_sequence_joins_during_decode(
611602
{
612603
# Sequence 0 finishes at step 121
613604
# (start step + 3 prefills + 118 decode - 1) = 1 + 3 + 118 - 1 = 121
614-
# total blocks in use: 6
615605
"step": 121,
616606
"tkv": 182,
617607
"waiting": [],
@@ -623,7 +613,6 @@ def test_new_sequence_joins_during_decode(
623613
},
624614
{
625615
# Decode sequence 2
626-
# total blocks in use: 3
627616
"step": 122,
628617
"tkv": 183,
629618
"waiting": [],
@@ -634,7 +623,6 @@ def test_new_sequence_joins_during_decode(
634623
},
635624
{
636625
# Sequence 3 joins: one iteration in waiting queue
637-
# total blocks in use: 3
638626
"step": 131,
639627
"tkv": 192,
640628
"waiting": ["3"],
@@ -643,6 +631,71 @@ def test_new_sequence_joins_during_decode(
643631
"n_reserved_blocks": 4,
644632
"n_used_blocks": 3
645633
},
634+
{
635+
# Prefill sequence 3
636+
"step": 132,
637+
"tkv": 192,
638+
"waiting": [],
639+
"running": ["3", "2"],
640+
"request_outputs": ["3"],
641+
"n_reserved_blocks": 8, # prefill (3 blocks) + 64 decode (1 block)
642+
"n_used_blocks": 6 # prefill (3 block)
643+
},
644+
{
645+
# Decode sequences 2 and 3
646+
"step": 133,
647+
"tkv": 193,
648+
"waiting": [],
649+
"running": ["3", "2"],
650+
"request_outputs": ["3", "2"],
651+
"n_reserved_blocks": 8, # prefill (3 blocks) + 64 decode (1 block)
652+
"n_used_blocks": 8 # 2 blocks extended, one for each sequence
653+
},
654+
{
655+
# Sequence 2 finishes at step 137
656+
# (start step + 2 prefills + 103 decodes) = 33 + 2 + 103 - 1 = 137
657+
"step": 137,
658+
"tkv": 197,
659+
"waiting": [],
660+
"running": ["3"],
661+
"request_outputs": ["3", "2"],
662+
"finished_requests": ["2"],
663+
"n_reserved_blocks": 8,
664+
"n_used_blocks": 8
665+
},
666+
{
667+
# Decode sequence 3
668+
"step": 138,
669+
"tkv": 70,
670+
"waiting": [],
671+
"running": ["3"],
672+
"request_outputs": ["3"],
673+
# 6 blocks freed: finished sequence (4) + left padding stripping (2)
674+
"n_reserved_blocks": 2,
675+
"n_used_blocks": 2
676+
},
677+
{
678+
# Sequence 3 finishes at step 196
679+
# (start step + 1 prefills + 103 decodes) = 132 + 1 + 64 - 1 = 196
680+
"step": 196,
681+
"tkv": 128,
682+
"waiting": [],
683+
"running": [],
684+
"request_outputs": ["3"],
685+
"finished_requests": ["3"],
686+
"n_reserved_blocks": 2,
687+
"n_used_blocks": 2
688+
},
689+
{
690+
# Tkv should be cleared one step later
691+
"step": 197,
692+
"tkv": 0,
693+
"waiting": [],
694+
"running": [],
695+
"request_outputs": [],
696+
"n_reserved_blocks": 0,
697+
"n_used_blocks": 0
698+
},
646699
]
647700

648701
check_scheduler_inference_steps(
@@ -658,6 +711,7 @@ def test_new_sequence_joins_during_decode(
658711
use_cb=True,
659712
)
660713

714+
661715
@pytest.mark.cb
662716
@pytest.mark.parametrize("model", get_spyre_model_list())
663717
@pytest.mark.parametrize("backend", get_spyre_backend_list())

0 commit comments

Comments
 (0)