Skip to content

Commit

Permalink
change create to __await__
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Dec 10, 2024
1 parent 0fdb1ec commit 9d4bfc7
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 60 deletions.
2 changes: 2 additions & 0 deletions docs/account-streamer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,5 @@ This callback can then be used when creating the streamer:
async with AlertStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer:
# ...
The reconnection uses `websockets`' exponential backoff algorithm, which can be configured through environment variables `here <https://websockets.readthedocs.io/en/14.1/reference/variables.html>`_.
6 changes: 4 additions & 2 deletions docs/data-streamer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ You can create a streamer using an active production session:
.. code-block:: python
from tastytrade import DXLinkStreamer
streamer = await DXLinkStreamer.create(session)
streamer = await DXLinkStreamer(session)
Or, you can create a streamer using an asynchronous context manager:

Expand Down Expand Up @@ -110,7 +110,7 @@ For example, we can use the streamer to create an option chain that will continu
# the `streamer_symbol` property is the symbol used by the streamer
streamer_symbols = [o.streamer_symbol for o in options]
streamer = await DXLinkStreamer.create(session)
streamer = await DXLinkStreamer(session)
# subscribe to quotes and greeks for all options on that date
await streamer.subscribe(Quote, [symbol] + streamer_symbols)
await streamer.subscribe(Greeks, streamer_symbols)
Expand Down Expand Up @@ -165,3 +165,5 @@ This callback can then be used when creating the streamer:
async with DXLinkStreamer(session, reconnect_fn=callback, reconnect_args=(arg1, arg2)) as streamer:
# ...
The reconnection uses `websockets`' exponential backoff algorithm, which can be configured through environment variables `here <https://websockets.readthedocs.io/en/14.1/reference/variables.html>`_.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"httpx>=0.27.2",
"pandas-market-calendars>=4.4.1",
"pydantic>=2.9.2",
"websockets>=14.1",
"websockets>=14.1,<15",
]

[project.urls]
Expand All @@ -29,7 +29,7 @@ dev-dependencies = [
"pytest-aio>=1.5.0",
"pytest-cov>=5.0.0",
"ruff>=0.6.9",
"pyright>=1.1.384",
"pyright>=1.1.390",
]

[tool.setuptools.package-data]
Expand Down
88 changes: 37 additions & 51 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from asyncio import Lock, Queue, QueueEmpty
from asyncio import Queue, QueueEmpty
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
Expand Down Expand Up @@ -167,8 +167,8 @@ class AlertStreamer:
"""
Used to subscribe to account-level updates (balances, orders, positions),
public watchlist updates, quote alerts, and user-level messages. It should
always be initialized as an async context manager, or with the `create`
function, since the object cannot be fully instantiated without async.
always be initialized as an async context manager, or by awaiting it,
since the object cannot be fully instantiated without async.
Example usage::
Expand All @@ -188,6 +188,10 @@ class AlertStreamer:
async for order in streamer.listen(PlacedOrder):
print(order)
Or::
streamer = await AlertStreamer(session)
"""

def __init__(
Expand Down Expand Up @@ -221,20 +225,13 @@ async def __aenter__(self):

return self

@classmethod
async def create(
cls,
session: Session,
reconnect_args: tuple[Any, ...] = (),
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
) -> "AlertStreamer":
self = cls(session, reconnect_args=reconnect_args, reconnect_fn=reconnect_fn)
return await self.__aenter__()
def __await__(self):
return self.__aenter__().__await__()

async def __aexit__(self, *exc):
self.close()

def close(self):
def close(self) -> None:
"""
Closes the websocket connection and cancels the pending tasks.
"""
Expand Down Expand Up @@ -266,10 +263,10 @@ async def _connect(self) -> None:
type_str = data.get("type")
if type_str is not None:
await self._map_message(type_str, data["data"])
except ConnectionClosed:
logger.debug("Websocket connection closed, retrying...")
reconnecting = True
continue
except ConnectionClosed as e:
logger.error(f"Websocket connection closed with {e}")
logger.debug("Websocket connection closed, retrying...")
reconnecting = True

async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]:
"""
Expand All @@ -285,7 +282,7 @@ async def listen(self, alert_class: Type[T]) -> AsyncIterator[T]:
while True:
yield await self._queues[cls_str].get()

async def _map_message(self, type_str: str, data: dict):
async def _map_message(self, type_str: str, data: dict) -> None:
"""
I'm not sure what the user-status messages look like, so they're absent.
"""
Expand Down Expand Up @@ -355,8 +352,8 @@ class DXLinkStreamer:
"""
A :class:`DXLinkStreamer` object is used to fetch quotes or greeks for a
given symbol or list of symbols. It should always be initialized as an
async context manager, or with the `create` function, since the object
cannot be fully instantiated without async.
async context manager, or by awaiting it, since the object cannot be
fully instantiated without async.
Example usage::
Expand All @@ -370,6 +367,10 @@ class DXLinkStreamer:
quote = await streamer.get_event(Quote)
print(quote)
Or::
streamer = await DXLinkStreamer(session)
"""

def __init__(
Expand All @@ -379,8 +380,6 @@ def __init__(
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
ssl_context: SSLContext = create_default_context(),
):
self._counter = 0
self._lock: Lock = Lock()
self._queues: dict[str, Queue] = defaultdict(Queue)
self._channels: dict[str, int] = {
"Candle": 1,
Expand All @@ -400,17 +399,15 @@ def __init__(
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args

# The unique client identifier received from the server
self._session = session
self._authenticated = False
self._wss_url = session.dxlink_url
self._auth_token = session.streamer_token
self._ssl_context = ssl_context

self._connect_task = asyncio.create_task(self._connect())
self._reconnect_task = None

async def __aenter__(self):
self._connect_task = asyncio.create_task(self._connect())
time_out = 100
while not self._authenticated:
await asyncio.sleep(0.1)
Expand All @@ -420,26 +417,13 @@ async def __aenter__(self):

return self

@classmethod
async def create(
cls,
session: Session,
reconnect_args: tuple[Any, ...] = (),
reconnect_fn: Optional[Callable[..., Coroutine[Any, Any, None]]] = None,
ssl_context: SSLContext = create_default_context(),
) -> "DXLinkStreamer":
self = cls(
session,
reconnect_args=reconnect_args,
reconnect_fn=reconnect_fn,
ssl_context=ssl_context,
)
return await self.__aenter__()
def __await__(self):
return self.__aenter__().__await__()

async def __aexit__(self, *exc):
self.close()

def close(self):
def close(self) -> None:
"""
Closes the websocket connection and cancels the heartbeat task.
"""
Expand Down Expand Up @@ -472,6 +456,8 @@ async def _connect(self) -> None:
)
# run reconnect hook upon auth completion
if reconnecting and self.reconnect_fn is not None:
self._subscription_state.clear()
reconnecting = False
self._reconnect_task = asyncio.create_task(
self.reconnect_fn(self, *self.reconnect_args)
)
Expand All @@ -481,15 +467,15 @@ async def _connect(self) -> None:
for k, v in self._channels.items()
if v == message["channel"]
)
self._subscription_state[channel] = message["type"]
self._subscription_state[channel] = "CHANNEL_OPENED"
logger.debug("Channel opened: %s", message)
elif message["type"] == "CHANNEL_CLOSED":
channel = next(
k
for k, v in self._channels.items()
if v == message["channel"]
)
self._subscription_state[channel] = message["type"]
del self._subscription_state[channel]
logger.debug("Channel closed: %s", message)
elif message["type"] == "FEED_CONFIG":
logger.debug("Feed configured: %s", message)
Expand All @@ -498,13 +484,13 @@ async def _connect(self) -> None:
elif message["type"] == "KEEPALIVE":
pass
else:
raise TastytradeError("Unknown message type:", message)
except ConnectionClosed:
logger.debug("Websocket connection closed, retrying...")
reconnecting = True
continue
logger.error(f"Streamer error: {message}")
except ConnectionClosed as e:
logger.error(f"Websocket connection closed with {e}")
logger.debug("Websocket connection closed, retrying...")
reconnecting = True

async def _setup_connection(self):
async def _setup_connection(self) -> None:
message = {
"type": "SETUP",
"channel": 0,
Expand All @@ -514,7 +500,7 @@ async def _setup_connection(self):
}
await self._websocket.send(json.dumps(message))

async def _authenticate_connection(self):
async def _authenticate_connection(self) -> None:
message = {
"type": "AUTH",
"channel": 0,
Expand Down Expand Up @@ -744,7 +730,7 @@ async def unsubscribe_candle(
}
await self._websocket.send(json.dumps(message))

async def _map_message(self, message) -> None: # pragma: no cover
async def _map_message(self, message) -> None:
"""
Takes the raw JSON data, parses the events and places them into their
respective queues.
Expand Down
36 changes: 36 additions & 0 deletions tests/test_streamer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime, timedelta

from tastytrade import Account, AlertStreamer, DXLinkStreamer
Expand Down Expand Up @@ -29,3 +30,38 @@ async def test_dxlink_streamer(session):
await streamer.unsubscribe_candle(subs[0], "1d")
await streamer.unsubscribe(Quote, [subs[0]])
await streamer.unsubscribe_all(Quote)


async def reconnect_alerts(streamer: AlertStreamer, ref: dict[str, bool]):
await streamer.subscribe_quote_alerts()
ref["test"] = True


async def test_account_streamer_reconnect(session):
ref = {}
streamer = await AlertStreamer(
session, reconnect_args=(ref,), reconnect_fn=reconnect_alerts
)
await streamer.subscribe_public_watchlists()
await streamer.subscribe_user_messages(session)
accounts = Account.get_accounts(session)
await streamer.subscribe_accounts(accounts)
await streamer._websocket.close() # type: ignore
await asyncio.sleep(3)
assert "test" in ref
streamer.close()


async def reconnect_trades(streamer: DXLinkStreamer):
await streamer.subscribe(Trade, ["SPX"])


async def test_dxlink_streamer_reconnect(session):
streamer = await DXLinkStreamer(session, reconnect_fn=reconnect_trades)
await streamer.subscribe(Quote, ["SPY"])
_ = await streamer.get_event(Quote)
await streamer._websocket.close()
await asyncio.sleep(3)
trade = await streamer.get_event(Trade)
assert trade.event_symbol == "SPX"
streamer.close()
10 changes: 5 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9d4bfc7

Please sign in to comment.