Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: do not create many redis subscriptions per token - fast yielding background #4419

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def __init__(self, namespace: str, app: App):
super().__init__(namespace)
self.app = app

def on_connect(self, sid, environ):
async def on_connect(self, sid, environ):
"""Event for when the websocket is connected.

Args:
Expand All @@ -1486,7 +1486,7 @@ def on_connect(self, sid, environ):
"""
pass

def on_disconnect(self, sid):
async def on_disconnect(self, sid):
"""Event for when the websocket disconnects.

Args:
Expand All @@ -1495,6 +1495,7 @@ def on_disconnect(self, sid):
disconnect_token = self.sid_to_token.pop(sid, None)
if disconnect_token:
self.token_to_sid.pop(disconnect_token, None)
await self.app.state_manager.disconnect(sid)

async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.
Expand Down
42 changes: 41 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,6 +2826,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""
yield self.state()

async def disconnect(self, token: str) -> None:
"""Disconnect the client with the given token.

Args:
token: The token to disconnect.
"""
pass


class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
Expand Down Expand Up @@ -2895,6 +2903,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
yield state
await self.set_state(token, state)

@override
async def disconnect(self, token: str) -> None:
"""Disconnect the client with the given token.

Args:
token: The token to disconnect.
"""
if token in self.states:
del self.states[token]
if lock := self._states_locks.get(token):
if lock.locked():
lock.release()
del self._states_locks[token]


def _default_token_expiration() -> int:
"""Get the default token expiration time.
Expand Down Expand Up @@ -3183,6 +3205,9 @@ class StateManagerRedis(StateManager):
b"evicted",
}

# This lock is used to ensure we only subscribe to keyspace events once per token and worker
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})

async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
Expand Down Expand Up @@ -3458,7 +3483,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
# Some redis servers only allow out-of-band configuration, so ignore errors here.
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
raise
async with self.redis.pubsub() as pubsub:
if lock_key not in self._pubsub_locks:
self._pubsub_locks[lock_key] = asyncio.Lock()
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub:
await pubsub.psubscribe(lock_key_channel)
while not state_is_locked:
# wait for the lock to be released
Expand All @@ -3475,6 +3502,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
break
state_is_locked = await self._try_get_lock(lock_key, lock_id)

@override
async def disconnect(self, token: str):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably wrong. Clients can reconnect. Need to check if there are any unexpired redis tokens

Copy link
Collaborator

@masenf masenf Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: oops i thought this was the StateManagerMemory.disconnect func; either way we need to be careful about reconnects, which can be triggered by a simple page refresh

i had a draft comment on my in progress review calling out this function, but i suppose your testing found it first. we need to keep the in memory states around. we do have an internal bug to add an expiry for memory state manager, but it just hasn't been high priority because we assume no one is using memory state manager in production

Copy link
Contributor Author

@benedikt-bartscher benedikt-bartscher Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one asyncio Lock is enough, we won't need to garbage collect then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: subscribe to redis expire events to clean up locks, may be easier with #4459

"""Disconnect the token from the redis client.

Args:
token: The token to disconnect.
"""
lock_key = self._lock_key(token)
if lock := self._pubsub_locks.get(lock_key):
if lock.locked():
lock.release()
del self._pubsub_locks[lock_key]

@contextlib.asynccontextmanager
async def _lock(self, token: str):
"""Obtain a redis lock for a token.
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ async def handle_event_yield_only(self):
yield State.increment() # type: ignore
await asyncio.sleep(0.005)

@rx.event(background=True)
async def fast_yielding(self):
for _ in range(1000):
yield State.increment()

@rx.event
def increment(self):
self.counter += 1
Expand Down Expand Up @@ -169,6 +174,11 @@ def index() -> rx.Component:
on_click=State.yield_in_async_with_self,
id="yield-in-async-with-self",
),
rx.button(
"Fast Yielding",
on_click=State.fast_yielding,
id="fast-yielding",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)

Expand Down Expand Up @@ -375,3 +385,28 @@ def test_yield_in_async_with_self(

yield_in_async_with_self_button.click()
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)


def test_fast_yielding(
background_task: AppHarness,
driver: WebDriver,
token: str,
) -> None:
"""Test that fast yielding works as expected.

Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None

# get a reference to all buttons
fast_yielding_button = driver.find_element(By.ID, "fast-yielding")

# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)

fast_yielding_button.click()
assert background_task._poll_for(lambda: counter.text == "1000", timeout=50)
Loading