From 904a334187c9c32febbe34ba81673832a8c775e5 Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Mon, 13 Jan 2025 19:32:03 -0500 Subject: [PATCH] sessions & oauth --- README.md | 44 ++++- pyproject.toml | 1 + tests/test_session.py | 13 +- tradestation/__init__.py | 24 ++- tradestation/account.py | 22 +++ tradestation/oauth.py | 364 +++++++++++++++++++++++++++++++++++++++ tradestation/session.py | 174 ++++++++++++++++++- tradestation/utils.py | 37 ++++ uv.lock | 11 ++ 9 files changed, 684 insertions(+), 6 deletions(-) create mode 100644 tradestation/account.py create mode 100644 tradestation/oauth.py create mode 100644 tradestation/utils.py diff --git a/README.md b/README.md index f5756d5..e206e3c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,44 @@ +[![Docs](https://readthedocs.org/projects/tradestation/badge/?version=latest)](https://tradestation.readthedocs.io/en/latest/?badge=latest) +[![PyPI](https://img.shields.io/pypi/v/tradestation)](https://pypi.org/project/tradestation) +[![Downloads](https://static.pepy.tech/badge/tradestation)](https://pepy.tech/project/tradestation) +[![Release)](https://img.shields.io/github/v/release/tastyware/tradestation?label=release%20notes)](https://github.com/tastyware/tradestation/releases) + # tradestation -An unofficial Python SDK for Tradestation! +A simple, unofficial, sync/async SDK for Tradestation built on their public API. This will allow you to create trading algorithms for whatever strategies you may have quickly and painlessly in Python. + +## Features + +- Up to 10x less code than using the API directly +- Sync/async functions for all endpoints +- Powerful websocket implementation for account alerts and data streaming, with support for auto-reconnection and reconnection callbacks +- 100% typed, with Pydantic models for all JSON responses from the API +- 95%+ unit test coverage +- Comprehensive documentation +- Utility functions for timezone calculations, futures monthly expiration dates, and more + +## Installation + +```console +$ pip install tradestation +``` + +## Initial setup + +Tradestation uses OAuth for secure authentication to the API. In order to obtain access tokens, you need to authenticate with OAuth 2's authorization code flow, which requires a local HTTP server running to handle the callback. Fortunately, the SDK makes doing this easy: + +```python +from tradestation.oauth import login +login() +``` + +This will let you authenticate in your local browser. Fortunately, this only needs to be done once, as afterwards you can use the refresh token to obtain new access tokens indefinitely. + +## Creating a session + +A session object is required to authenticate your requests to the Tradestation API. +You can create a simulation session by passing `is_test=True`. + +```python +from tradestation import Session +session = Session('api_key', 'secret_key', 'refresh_token') +``` diff --git a/pyproject.toml b/pyproject.toml index b35a942..f982003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ authors = [ dependencies = [ "httpx>=0.27.2", "pydantic>=2.9.2", + "pyjwt>=2.10.1", ] [project.urls] diff --git a/tests/test_session.py b/tests/test_session.py index eb428f3..9e5c77e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,7 +1,14 @@ +import os + from tradestation.session import Session def test_session(): - session = Session() - assert session is not None - + api_key = os.getenv("TS_API_KEY") + secret_key = os.getenv("TS_SECRET_KEY") + refresh_token = os.getenv("TS_REFRESH") + assert api_key is not None + assert secret_key is not None + assert refresh_token is not None + session = Session(api_key, secret_key, refresh_token) + assert session.user_info != {} diff --git a/tradestation/__init__.py b/tradestation/__init__.py index d49e0c5..1b83bfa 100644 --- a/tradestation/__init__.py +++ b/tradestation/__init__.py @@ -3,6 +3,27 @@ API_URL_V3 = "https://api.tradestation.com/v3" API_URL_V2 = "https://api.tradestation.com/v2" API_URL_SIM = "https://sim-api.tradestation.com/v3" +OAUTH_SCOPES = [ + # Requests access to lookup or stream Market Data. + "MarketData", + # Requests access to view Brokerage Accounts belonging to the current user. + "ReadAccount", + # Requests access to execute orders on behalf of the current user's account(s). + "Trade", + # Request access to execute options related endpoints. + "OptionSpreads", + # Request access to execute market depth related endpoints. + "Matrix", + # Returns the sub claim, which uniquely identifies the user. In an ID Token, iss, aud, exp, iat, and at_hash claims will also be present. + "openid", + # Allows for use of Refresh Tokens. + "offline_access", + # Returns claims in the ID Token that represent basic profile information, including name, family_name, given_name, middle_name, nickname, picture, and updated_at. + "profile", + # Returns the email claim in the ID Token, which contains the user's email address, and email_verified, which is a boolean indicating whether the email address was verified by the user. + "email", +] +OAUTH_URL = "https://signin.tradestation.com" VERSION = "0.2" logger = logging.getLogger(__name__) @@ -10,6 +31,7 @@ # ruff: noqa: E402 +from .account import Account from .session import Session -__all__ = ["Session"] +__all__ = ["Account", "Session"] diff --git a/tradestation/account.py b/tradestation/account.py new file mode 100644 index 0000000..948e2df --- /dev/null +++ b/tradestation/account.py @@ -0,0 +1,22 @@ +from pydantic import Field + +from tradestation.utils import TradestationModel + + +class AccountDetail(TradestationModel): + day_trading_qualified: bool + enrolled_in_reg_t_program: bool + is_stock_locate_eligible: bool + option_approval_level: int + pattern_day_trader: bool + requires_buying_power_warning: bool + + +class Account(TradestationModel): + account_detail: AccountDetail | None = None + account_id: str = Field(alias="AccountID") + account_type: str + alias: str | None = None + alt_id: str | None = Field(default=None, alias="AltID") + currency: str + status: str diff --git a/tradestation/oauth.py b/tradestation/oauth.py new file mode 100644 index 0000000..22b7873 --- /dev/null +++ b/tradestation/oauth.py @@ -0,0 +1,364 @@ +# Based on https://community.tradestation.com/Discussions/Topic.aspx?Topic_ID=205209 +# This file is designed to provide a simple way to obtain a refresh +# token and an initial access token using v3 of the Web API. + +import re +import sys +from typing import Any +import webbrowser +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.parse import urlparse + +import httpx +from pydantic import BaseModel + +from tradestation import OAUTH_SCOPES, OAUTH_URL + +AUDIENCE = "https://api.tradestation.com" +PORT = 3001 +REDIRECT_URI = f"http://localhost:{PORT}" +SCOPES = " ".join(OAUTH_SCOPES) + + +class Credentials(BaseModel): + key: str = "" + secret: str = "" + scopes: str = SCOPES + + def clear(self) -> None: + self.key = "" + self.secret = "" + self.scopes = SCOPES + + +credentials = Credentials() + + +def get_access_url(credentials: Credentials) -> str: + query_string = "&".join( + [ + "response_type=code", + f"audience={AUDIENCE}", + f"redirect_uri={REDIRECT_URI}", + f"client_id={credentials.key}", + f"scope={credentials.scopes}", + ] + ) + access_url = f"{OAUTH_URL}/authorize?{query_string}" + return access_url + + +def convert_auth_code(credentials: Credentials, auth_code: str) -> dict[str, Any]: + """ + Uses an api key, a secret key and authorization code to obtain a response + containing an access token, refresh token, user id, and expriation time + """ + post_data = { + "grant_type": "authorization_code", + "client_id": credentials.key, + "client_secret": credentials.secret, + "redirect_uri": REDIRECT_URI, + "code": auth_code, + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + response = httpx.post(f"{OAUTH_URL}/oauth/token", headers=headers, data=post_data) + if response.status_code != 200: + raise Exception( + "Could not load access and refresh tokens from authorization code!" + ) + + return response.json() + + +root_page: bytes = f""" + + + + Web API + + + +
+
+
+
WEB API V3
+
+
+
+
+
+ +
+
+
+
+
+
+ +
+
+
+
+
+
+ +
+
+
+ +
+
+
+
 
+
+
+
+
+ + + +""".encode("utf-8") + +bad_request_page: bytes = ( + "
400 - Bad page request.
".encode( + "utf-8" + ) +) + +unknown_page: bytes = ( + "
404 - Page not found.
".encode("utf-8") +) + + +def response_page( + refresh_token: str, access_token: str, response: dict[str, Any] +) -> bytes: + return f""" + + + + + +
Refresh Token:
+
{refresh_token}
+Access Token: +
{access_token}
+Complete Response: +
{response}
+ + +""".encode("utf-8") + + +class RequestHandler(BaseHTTPRequestHandler): + def do_GET(self) -> None: + # Serve root page with sign in link + if self.path == "/": + self.send_response(200) + self.send_header("Content-type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write(root_page) + + return + + if self.path.startswith("/submit"): + # Parse query components from url + query = urlparse(self.path).query + query_components = dict(qc.split("=") for qc in query.split("&")) + + credentials.key = query_components["apiKey"] + credentials.secret = query_components["apiSecret"] + credentials.scopes = query_components["scopes"].replace("+", "%20") + + # Redirect to login page using API key submitted by user + self.send_response(302) + self.send_header("Location", get_access_url(credentials)) + self.end_headers() + + return + + if self.path.startswith("/?code"): + # Check if query path contains case insensitive "code=" + code_match = re.search(r"code=(.+)", self.path, re.I) + + if code_match and credentials.key and credentials.secret: + user_auth_code = code_match[1] + token_access = convert_auth_code(credentials, user_auth_code) + + # Clear stored info + credentials.clear() + + access_token = token_access["access_token"] + refresh_token = token_access["refresh_token"] + + self.send_response(200) + self.send_header("Content-type", "text/html; charset=utf-8") + self.end_headers() + + token_page = response_page(refresh_token, access_token, token_access) + self.wfile.write(token_page) + sys.exit(0) + + else: + self.send_response(400) + self.send_header("Content-type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write(bad_request_page) + + return + + # Send 404 error if path is none of the above + self.send_response(404) + self.send_header("Content-type", "text/html; charset=utf-8") + self.end_headers() + self.wfile.write(unknown_page) + + +def login() -> None: + """ + Starts a local HTTP server and opens the browser to OAuth login. + """ + httpd = HTTPServer(("", PORT), RequestHandler) + print(f"Opening url: {REDIRECT_URI}") + webbrowser.open(REDIRECT_URI) + httpd.serve_forever() + + +if __name__ == "__main__": + login() diff --git a/tradestation/session.py b/tradestation/session.py index 1f210ae..3634d68 100644 --- a/tradestation/session.py +++ b/tradestation/session.py @@ -1,3 +1,175 @@ +import json +from typing import Any + +import httpx +import jwt +from httpx import AsyncClient, Client +from typing_extensions import Self + +from tradestation import API_URL_SIM, API_URL_V2, API_URL_V3, OAUTH_URL, logger +from tradestation.account import Account +from tradestation.utils import TradestationError, _validate_and_parse, validate_response + + class Session: - id: int + """ + Contains a local user login which can then be used to interact with the + remote API. + + :param api_key: Tradestation API key (client ID) + :param secret_key: Tradestation secret key (client secret) + :param refresh_token: + Tradestation refresh token used to obtain new access tokens; can be + acquired initially by calling :any:`tradestation.oauth.login` + :param access_token: + previously generated access token; if absent, refresh token will be + used to generate a new one automatically + :param id_token: + previously generated ID token; if absent, you won't be able to access + the `user_info` property until refreshing + :param is_test: whether to use the simulated API endpoints, default False + :param use_v2: whether to use the older v2 endpoints instead of the v3 ones + """ + + def __init__( + self, + api_key: str, + secret_key: str, + refresh_token: str, + access_token: str | None = None, + id_token: str | None = None, + is_test: bool = False, + use_v2: bool = False, + ): + if is_test and use_v2: + raise TradestationError( + "The simulation environment doesn't support v2 URLs!" + ) + if is_test: + self.base_url = API_URL_SIM + elif use_v2: + self.base_url = API_URL_V2 + else: + self.base_url = API_URL_V3 + + #: Tradestation client ID + self.api_key = api_key + #: Tradestation client secret + self.secret_key = secret_key + #: Access token for authenticating requests. By default, is valid for + #: 20 minutes, and then needs to be replaced. + self.access_token = access_token + #: Refresh token for generating new access tokens + self.refresh_token = refresh_token + #: ID token containing personal info like name, email + self.id_token = id_token + + #: Whether this is a simulated or real session + self.is_test = is_test + # The headers to use for API requests + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {self.access_token}", + } + #: httpx client for sync requests + self.sync_client = Client(base_url=self.base_url, headers=headers) + #: httpx client for async requests + self.async_client = AsyncClient(base_url=self.base_url, headers=headers) + + if not access_token: + self.refresh() + + async def _a_get(self, url, **kwargs) -> Any: + response = await self.async_client.get(url, **kwargs) + return _validate_and_parse(response) + + def _get(self, url, **kwargs) -> Any: + response = self.sync_client.get(url, **kwargs) + return _validate_and_parse(response) + + def refresh(self) -> None: + """ + Refreshes the acccess token using the stored refresh token. + """ + response = httpx.post( + f"{OAUTH_URL}/oauth/token", + data={ + "grant_type": "refresh_token", + "client_id": self.api_key, + "client_secret": self.secret_key, + "refresh_token": self.refresh_token, + }, + ) + data: dict[str, str] = _validate_and_parse(response) + # update the relevant tokens + self.access_token = data["access_token"] + self.id_token = data["id_token"] + expires_in = data.get("expires_in", "?") + logger.debug(f"Refreshed token, expires in {expires_in} seconds.") + auth_headers = {"Authorization": f"Bearer {self.access_token}"} + # update the httpx clients with the new token + self.sync_client.headers.update(auth_headers) + self.async_client.headers.update(auth_headers) + + def revoke(self) -> None: + """ + Revokes all valid refresh tokens. + """ + response = httpx.post( + f"{OAUTH_URL}/oauth/revoke", + data={ + "client_id": self.api_key, + "client_secret": self.secret_key, + "token": self.refresh_token, + }, + ) + validate_response(response) + logger.debug("Successfully revoked refresh tokens!") + + def serialize(self) -> str: + """ + Serializes the session to a string, useful for storing + a session for later use. + Could be used with pickle, Redis, etc. + """ + attrs = self.__dict__.copy() + del attrs["async_client"] + del attrs["sync_client"] + return json.dumps(attrs) + + @classmethod + def deserialize(cls, serialized: str) -> Self: + """ + Create a new Session object from a serialized string. + """ + deserialized = json.loads(serialized) + self = cls.__new__(cls) + self.__dict__ = deserialized + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + "Authorization": f"Bearer {self.access_token}", + } + self.sync_client = Client(base_url=self.base_url, headers=headers) + self.async_client = AsyncClient(base_url=self.base_url, headers=headers) + return self + + @property + def user_info(self) -> dict[str, str]: + """ + Contains user info depending on the OAuth scopes provided (eg email, name) + If you call this on a session where you didn't provide an `id_token` and you + haven't refreshed yet, this will return an empty dict. + """ + if self.id_token is None: + return {} + return jwt.decode(self.id_token, options={"verify_signature": False}) + + def get_accounts(self) -> list[Account]: + data = self._get("/brokerage/accounts") + return [Account(**item) for item in data["Accounts"]] + async def a_get_accounts(self) -> list[Account]: + data = await self._a_get("/brokerage/accounts") + return [Account(**item) for item in data["Accounts"]] diff --git a/tradestation/utils.py b/tradestation/utils.py new file mode 100644 index 0000000..a9171a7 --- /dev/null +++ b/tradestation/utils.py @@ -0,0 +1,37 @@ +from typing import Any +from httpx._models import Response +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_pascal + + +class TradestationError(Exception): + """ + An internal error raised by the Tradestation SDK. + """ + + pass + + +class TradestationModel(BaseModel): + """ + A pydantic dataclass that converts keys from snake case to Pascal case + and performs type validation and coercion. + """ + + model_config = ConfigDict(alias_generator=to_pascal, populate_by_name=True) + + +def validate_response(response: Response) -> None: + """ + Checks if the given code is an error; if so, raises an exception. + + :param response: response to check for errors + """ + if response.status_code // 100 != 2: + data = response.json() + raise TradestationError(f"{data['error']}: {data['error_description']}") + + +def _validate_and_parse(response: Response) -> Any: + validate_response(response) + return response.json() diff --git a/uv.lock b/uv.lock index ecfd6cc..d9f9a4d 100644 --- a/uv.lock +++ b/uv.lock @@ -288,6 +288,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 }, +] + [[package]] name = "pyright" version = "1.1.391" @@ -423,6 +432,7 @@ source = { editable = "." } dependencies = [ { name = "httpx" }, { name = "pydantic" }, + { name = "pyjwt" }, ] [package.dev-dependencies] @@ -438,6 +448,7 @@ dev = [ requires-dist = [ { name = "httpx", specifier = ">=0.27.2" }, { name = "pydantic", specifier = ">=2.9.2" }, + { name = "pyjwt", specifier = ">=2.10.1" }, ] [package.metadata.requires-dev]