Skip to content

Commit 87ebd85

Browse files
committed
Fix race condition in StepContribution skip/filter count
Signed-off-by: mugeon <pos04167@kakao.com>
1 parent 2cc7890 commit 87ebd85

2 files changed

Lines changed: 161 additions & 16 deletions

File tree

spring-batch-core/src/main/java/org/springframework/batch/core/step/StepContribution.java

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.io.Serializable;
1919
import java.util.Objects;
20+
import java.util.concurrent.atomic.AtomicLong;
2021

2122
import org.springframework.batch.core.ExitStatus;
2223

@@ -36,15 +37,15 @@ public class StepContribution implements Serializable {
3637

3738
private long writeCount = 0;
3839

39-
private long filterCount = 0;
40+
private final AtomicLong filterCount = new AtomicLong(0);
4041

4142
private final long parentSkipCount;
4243

4344
private long readSkipCount;
4445

4546
private long writeSkipCount;
4647

47-
private long processSkipCount;
48+
private final AtomicLong processSkipCount = new AtomicLong(0);
4849

4950
private ExitStatus exitStatus = ExitStatus.EXECUTING;
5051

@@ -86,7 +87,7 @@ public void incrementFilterCount() {
8687
* @param count The {@code long} amount to increment by.
8788
*/
8889
public void incrementFilterCount(long count) {
89-
filterCount += count;
90+
filterCount.addAndGet(count);
9091
}
9192

9293
/**
@@ -125,23 +126,23 @@ public long getWriteCount() {
125126
* @return the filter counter.
126127
*/
127128
public long getFilterCount() {
128-
return filterCount;
129+
return filterCount.get();
129130
}
130131

131132
/**
132133
* @return the sum of skips accumulated in the parent {@link StepExecution} and this
133134
* <code>StepContribution</code>.
134135
*/
135136
public long getStepSkipCount() {
136-
return readSkipCount + writeSkipCount + processSkipCount + parentSkipCount;
137+
return readSkipCount + writeSkipCount + processSkipCount.get() + parentSkipCount;
137138
}
138139

139140
/**
140141
* @return the number of skips collected in this <code>StepContribution</code> (not
141142
* including skips accumulated in the parent {@link StepExecution}).
142143
*/
143144
public long getSkipCount() {
144-
return readSkipCount + writeSkipCount + processSkipCount;
145+
return readSkipCount + writeSkipCount + processSkipCount.get();
145146
}
146147

147148
/**
@@ -179,11 +180,11 @@ public void incrementWriteSkipCount(long count) {
179180
*
180181
*/
181182
public void incrementProcessSkipCount() {
182-
processSkipCount++;
183+
processSkipCount.incrementAndGet();
183184
}
184185

185186
public void incrementProcessSkipCount(long count) {
186-
processSkipCount += count;
187+
processSkipCount.addAndGet(count);
187188
}
188189

189190
/**
@@ -207,7 +208,7 @@ public long getWriteSkipCount() {
207208
* @return the process skip count.
208209
*/
209210
public long getProcessSkipCount() {
210-
return processSkipCount;
211+
return processSkipCount.get();
211212
}
212213

213214
/**
@@ -220,25 +221,26 @@ public StepExecution getStepExecution() {
220221

221222
@Override
222223
public String toString() {
223-
return "[StepContribution: read=" + readCount + ", written=" + writeCount + ", filtered=" + filterCount
224+
return "[StepContribution: read=" + readCount + ", written=" + writeCount + ", filtered=" + filterCount.get()
224225
+ ", readSkips=" + readSkipCount + ", writeSkips=" + writeSkipCount + ", processSkips="
225-
+ processSkipCount + ", exitStatus=" + exitStatus.getExitCode() + "]";
226+
+ processSkipCount.get() + ", exitStatus=" + exitStatus.getExitCode() + "]";
226227
}
227228

228229
@Override
229230
public boolean equals(Object o) {
230231
if (!(o instanceof StepContribution that))
231232
return false;
232-
return readCount == that.readCount && writeCount == that.writeCount && filterCount == that.filterCount
233-
&& parentSkipCount == that.parentSkipCount && readSkipCount == that.readSkipCount
234-
&& writeSkipCount == that.writeSkipCount && processSkipCount == that.processSkipCount
233+
return readCount == that.readCount && writeCount == that.writeCount
234+
&& filterCount.get() == that.filterCount.get() && parentSkipCount == that.parentSkipCount
235+
&& readSkipCount == that.readSkipCount && writeSkipCount == that.writeSkipCount
236+
&& processSkipCount.get() == that.processSkipCount.get()
235237
&& Objects.equals(stepExecution, that.stepExecution) && Objects.equals(exitStatus, that.exitStatus);
236238
}
237239

238240
@Override
239241
public int hashCode() {
240-
return Objects.hash(stepExecution, readCount, writeCount, filterCount, parentSkipCount, readSkipCount,
241-
writeSkipCount, processSkipCount, exitStatus);
242+
return Objects.hash(stepExecution, readCount, writeCount, filterCount.get(), parentSkipCount, readSkipCount,
243+
writeSkipCount, processSkipCount.get(), exitStatus);
242244
}
243245

244246
}

spring-batch-core/src/test/java/org/springframework/batch/core/step/item/ChunkOrientedStepTests.java

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,4 +319,147 @@ class SkippableException extends RuntimeException {
319319
assertEquals(1, stepExecution.getSkipCount());
320320
}
321321

322+
@Test
323+
void testFilterCountAccuracyInConcurrentMode() throws Exception {
324+
// given
325+
int itemCount = 10;
326+
AtomicInteger readCounter = new AtomicInteger(0);
327+
328+
ItemReader<Integer> reader = () -> {
329+
int current = readCounter.incrementAndGet();
330+
return current <= itemCount ? current : null;
331+
};
332+
333+
ItemProcessor<Integer, Integer> filteringProcessor = item -> null;
334+
335+
ItemWriter<Integer> writer = chunk -> {
336+
};
337+
338+
JobRepository jobRepository = new ResourcelessJobRepository();
339+
ChunkOrientedStep<Integer, Integer> step = new ChunkOrientedStep<>("step", 100, reader, writer, jobRepository);
340+
step.setItemProcessor(filteringProcessor);
341+
step.setTaskExecutor(new SimpleAsyncTaskExecutor());
342+
step.afterPropertiesSet();
343+
344+
JobInstance jobInstance = new JobInstance(1L, "job");
345+
JobExecution jobExecution = new JobExecution(1L, jobInstance, new JobParameters());
346+
StepExecution stepExecution = new StepExecution(1L, "step", jobExecution);
347+
348+
// when
349+
step.execute(stepExecution);
350+
351+
// then
352+
assertEquals(itemCount, stepExecution.getFilterCount(), "Race condition detected! Expected " + itemCount
353+
+ " filtered items, but got " + stepExecution.getFilterCount());
354+
}
355+
356+
@Test
357+
void testFilterCountAccuracyInSequentialMode() throws Exception {
358+
// given
359+
int itemCount = 10;
360+
AtomicInteger readCounter = new AtomicInteger(0);
361+
362+
ItemReader<Integer> reader = () -> {
363+
int current = readCounter.incrementAndGet();
364+
return current <= itemCount ? current : null;
365+
};
366+
367+
ItemProcessor<Integer, Integer> filteringProcessor = item -> null;
368+
ItemWriter<Integer> writer = chunk -> {
369+
};
370+
371+
JobRepository jobRepository = new ResourcelessJobRepository();
372+
ChunkOrientedStep<Integer, Integer> step = new ChunkOrientedStep<>("step", 100, reader, writer, jobRepository);
373+
step.setItemProcessor(filteringProcessor);
374+
step.afterPropertiesSet();
375+
376+
JobInstance jobInstance = new JobInstance(1L, "job");
377+
JobExecution jobExecution = new JobExecution(1L, jobInstance, new JobParameters());
378+
StepExecution stepExecution = new StepExecution(1L, "step", jobExecution);
379+
380+
// when
381+
step.execute(stepExecution);
382+
383+
// then
384+
assertEquals(itemCount, stepExecution.getFilterCount(), "Sequential mode should have accurate filter count");
385+
}
386+
387+
@Test
388+
void testProcessSkipCountAccuracyInConcurrentMode() throws Exception {
389+
// given
390+
int itemCount = 10;
391+
AtomicInteger readCounter = new AtomicInteger(0);
392+
393+
ItemReader<Integer> reader = () -> {
394+
int current = readCounter.incrementAndGet();
395+
return current <= itemCount ? current : null;
396+
};
397+
398+
ItemProcessor<Integer, Integer> failingProcessor = item -> {
399+
throw new RuntimeException("Simulated processing failure");
400+
};
401+
402+
ItemWriter<Integer> writer = chunk -> {
403+
};
404+
405+
JobRepository jobRepository = new ResourcelessJobRepository();
406+
ChunkOrientedStep<Integer, Integer> step = new ChunkOrientedStep<>("step", 100, reader, writer, jobRepository);
407+
step.setItemProcessor(failingProcessor);
408+
step.setTaskExecutor(new SimpleAsyncTaskExecutor());
409+
step.setFaultTolerant(true);
410+
step.setRetryPolicy(RetryPolicy.withMaxRetries(1));
411+
step.setSkipPolicy((throwable, skipCount) -> throwable instanceof RuntimeException);
412+
413+
step.afterPropertiesSet();
414+
415+
JobInstance jobInstance = new JobInstance(1L, "job");
416+
JobExecution jobExecution = new JobExecution(1L, jobInstance, new JobParameters());
417+
StepExecution stepExecution = new StepExecution(1L, "step", jobExecution);
418+
419+
// when
420+
step.execute(stepExecution);
421+
422+
// then
423+
assertEquals(itemCount, stepExecution.getProcessSkipCount(), "Race condition detected! Expected " + itemCount
424+
+ " process skips, but got " + stepExecution.getProcessSkipCount());
425+
}
426+
427+
@Test
428+
void testProcessSkipCountAccuracyInSequentialMode() throws Exception {
429+
// given
430+
int itemCount = 10;
431+
AtomicInteger readCounter = new AtomicInteger(0);
432+
433+
ItemReader<Integer> reader = () -> {
434+
int current = readCounter.incrementAndGet();
435+
return current <= itemCount ? current : null;
436+
};
437+
438+
ItemProcessor<Integer, Integer> failingProcessor = item -> {
439+
throw new RuntimeException("Simulated processing failure");
440+
};
441+
442+
ItemWriter<Integer> writer = chunk -> {
443+
};
444+
445+
JobRepository jobRepository = new ResourcelessJobRepository();
446+
ChunkOrientedStep<Integer, Integer> step = new ChunkOrientedStep<>("step", 100, reader, writer, jobRepository);
447+
step.setItemProcessor(failingProcessor);
448+
step.setFaultTolerant(true);
449+
step.setRetryPolicy(RetryPolicy.withMaxRetries(1));
450+
step.setSkipPolicy((throwable, skipCount) -> throwable instanceof RuntimeException);
451+
step.afterPropertiesSet();
452+
453+
JobInstance jobInstance = new JobInstance(1L, "job");
454+
JobExecution jobExecution = new JobExecution(1L, jobInstance, new JobParameters());
455+
StepExecution stepExecution = new StepExecution(1L, "step", jobExecution);
456+
457+
// when
458+
step.execute(stepExecution);
459+
460+
// then
461+
assertEquals(itemCount, stepExecution.getProcessSkipCount(),
462+
"Sequential mode should have accurate process skip count");
463+
}
464+
322465
}

0 commit comments

Comments
 (0)