diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java index d3e9eddf75aa..13bb8ceae05d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/FinalizeBundleHandler.java @@ -24,11 +24,15 @@ import java.util.Collection; import java.util.Comparator; import java.util.PriorityQueue; -import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; @@ -45,14 +49,11 @@ *

See Apache Beam Portability API: How to * Finalize Bundles for further details. */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class FinalizeBundleHandler { /** A {@link BundleFinalizer.Callback} and expiry time pair. */ @AutoValue - abstract static class CallbackRegistration { + public abstract static class CallbackRegistration { public static CallbackRegistration create( Instant expiryTime, BundleFinalizer.Callback callback) { return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback); @@ -64,55 +65,53 @@ public static CallbackRegistration create( } private final ConcurrentMap> bundleFinalizationCallbacks; + private final ReentrantLock cleanupLock = new ReentrantLock(); + private final Condition queueMinChanged = cleanupLock.newCondition(); + + @GuardedBy("cleanupLock") private final PriorityQueue> cleanUpQueue; @SuppressWarnings("unused") - private final Future cleanUpResult; + private final Future cleanUpResult; + @SuppressWarnings("methodref.receiver.bound") public FinalizeBundleHandler(ExecutorService executorService) { this.bundleFinalizationCallbacks = new ConcurrentHashMap<>(); this.cleanUpQueue = new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp)); - // Wait until we have at least one element. We are notified on each element - // being added. - // Wait until the current time has past the expiry time for the head of the - // queue. - // We are notified on each element being added. - // Wait until we have at least one element. We are notified on each element - // being added. - // Wait until the current time has past the expiry time for the head of the - // queue. - // We are notified on each element being added. - cleanUpResult = - executorService.submit( - (Callable) - () -> { - while (true) { - synchronized (cleanUpQueue) { - TimestampedValue expiryTime = cleanUpQueue.peek(); - - // Wait until we have at least one element. We are notified on each element - // being added. - while (expiryTime == null) { - cleanUpQueue.wait(); - expiryTime = cleanUpQueue.peek(); - } - - // Wait until the current time has past the expiry time for the head of the - // queue. - // We are notified on each element being added. - Instant now = Instant.now(); - while (expiryTime.getTimestamp().isAfter(now)) { - Duration timeDifference = new Duration(now, expiryTime.getTimestamp()); - cleanUpQueue.wait(timeDifference.getMillis()); - expiryTime = cleanUpQueue.peek(); - now = Instant.now(); - } - - bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue()); - } - } - }); + + cleanUpResult = executorService.submit(this::cleanupThreadBody); + } + + private void cleanupThreadBody() { + cleanupLock.lock(); + try { + while (true) { + final @Nullable TimestampedValue minValue = cleanUpQueue.peek(); + if (minValue == null) { + // Wait for an element to be added and loop to re-examine the min. + queueMinChanged.await(); + continue; + } + + Instant now = Instant.now(); + Duration timeDifference = new Duration(now, minValue.getTimestamp()); + if (timeDifference.getMillis() > 0 + && queueMinChanged.await(timeDifference.getMillis(), TimeUnit.MILLISECONDS)) { + // If the time didn't elapse, loop to re-examine the min. + continue; + } + + // The minimum element has an expiry time before now. + // It may or may not actually be present in the map if the finalization has already been + // completed. + bundleFinalizationCallbacks.remove(minValue.getValue()); + } + } catch (InterruptedException e) { + // We're being shutdown. + } finally { + cleanupLock.unlock(); + } } public void registerCallbacks(String bundleId, Collection callbacks) { @@ -131,9 +130,18 @@ public void registerCallbacks(String bundleId, Collection for (CallbackRegistration callback : callbacks) { expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis()); } - synchronized (cleanUpQueue) { - cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis))); - cleanUpQueue.notify(); + + cleanupLock.lock(); + try { + TimestampedValue value = TimestampedValue.of(bundleId, new Instant(expiryTimeMillis)); + cleanUpQueue.offer(value); + @SuppressWarnings("ReferenceEquality") + boolean newMin = cleanUpQueue.peek() == value; + if (newMin) { + queueMinChanged.signal(); + } + } finally { + cleanupLock.unlock(); } } @@ -141,13 +149,15 @@ public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.Instructio throws Exception { String bundleId = request.getFinalizeBundle().getInstructionId(); - Collection callbacks = bundleFinalizationCallbacks.remove(bundleId); + final @Nullable Collection callbacks = + bundleFinalizationCallbacks.remove(bundleId); if (callbacks == null) { // We have already processed the callbacks on a prior bundle finalization attempt return BeamFnApi.InstructionResponse.newBuilder() .setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance()); } + // We don't bother removing from the cleanupQueue. Collection failures = new ArrayList<>(); for (CallbackRegistration callback : callbacks) {