Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@
import java.util.Collection;
import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.FinalizeBundleResponse;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
Expand All @@ -45,14 +49,11 @@
* <p>See <a href="https://s.apache.org/beam-finalizing-bundles">Apache Beam Portability API: How to
* Finalize Bundles</a> for further details.
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class FinalizeBundleHandler {

/** A {@link BundleFinalizer.Callback} and expiry time pair. */
@AutoValue
abstract static class CallbackRegistration {
public abstract static class CallbackRegistration {
public static CallbackRegistration create(
Instant expiryTime, BundleFinalizer.Callback callback) {
return new AutoValue_FinalizeBundleHandler_CallbackRegistration(expiryTime, callback);
Expand All @@ -64,55 +65,53 @@ public static CallbackRegistration create(
}

private final ConcurrentMap<String, Collection<CallbackRegistration>> bundleFinalizationCallbacks;
private final ReentrantLock cleanupLock = new ReentrantLock();
private final Condition queueMinChanged = cleanupLock.newCondition();

@GuardedBy("cleanupLock")
private final PriorityQueue<TimestampedValue<String>> cleanUpQueue;

@SuppressWarnings("unused")
private final Future<Void> cleanUpResult;
private final Future<?> cleanUpResult;

@SuppressWarnings("methodref.receiver.bound")
public FinalizeBundleHandler(ExecutorService executorService) {
this.bundleFinalizationCallbacks = new ConcurrentHashMap<>();
this.cleanUpQueue =
new PriorityQueue<>(11, Comparator.comparing(TimestampedValue::getTimestamp));
// Wait until we have at least one element. We are notified on each element
// being added.
// Wait until the current time has past the expiry time for the head of the
// queue.
// We are notified on each element being added.
// Wait until we have at least one element. We are notified on each element
// being added.
// Wait until the current time has past the expiry time for the head of the
// queue.
// We are notified on each element being added.
cleanUpResult =
executorService.submit(
(Callable<Void>)
() -> {
while (true) {
synchronized (cleanUpQueue) {
TimestampedValue<String> expiryTime = cleanUpQueue.peek();

// Wait until we have at least one element. We are notified on each element
// being added.
while (expiryTime == null) {
cleanUpQueue.wait();
expiryTime = cleanUpQueue.peek();
}

// Wait until the current time has past the expiry time for the head of the
// queue.
// We are notified on each element being added.
Instant now = Instant.now();
while (expiryTime.getTimestamp().isAfter(now)) {
Duration timeDifference = new Duration(now, expiryTime.getTimestamp());
cleanUpQueue.wait(timeDifference.getMillis());
expiryTime = cleanUpQueue.peek();
now = Instant.now();
}

bundleFinalizationCallbacks.remove(cleanUpQueue.poll().getValue());
}
}
});

cleanUpResult = executorService.submit(this::cleanupThreadBody);
}

private void cleanupThreadBody() {
cleanupLock.lock();
try {
while (true) {
final @Nullable TimestampedValue<String> minValue = cleanUpQueue.peek();
if (minValue == null) {
// Wait for an element to be added and loop to re-examine the min.
queueMinChanged.await();
continue;
}

Instant now = Instant.now();
Duration timeDifference = new Duration(now, minValue.getTimestamp());
if (timeDifference.getMillis() > 0
&& queueMinChanged.await(timeDifference.getMillis(), TimeUnit.MILLISECONDS)) {
// If the time didn't elapse, loop to re-examine the min.
continue;
}

// The minimum element has an expiry time before now.
// It may or may not actually be present in the map if the finalization has already been
// completed.
bundleFinalizationCallbacks.remove(minValue.getValue());
}
} catch (InterruptedException e) {
// We're being shutdown.
} finally {
cleanupLock.unlock();
}
}

public void registerCallbacks(String bundleId, Collection<CallbackRegistration> callbacks) {
Expand All @@ -131,23 +130,34 @@ public void registerCallbacks(String bundleId, Collection<CallbackRegistration>
for (CallbackRegistration callback : callbacks) {
expiryTimeMillis = Math.max(expiryTimeMillis, callback.getExpiryTime().getMillis());
}
synchronized (cleanUpQueue) {
cleanUpQueue.offer(TimestampedValue.of(bundleId, new Instant(expiryTimeMillis)));
cleanUpQueue.notify();

cleanupLock.lock();
try {
TimestampedValue<String> value = TimestampedValue.of(bundleId, new Instant(expiryTimeMillis));
cleanUpQueue.offer(value);
@SuppressWarnings("ReferenceEquality")
boolean newMin = cleanUpQueue.peek() == value;
if (newMin) {
queueMinChanged.signal();
}
} finally {
cleanupLock.unlock();
}
}

public BeamFnApi.InstructionResponse.Builder finalizeBundle(BeamFnApi.InstructionRequest request)
throws Exception {
String bundleId = request.getFinalizeBundle().getInstructionId();

Collection<CallbackRegistration> callbacks = bundleFinalizationCallbacks.remove(bundleId);
final @Nullable Collection<CallbackRegistration> callbacks =
bundleFinalizationCallbacks.remove(bundleId);

if (callbacks == null) {
// We have already processed the callbacks on a prior bundle finalization attempt
return BeamFnApi.InstructionResponse.newBuilder()
.setFinalizeBundle(FinalizeBundleResponse.getDefaultInstance());
}
// We don't bother removing from the cleanupQueue.

Collection<Exception> failures = new ArrayList<>();
for (CallbackRegistration callback : callbacks) {
Expand Down
Loading