diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherCache.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherCache.java new file mode 100644 index 0000000000..fad7855823 --- /dev/null +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherCache.java @@ -0,0 +1,261 @@ +/* + * Copyright © 2024 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscriber; +import io.servicetalk.concurrent.PublisherSource.Subscription; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nullable; + +/** + * A cache of {@link Publisher}s. Keys must correctly implement Object::hashCode and Object::equals. + * The cache can be created with either multicast or replay configurations. Subscriptions + * are tracked by the cache and Publishers are removed when there are no more subscriptions. + * + * @param a key type suitable for use as a Map key, that correctly implements hashCode/equals. + * @param the type of the publisher that is cached. + */ +public final class PublisherCache { + private final MulticastStrategy multicastStrategy; + private final Map> publisherCache; + + PublisherCache(final MulticastStrategy multicastStrategy) { + this.multicastStrategy = multicastStrategy; + this.publisherCache = new HashMap<>(); + } + + /** + * Create a new PublisherCache where the cached publishers must be configured with a multicast or replay operator + * by the multicastSupplier function. + * + * @param a key type suitable for use as a {@link Map} key. + * @param the type of the {@link Publisher} contained in the cache. + * @return a new PublisherCache that will not wrap cached values. + */ + public static PublisherCache create() { + return new PublisherCache<>(MulticastStrategy.identity()); + } + + /** + * Create a new PublisherCache where the cached publishers will be configured for multicast for all + * consumers of a specific key. + * + * @param a key type suitable for use as a {@link Map} key. + * @param the type of the {@link Publisher} contained in the cache. + * @return a new PublisherCache that will wrap cached values with multicast operator. + */ + public static PublisherCache multicast() { + return new PublisherCache<>(MulticastStrategy.wrapMulticast()); + } + + /** + * Create a new PublisherCache where the cached publishers will be configured for multicast for all + * consumers of a specific key. + * + * @param minSubscribers The upstream subscribe operation will not happen until after this many {@link Subscriber} + * subscribe to the return value. + * @param a key type suitable for use as a {@link Map} key. + * @param the type of the {@link Publisher} contained in the cache. + * @return a new PublisherCache that will wrap cached values with multicast operator. + */ + public static PublisherCache multicast(final int minSubscribers) { + return new PublisherCache<>(MulticastStrategy.wrapMulticast(minSubscribers)); + } + + /** + * Create a new PublisherCache where the cached publishers will be configured for multicast for all + * consumers of a specific key. + * + * @param minSubscribers The upstream subscribe operation will not happen until after this many {@link Subscriber} + * subscribe to the return value. + * @param queueLimit The number of elements which will be queued for each {@link Subscriber} in order to compensate + * for unequal demand. + * @param a key type suitable for use as a {@link Map} key. + * @param the type of the {@link Publisher} contained in the cache. + * @return aa new PublisherCache that will wrap cached values with multicast operator. + */ + public static PublisherCache multicast(final int minSubscribers, final int queueLimit) { + return new PublisherCache<>(MulticastStrategy.wrapMulticast(minSubscribers, queueLimit)); + } + + /** + * Create a new PublisherCache where the cacehed publishers will be configured for replay given the privided + * ReplayStrategy. + * + * @param replayStrategy a {@link ReplayStrategy} that determines the replay behavior and history retention logic. + * @param a key type suitable for use as a {@link Map} key. + * @param the type of the {@link Publisher} contained in the cache. + * @return a new PublisherCache that will wrap cached values with replay operator. + */ + public static PublisherCache replay(ReplayStrategy replayStrategy) { + return new PublisherCache<>(MulticastStrategy.wrapReplay(replayStrategy)); + } + + /** + * Retrieve the value for the given key, if no value exists in the cache it will be synchronously created. + * + * @param key a key corresponding to the requested {@link Publisher}. + * @param publisherSupplier if the key does not exist in the cache, used to create a new publisher. + * @return a new {@link Publisher} from the publisherSupplier if not contained in the cache, otherwise + * the cached publisher. + */ + public Publisher get(final K key, final Function> publisherSupplier) { + return Publisher.defer(() -> { + synchronized (publisherCache) { + if (publisherCache.containsKey(key)) { + return publisherCache.get(key).publisher; + } + + final Holder item2 = new Holder<>(); + publisherCache.put(key, item2); + + final Publisher newPublisher = publisherSupplier.apply(key) + .liftSync(subscriber -> new Subscriber() { + @Override + public void onSubscribe(Subscription subscription) { + subscriber.onSubscribe(new Subscription() { + @Override + public void request(long n) { + subscription.request(n); + } + + @Override + public void cancel() { + try { + assert Thread.holdsLock(publisherCache); + publisherCache.remove(key, item2); + } finally { + subscription.cancel(); + } + } + }); + } + + @Override + public void onNext(@Nullable T next) { + subscriber.onNext(next); + } + + @Override + public void onError(Throwable t) { + lockRemoveFromMap(); + subscriber.onError(t); + } + + @Override + public void onComplete() { + lockRemoveFromMap(); + subscriber.onComplete(); + } + + private void lockRemoveFromMap() { + synchronized (publisherCache) { + publisherCache.remove(key, item2); + } + } + }); + + item2.publisher = multicastStrategy.apply(newPublisher) + .liftSync(subscriber -> new Subscriber() { + @Override + public void onSubscribe(Subscription subscription) { + subscriber.onSubscribe(new Subscription() { + @Override + public void request(long n) { + subscription.request(n); + } + + @Override + public void cancel() { + synchronized (publisherCache) { + subscription.cancel(); + } + } + }); + } + + @Override + public void onNext(@Nullable T next) { + subscriber.onNext(next); + } + + @Override + public void onError(Throwable t) { + try { + subscriber.onError(t); + } finally { + lockRemoveFromMap(); + } + } + + @Override + public void onComplete() { + try { + subscriber.onComplete(); + } finally { + lockRemoveFromMap(); + } + } + + private void lockRemoveFromMap() { + synchronized (publisherCache) { + publisherCache.remove(key, item2); + } + } + }); + return item2.publisher; + } + }); + } + + private static final class Holder { + @Nullable + Publisher publisher; + } + + /** + * A series of strategies used to make cached {@link Publisher}s "shareable". + * + * @param the type of the {@link Publisher}. + */ + @FunctionalInterface + private interface MulticastStrategy { + Publisher apply(Publisher cached); + + static MulticastStrategy identity() { + return cached -> cached; + } + + static MulticastStrategy wrapMulticast() { + return cached -> cached.multicast(1, true); + } + + static MulticastStrategy wrapMulticast(int minSubscribers) { + return cached -> cached.multicast(minSubscribers, true); + } + + static MulticastStrategy wrapMulticast(int minSubscribers, int queueLimit) { + return cached -> cached.multicast(minSubscribers, queueLimit); + } + + static MulticastStrategy wrapReplay(final ReplayStrategy strategy) { + return cached -> cached.replay(strategy); + } + } +} diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherCacheTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherCacheTest.java new file mode 100644 index 0000000000..1983a77702 --- /dev/null +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherCacheTest.java @@ -0,0 +1,211 @@ +/* + * Copyright © 2024 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.concurrent.api; + +import io.servicetalk.concurrent.PublisherSource.Subscription; +import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static io.servicetalk.concurrent.api.SourceAdapters.toSource; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +class PublisherCacheTest { + private TestPublisher testPublisher; + private AtomicInteger upstreamSubscriptionCount; + private AtomicBoolean isUpstreamUnsubsrcibed; + private AtomicInteger upstreamCacheRequestCount; + private Function> publisherSupplier; + + @BeforeEach + public void setup() { + this.testPublisher = new TestPublisher<>(); + this.upstreamSubscriptionCount = new AtomicInteger(0); + this.isUpstreamUnsubsrcibed = new AtomicBoolean(false); + this.upstreamCacheRequestCount = new AtomicInteger(0); + this.publisherSupplier = (_ignore) -> { + upstreamCacheRequestCount.incrementAndGet(); + return testPublisher.afterOnSubscribe(subscription -> + upstreamSubscriptionCount.incrementAndGet() + ).afterFinally(() -> isUpstreamUnsubsrcibed.set(true)); + }; + } + + @Test + void multipleSubscribersToSameKeyReceiveMulticastEvents() { + final PublisherCache publisherCache = PublisherCache.multicast(); + + final TestPublisherSubscriber subscriber1 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber1); + final Subscription subscription1 = subscriber1.awaitSubscription(); + + // the first subscriber receives the initial event + subscription1.request(1); + testPublisher.onNext(1); + assertThat(subscriber1.takeOnNext(), is(1)); + + // the second subscriber receives the cached publisher + final TestPublisherSubscriber subscriber2 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber2); + final Subscription subscription2 = subscriber2.awaitSubscription(); + + // all subscribers receive all subsequent events + subscription1.request(1); + subscription2.request(1); + testPublisher.onNext(2); + assertThat(subscriber1.takeOnNext(), is(2)); + assertThat(subscriber2.takeOnNext(), is(2)); + + // subscribe with the first request + assertThat(upstreamSubscriptionCount.get(), is(1)); + + // make sure we still only have subscribed once + assertThat(upstreamSubscriptionCount.get(), is(1)); + assertThat(upstreamCacheRequestCount.get(), is(1)); + } + + @Test + void minSubscribersMulticastPolicySubscribesUpstreamOnce() { + final PublisherCache publisherCache = PublisherCache.multicast(); + + final TestPublisherSubscriber subscriber1 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber1); + final Subscription subscription1 = subscriber1.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(1)); + + final TestPublisherSubscriber subscriber2 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber2); + final Subscription subscription2 = subscriber2.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(1)); + + subscription1.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(false)); + + subscription2.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(true)); + + assertThat(upstreamCacheRequestCount.get(), is(1)); + } + + @Test + void unSubscribeUpstreamAndInvalidateCacheWhenEmpty() { + final PublisherCache publisherCache = PublisherCache.multicast(); + + final TestPublisherSubscriber subscriber1 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber1); + final Subscription subscription1 = subscriber1.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(1)); + + final TestPublisherSubscriber subscriber2 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber2); + final Subscription subscription2 = subscriber2.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(1)); + + subscription1.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(false)); + + subscription2.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(true)); + + assertThat(upstreamCacheRequestCount.get(), is(1)); + } + + @Test + void cacheSubscriptionAndUnsubscriptionWithMulticastMinSubscribers() { + final PublisherCache publisherCache = PublisherCache.multicast(2); + + final TestPublisherSubscriber subscriber1 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber1); + final Subscription subscription1 = subscriber1.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(0)); + + final TestPublisherSubscriber subscriber2 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber2); + final Subscription subscription2 = subscriber2.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(1)); + + subscription1.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(false)); + + subscription2.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(true)); + + assertThat(upstreamCacheRequestCount.get(), is(1)); + } + + @Test + void cacheSubscriptionAndUnsubscriptionWithReplay() { + final PublisherCache publisherCache = PublisherCache.replay( + ReplayStrategies.historyBuilder(1) + .cancelUpstream(true) + .minSubscribers(1) + .build()); + + final TestPublisherSubscriber subscriber1 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber1); + final Subscription subscription1 = subscriber1.awaitSubscription(); + + subscription1.request(3); + testPublisher.onNext(1, 2, 3); + assertThat(subscriber1.takeOnNext(3), is(Arrays.asList(1, 2, 3))); + + final TestPublisherSubscriber subscriber2 = new TestPublisherSubscriber<>(); + toSource(publisherCache.get("foo", publisherSupplier)).subscribe(subscriber2); + final Subscription subscription2 = subscriber2.awaitSubscription(); + + assertThat(upstreamSubscriptionCount.get(), is(1)); + + subscription2.request(1); + assertThat(subscriber2.takeOnNext(), is(3)); + + subscription1.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(false)); + + subscription2.cancel(); + assertThat(isUpstreamUnsubsrcibed.get(), is(true)); + + assertThat(upstreamCacheRequestCount.get(), is(1)); + } + + @Test + void testErrorFromUpstreamInvalidatesCacheEntryAndRequestsANewStream() { + final PublisherCache publisherCache = PublisherCache.multicast(); + + final TestPublisherSubscriber subscriber1 = new TestPublisherSubscriber<>(); + // use an "always" retry policy to force the situation + final Publisher discovery = publisherCache.get("foo", publisherSupplier).retry((i, cause) -> true); + toSource(discovery).subscribe(subscriber1); + final Subscription subscription1 = subscriber1.awaitSubscription(); + + subscription1.request(2); + testPublisher.onError(new Exception("bad stuff happened")); + + assertThat(upstreamCacheRequestCount.get(), is(2)); + } +}