Skip to content

Commit

Permalink
add proxies (#216)
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 authored Feb 17, 2025
1 parent c862105 commit dc24afb
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 84 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"httpx>=0.27.2",
"pandas-market-calendars>=4.4.1",
"pydantic>=2.9.2",
"websockets>=14.1,<15",
"websockets>=15",
]
dynamic = ["version"]

Expand All @@ -37,6 +37,7 @@ dev-dependencies = [
"sphinx-rtd-theme>=3.0.2",
"enum-tools[sphinx]>=0.12.0",
"autodoc-pydantic>=2.2.0",
"proxy-py>=2.4.9",
]

[tool.setuptools.package-data]
Expand Down
2 changes: 1 addition & 1 deletion tastytrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
BACKTEST_URL = "https://backtester.vast.tastyworks.com"
CERT_URL = "https://api.cert.tastyworks.com"
VAST_URL = "https://vast.tastyworks.com"
VERSION = "9.9"
VERSION = "9.10"

__version__ = VERSION

Expand Down
24 changes: 20 additions & 4 deletions tastytrade/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ class Session:
user's device
:param dxfeed_tos_compliant:
whether to use the dxfeed TOS-compliant API endpoint for the streamer
:param proxy:
if provided, all requests will be made through this proxy, as well as
web socket connections for streamers.
"""

def __init__(
Expand All @@ -291,6 +294,7 @@ def __init__(
is_test: bool = False,
two_factor_authentication: Optional[str] = None,
dxfeed_tos_compliant: bool = False,
proxy: Optional[str] = None,
):
body = {"login": login, "remember-me": remember_me}
if password is not None:
Expand All @@ -303,14 +307,16 @@ def __init__(
)
#: Whether this is a cert or real session
self.is_test = is_test
#: Proxy URL to use for requests and web sockets
self.proxy = proxy
# The headers to use for API requests
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
#: httpx client for sync requests
self.sync_client = Client(
base_url=(CERT_URL if is_test else API_URL), headers=headers
base_url=(CERT_URL if is_test else API_URL), headers=headers, proxy=proxy
)
if two_factor_authentication is not None:
response = self.sync_client.post(
Expand All @@ -330,7 +336,9 @@ def __init__(
self.sync_client.headers.update({"Authorization": self.session_token})
#: httpx client for async requests
self.async_client = AsyncClient(
base_url=self.sync_client.base_url, headers=self.sync_client.headers.copy()
base_url=self.sync_client.base_url,
headers=self.sync_client.headers.copy(),
proxy=proxy,
)

# Pull streamer tokens and urls
Expand All @@ -345,6 +353,12 @@ def __init__(
#: URL for dxfeed websocket
self.dxlink_url = data["dxlink-url"]

def __enter__(self):
return self

def __exit__(self, *exc):
self.destroy()

async def _a_get(self, url, **kwargs) -> dict[str, Any]:
response = await self.async_client.get(url, timeout=30, **kwargs)
return validate_and_parse(response)
Expand Down Expand Up @@ -468,6 +482,8 @@ def deserialize(cls, serialized: str) -> Self:
"Content-Type": "application/json",
"Authorization": self.session_token,
}
self.sync_client = Client(base_url=base_url, headers=headers)
self.async_client = AsyncClient(base_url=base_url, headers=headers)
self.sync_client = Client(base_url=base_url, headers=headers, proxy=self.proxy)
self.async_client = AsyncClient(
base_url=base_url, headers=headers, proxy=self.proxy
)
return self
12 changes: 10 additions & 2 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(
self.reconnect_fn = reconnect_fn
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args
#: The proxy URL, if any, associated with the session
self.proxy = session.proxy

self._queues: dict[str, Queue] = defaultdict(Queue)
self._websocket: Optional[ClientConnection] = None
Expand Down Expand Up @@ -251,7 +253,9 @@ async def _connect(self) -> None:
"""
headers = {"Authorization": f"Bearer {self.token}"}
reconnecting = False
async for websocket in connect(self.base_url, additional_headers=headers):
async for websocket in connect(
self.base_url, additional_headers=headers, proxy=self.proxy
):
self._websocket = websocket
self._heartbeat_task = asyncio.create_task(self._heartbeat())
logger.debug("Websocket connection established.")
Expand Down Expand Up @@ -413,6 +417,8 @@ def __init__(
self.reconnect_fn = reconnect_fn
#: Variable number of arguments to pass to the reconnect function
self.reconnect_args = reconnect_args
#: The proxy URL, if any, associated with the session
self.proxy = session.proxy

self._authenticated = False
self._wss_url = session.dxlink_url
Expand Down Expand Up @@ -456,7 +462,9 @@ async def _connect(self) -> None:
authorization token provided during initialization.
"""
reconnecting = False
async for websocket in connect(self._wss_url, ssl=self._ssl_context):
async for websocket in connect(
self._wss_url, ssl=self._ssl_context, proxy=self.proxy
):
self._websocket = websocket
await self._setup_connection()
try:
Expand Down
10 changes: 7 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def credentials() -> tuple[str, str]:
async def session(
credentials: tuple[str, str], aiolib: str
) -> AsyncGenerator[Session, None]:
session = Session(*credentials)
yield session
session.destroy()
with Session(*credentials) as session:
yield session


@fixture(scope="class")
def inject_credentials(request, credentials: tuple[str, str]):
request.cls.credentials = credentials
15 changes: 15 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest
from proxy import TestCase

from tastytrade import Session


Expand Down Expand Up @@ -31,3 +34,15 @@ def test_serialize_deserialize(session: Session):
data = session.serialize()
obj = Session.deserialize(data)
assert set(obj.__dict__.keys()) == set(session.__dict__.keys())


@pytest.mark.usefixtures("inject_credentials")
class TestProxy(TestCase):
def test_session_with_proxy(self):
assert self.PROXY is not None
session = Session(
*self.credentials, # type: ignore
proxy=f"http://127.0.0.1:{self.PROXY.flags.port}",
)
assert session.validate()
session.destroy()
19 changes: 19 additions & 0 deletions tests/test_streamer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import asyncio
from datetime import datetime, timedelta
from unittest import IsolatedAsyncioTestCase

import pytest
from proxy import TestCase

from tastytrade import Account, AlertStreamer, DXLinkStreamer, Session
from tastytrade.dxfeed import Candle, Quote, Trade
Expand Down Expand Up @@ -65,3 +69,18 @@ async def test_dxlink_streamer_reconnect(session: Session):
trade = await streamer.get_event(Trade)
assert trade.event_symbol == "SPX"
await streamer.close()


@pytest.mark.usefixtures("inject_credentials")
class TestProxy(TestCase, IsolatedAsyncioTestCase):
@pytest.mark.asyncio
async def test_streamer_with_proxy(self):
assert self.PROXY is not None
with Session(
*self.credentials, # type: ignore
proxy=f"http://127.0.0.1:{self.PROXY.flags.port}",
) as session:
assert session.validate()
async with DXLinkStreamer(session) as streamer:
await streamer.subscribe(Quote, ["SPY"])
_ = await streamer.get_event(Quote)
Loading

0 comments on commit dc24afb

Please sign in to comment.