+
+
+
+
+""".encode("utf-8")
+
+bad_request_page: bytes = (
+ "
400 - Bad page request.
".encode(
+ "utf-8"
+ )
+)
+
+unknown_page: bytes = (
+ "
404 - Page not found.
".encode("utf-8")
+)
+
+
+def response_page(
+ refresh_token: str, access_token: str, response: dict[str, Any]
+) -> bytes:
+ return f"""
+
+
+
+
+
+
Refresh Token:
+
{refresh_token}
+
Access Token:
+
{access_token}
+
Complete Response:
+
{response}
+
+
+""".encode("utf-8")
+
+
+class RequestHandler(BaseHTTPRequestHandler):
+ def do_GET(self) -> None:
+ # Serve root page with sign in link
+ if self.path == "/":
+ self.send_response(200)
+ self.send_header("Content-type", "text/html; charset=utf-8")
+ self.end_headers()
+ self.wfile.write(root_page)
+
+ return
+
+ if self.path.startswith("/submit"):
+ # Parse query components from url
+ query = urlparse(self.path).query
+ query_components = dict(qc.split("=") for qc in query.split("&"))
+
+ credentials.key = query_components["apiKey"]
+ credentials.secret = query_components["apiSecret"]
+ credentials.scopes = query_components["scopes"].replace("+", "%20")
+
+ # Redirect to login page using API key submitted by user
+ self.send_response(302)
+ self.send_header("Location", get_access_url(credentials))
+ self.end_headers()
+
+ return
+
+ if self.path.startswith("/?code"):
+ # Check if query path contains case insensitive "code="
+ code_match = re.search(r"code=(.+)", self.path, re.I)
+
+ if code_match and credentials.key and credentials.secret:
+ user_auth_code = code_match[1]
+ token_access = convert_auth_code(credentials, user_auth_code)
+
+ # Clear stored info
+ credentials.clear()
+
+ access_token = token_access["access_token"]
+ refresh_token = token_access["refresh_token"]
+
+ self.send_response(200)
+ self.send_header("Content-type", "text/html; charset=utf-8")
+ self.end_headers()
+
+ token_page = response_page(refresh_token, access_token, token_access)
+ self.wfile.write(token_page)
+ sys.exit(0)
+
+ else:
+ self.send_response(400)
+ self.send_header("Content-type", "text/html; charset=utf-8")
+ self.end_headers()
+ self.wfile.write(bad_request_page)
+
+ return
+
+ # Send 404 error if path is none of the above
+ self.send_response(404)
+ self.send_header("Content-type", "text/html; charset=utf-8")
+ self.end_headers()
+ self.wfile.write(unknown_page)
+
+
+def login() -> None:
+ """
+ Starts a local HTTP server and opens the browser to OAuth login.
+ """
+ httpd = HTTPServer(("", PORT), RequestHandler)
+ print(f"Opening url: {REDIRECT_URI}")
+ webbrowser.open(REDIRECT_URI)
+ httpd.serve_forever()
+
+
+if __name__ == "__main__":
+ login()
diff --git a/tradestation/session.py b/tradestation/session.py
index 1f210ae..3634d68 100644
--- a/tradestation/session.py
+++ b/tradestation/session.py
@@ -1,3 +1,175 @@
+import json
+from typing import Any
+
+import httpx
+import jwt
+from httpx import AsyncClient, Client
+from typing_extensions import Self
+
+from tradestation import API_URL_SIM, API_URL_V2, API_URL_V3, OAUTH_URL, logger
+from tradestation.account import Account
+from tradestation.utils import TradestationError, _validate_and_parse, validate_response
+
+
class Session:
- id: int
+ """
+ Contains a local user login which can then be used to interact with the
+ remote API.
+
+ :param api_key: Tradestation API key (client ID)
+ :param secret_key: Tradestation secret key (client secret)
+ :param refresh_token:
+ Tradestation refresh token used to obtain new access tokens; can be
+ acquired initially by calling :any:`tradestation.oauth.login`
+ :param access_token:
+ previously generated access token; if absent, refresh token will be
+ used to generate a new one automatically
+ :param id_token:
+ previously generated ID token; if absent, you won't be able to access
+ the `user_info` property until refreshing
+ :param is_test: whether to use the simulated API endpoints, default False
+ :param use_v2: whether to use the older v2 endpoints instead of the v3 ones
+ """
+
+ def __init__(
+ self,
+ api_key: str,
+ secret_key: str,
+ refresh_token: str,
+ access_token: str | None = None,
+ id_token: str | None = None,
+ is_test: bool = False,
+ use_v2: bool = False,
+ ):
+ if is_test and use_v2:
+ raise TradestationError(
+ "The simulation environment doesn't support v2 URLs!"
+ )
+ if is_test:
+ self.base_url = API_URL_SIM
+ elif use_v2:
+ self.base_url = API_URL_V2
+ else:
+ self.base_url = API_URL_V3
+
+ #: Tradestation client ID
+ self.api_key = api_key
+ #: Tradestation client secret
+ self.secret_key = secret_key
+ #: Access token for authenticating requests. By default, is valid for
+ #: 20 minutes, and then needs to be replaced.
+ self.access_token = access_token
+ #: Refresh token for generating new access tokens
+ self.refresh_token = refresh_token
+ #: ID token containing personal info like name, email
+ self.id_token = id_token
+
+ #: Whether this is a simulated or real session
+ self.is_test = is_test
+ # The headers to use for API requests
+ headers = {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.access_token}",
+ }
+ #: httpx client for sync requests
+ self.sync_client = Client(base_url=self.base_url, headers=headers)
+ #: httpx client for async requests
+ self.async_client = AsyncClient(base_url=self.base_url, headers=headers)
+
+ if not access_token:
+ self.refresh()
+
+ async def _a_get(self, url, **kwargs) -> Any:
+ response = await self.async_client.get(url, **kwargs)
+ return _validate_and_parse(response)
+
+ def _get(self, url, **kwargs) -> Any:
+ response = self.sync_client.get(url, **kwargs)
+ return _validate_and_parse(response)
+
+ def refresh(self) -> None:
+ """
+ Refreshes the acccess token using the stored refresh token.
+ """
+ response = httpx.post(
+ f"{OAUTH_URL}/oauth/token",
+ data={
+ "grant_type": "refresh_token",
+ "client_id": self.api_key,
+ "client_secret": self.secret_key,
+ "refresh_token": self.refresh_token,
+ },
+ )
+ data: dict[str, str] = _validate_and_parse(response)
+ # update the relevant tokens
+ self.access_token = data["access_token"]
+ self.id_token = data["id_token"]
+ expires_in = data.get("expires_in", "?")
+ logger.debug(f"Refreshed token, expires in {expires_in} seconds.")
+ auth_headers = {"Authorization": f"Bearer {self.access_token}"}
+ # update the httpx clients with the new token
+ self.sync_client.headers.update(auth_headers)
+ self.async_client.headers.update(auth_headers)
+
+ def revoke(self) -> None:
+ """
+ Revokes all valid refresh tokens.
+ """
+ response = httpx.post(
+ f"{OAUTH_URL}/oauth/revoke",
+ data={
+ "client_id": self.api_key,
+ "client_secret": self.secret_key,
+ "token": self.refresh_token,
+ },
+ )
+ validate_response(response)
+ logger.debug("Successfully revoked refresh tokens!")
+
+ def serialize(self) -> str:
+ """
+ Serializes the session to a string, useful for storing
+ a session for later use.
+ Could be used with pickle, Redis, etc.
+ """
+ attrs = self.__dict__.copy()
+ del attrs["async_client"]
+ del attrs["sync_client"]
+ return json.dumps(attrs)
+
+ @classmethod
+ def deserialize(cls, serialized: str) -> Self:
+ """
+ Create a new Session object from a serialized string.
+ """
+ deserialized = json.loads(serialized)
+ self = cls.__new__(cls)
+ self.__dict__ = deserialized
+ headers = {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.access_token}",
+ }
+ self.sync_client = Client(base_url=self.base_url, headers=headers)
+ self.async_client = AsyncClient(base_url=self.base_url, headers=headers)
+ return self
+
+ @property
+ def user_info(self) -> dict[str, str]:
+ """
+ Contains user info depending on the OAuth scopes provided (eg email, name)
+ If you call this on a session where you didn't provide an `id_token` and you
+ haven't refreshed yet, this will return an empty dict.
+ """
+ if self.id_token is None:
+ return {}
+ return jwt.decode(self.id_token, options={"verify_signature": False})
+
+ def get_accounts(self) -> list[Account]:
+ data = self._get("/brokerage/accounts")
+ return [Account(**item) for item in data["Accounts"]]
+ async def a_get_accounts(self) -> list[Account]:
+ data = await self._a_get("/brokerage/accounts")
+ return [Account(**item) for item in data["Accounts"]]
diff --git a/tradestation/utils.py b/tradestation/utils.py
new file mode 100644
index 0000000..a9171a7
--- /dev/null
+++ b/tradestation/utils.py
@@ -0,0 +1,37 @@
+from typing import Any
+from httpx._models import Response
+from pydantic import BaseModel, ConfigDict
+from pydantic.alias_generators import to_pascal
+
+
+class TradestationError(Exception):
+ """
+ An internal error raised by the Tradestation SDK.
+ """
+
+ pass
+
+
+class TradestationModel(BaseModel):
+ """
+ A pydantic dataclass that converts keys from snake case to Pascal case
+ and performs type validation and coercion.
+ """
+
+ model_config = ConfigDict(alias_generator=to_pascal, populate_by_name=True)
+
+
+def validate_response(response: Response) -> None:
+ """
+ Checks if the given code is an error; if so, raises an exception.
+
+ :param response: response to check for errors
+ """
+ if response.status_code // 100 != 2:
+ data = response.json()
+ raise TradestationError(f"{data['error']}: {data['error_description']}")
+
+
+def _validate_and_parse(response: Response) -> Any:
+ validate_response(response)
+ return response.json()
diff --git a/uv.lock b/uv.lock
index ecfd6cc..d9f9a4d 100644
--- a/uv.lock
+++ b/uv.lock
@@ -288,6 +288,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 },
]
+[[package]]
+name = "pyjwt"
+version = "2.10.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 },
+]
+
[[package]]
name = "pyright"
version = "1.1.391"
@@ -423,6 +432,7 @@ source = { editable = "." }
dependencies = [
{ name = "httpx" },
{ name = "pydantic" },
+ { name = "pyjwt" },
]
[package.dev-dependencies]
@@ -438,6 +448,7 @@ dev = [
requires-dist = [
{ name = "httpx", specifier = ">=0.27.2" },
{ name = "pydantic", specifier = ">=2.9.2" },
+ { name = "pyjwt", specifier = ">=2.10.1" },
]
[package.metadata.requires-dev]