1818
1919package  org .apache .flink .api .connector .source .mocks ;
2020
21+ import  org .apache .flink .api .connector .source .ReaderInfo ;
2122import  org .apache .flink .api .connector .source .SourceEvent ;
2223import  org .apache .flink .api .connector .source .SplitEnumerator ;
2324import  org .apache .flink .api .connector .source .SplitEnumeratorContext ;
2829
2930import  java .io .IOException ;
3031import  java .util .ArrayList ;
32+ import  java .util .Collection ;
3133import  java .util .Collections ;
32- import  java .util .Comparator ;
3334import  java .util .HashMap ;
3435import  java .util .HashSet ;
3536import  java .util .List ;
3637import  java .util .Map ;
3738import  java .util .Set ;
38- import  java .util .SortedSet ;
39- import  java .util .TreeSet ;
39+ import  java .util .stream .Collectors ;
4040
4141/** A mock {@link SplitEnumerator} for unit tests. */ 
4242public  class  MockSplitEnumerator 
4343        implements  SplitEnumerator <MockSourceSplit , Set <MockSourceSplit >>, SupportsBatchSnapshot  {
44-     private  final  SortedSet <MockSourceSplit > unassignedSplits ;
44+     private  final  Map <Integer , Set <MockSourceSplit >> pendingSplitAssignment ;
45+     private  final  Map <String , Integer > globalSplitAssignment ;
4546    private  final  SplitEnumeratorContext <MockSourceSplit > enumContext ;
4647    private  final  List <SourceEvent > handledSourceEvent ;
4748    private  final  List <Long > successfulCheckpoints ;
@@ -50,22 +51,24 @@ public class MockSplitEnumerator
5051
5152    public  MockSplitEnumerator (int  numSplits , SplitEnumeratorContext <MockSourceSplit > enumContext ) {
5253        this (new  HashSet <>(), enumContext );
54+         List <MockSourceSplit > unassignedSplits  = new  ArrayList <>();
5355        for  (int  i  = 0 ; i  < numSplits ; i ++) {
5456            unassignedSplits .add (new  MockSourceSplit (i ));
5557        }
58+         recalculateAssignments (unassignedSplits );
5659    }
5760
5861    public  MockSplitEnumerator (
5962            Set <MockSourceSplit > unassignedSplits ,
6063            SplitEnumeratorContext <MockSourceSplit > enumContext ) {
61-         this .unassignedSplits  =
62-                 new  TreeSet <>(Comparator .comparingInt (o  -> Integer .parseInt (o .splitId ())));
63-         this .unassignedSplits .addAll (unassignedSplits );
64+         this .pendingSplitAssignment  = new  HashMap <>();
65+         this .globalSplitAssignment  = new  HashMap <>();
6466        this .enumContext  = enumContext ;
6567        this .handledSourceEvent  = new  ArrayList <>();
6668        this .successfulCheckpoints  = new  ArrayList <>();
6769        this .started  = false ;
6870        this .closed  = false ;
71+         recalculateAssignments (unassignedSplits );
6972    }
7073
7174    @ Override 
@@ -83,25 +86,36 @@ public void handleSourceEvent(int subtaskId, SourceEvent sourceEvent) {
8386
8487    @ Override 
8588    public  void  addSplitsBack (List <MockSourceSplit > splits , int  subtaskId ) {
86-         unassignedSplits .addAll (splits );
89+         // add back to same subtaskId. 
90+         putPendingAssignments (subtaskId , splits );
8791    }
8892
8993    @ Override 
9094    public  void  addReader (int  subtaskId ) {
91-         List <MockSourceSplit > assignment  = new  ArrayList <>();
92-         for  (MockSourceSplit  split  : unassignedSplits ) {
93-             if  (Integer .parseInt (split .splitId ()) % enumContext .currentParallelism () == subtaskId ) {
94-                 assignment .add (split );
95+         ReaderInfo  readerInfo  = enumContext .registeredReaders ().get (subtaskId );
96+         List <MockSourceSplit > splitsOnRecovery  = readerInfo .getReportedSplitsOnRegistration ();
97+ 
98+         List <MockSourceSplit > redistributedSplits  = new  ArrayList <>();
99+         List <MockSourceSplit > addBackSplits  = new  ArrayList <>();
100+         for  (MockSourceSplit  split  : splitsOnRecovery ) {
101+             if  (!globalSplitAssignment .containsKey (split .splitId ())) {
102+                 // if the split is not present in globalSplitAssignment, it means that this split is 
103+                 // being registered for the first time and is eligible for redistribution. 
104+                 redistributedSplits .add (split );
105+             } else  if  (!globalSplitAssignment .containsKey (split .splitId ())) {
106+                 //  if split is already assigned to other sub-task, just ignore it. Otherwise, add 
107+                 // back to this sub-task again. 
108+                 addBackSplits .add (split );
95109            }
96110        }
97-         enumContext . assignSplits ( 
98-                  new   SplitsAssignment <>( Collections . singletonMap ( subtaskId , assignment )) );
99-         unassignedSplits . removeAll ( assignment );
111+         recalculateAssignments ( redistributedSplits ); 
112+         putPendingAssignments ( subtaskId , addBackSplits );
113+         assignAllSplits ( );
100114    }
101115
102116    @ Override 
103117    public  Set <MockSourceSplit > snapshotState (long  checkpointId ) {
104-         return  unassignedSplits ;
118+         return  getUnassignedSplits () ;
105119    }
106120
107121    @ Override 
@@ -114,11 +128,6 @@ public void close() throws IOException {
114128        this .closed  = true ;
115129    }
116130
117-     public  void  addNewSplits (List <MockSourceSplit > newSplits ) {
118-         unassignedSplits .addAll (newSplits );
119-         assignAllSplits ();
120-     }
121- 
122131    // -------------------- 
123132
124133    public  boolean  started () {
@@ -130,7 +139,9 @@ public boolean closed() {
130139    }
131140
132141    public  Set <MockSourceSplit > getUnassignedSplits () {
133-         return  unassignedSplits ;
142+         return  pendingSplitAssignment .values ().stream ()
143+                 .flatMap (Set ::stream )
144+                 .collect (Collectors .toSet ());
134145    }
135146
136147    public  List <SourceEvent > getHandledSourceEvent () {
@@ -145,17 +156,27 @@ public List<Long> getSuccessfulCheckpoints() {
145156
146157    private  void  assignAllSplits () {
147158        Map <Integer , List <MockSourceSplit >> assignment  = new  HashMap <>();
148-         unassignedSplits .forEach (
149-                 split  -> {
150-                     int  subtaskId  =
151-                             Integer .parseInt (split .splitId ()) % enumContext .currentParallelism ();
152-                     if  (enumContext .registeredReaders ().containsKey (subtaskId )) {
153-                         assignment 
154-                                 .computeIfAbsent (subtaskId , ignored  -> new  ArrayList <>())
155-                                 .add (split );
156-                     }
157-                 });
159+         for  (Map .Entry <Integer , Set <MockSourceSplit >> iter  : pendingSplitAssignment .entrySet ()) {
160+             Integer  subtaskId  = iter .getKey ();
161+             if  (enumContext .registeredReaders ().containsKey (subtaskId )) {
162+                 assignment .put (subtaskId , new  ArrayList <>(iter .getValue ()));
163+             }
164+         }
158165        enumContext .assignSplits (new  SplitsAssignment <>(assignment ));
159-         assignment .values ().forEach (l  -> unassignedSplits .removeAll (l ));
166+         assignment .keySet ().forEach (pendingSplitAssignment ::remove );
167+     }
168+ 
169+     private  void  recalculateAssignments (Collection <MockSourceSplit > newSplits ) {
170+         for  (MockSourceSplit  split  : newSplits ) {
171+             int  subtaskId  = Integer .parseInt (split .splitId ()) % enumContext .currentParallelism ();
172+             putPendingAssignments (subtaskId , Collections .singletonList (split ));
173+         }
174+     }
175+ 
176+     private  void  putPendingAssignments (int  subtaskId , Collection <MockSourceSplit > splits ) {
177+         Set <MockSourceSplit > pendingSplits  =
178+                 pendingSplitAssignment .computeIfAbsent (subtaskId , HashSet ::new );
179+         pendingSplits .addAll (splits );
180+         splits .forEach (split  -> globalSplitAssignment .put (split .splitId (), subtaskId ));
160181    }
161182}
0 commit comments