Skip to content

Commit 0ef8f45

Browse files
authored
🔒 Generic Oauth installer (#1150)
* 🤌 Generic oauth installer * 🤌 Generic oauth installer * 🐛 Poetry issues * 🐛 Poetry issues
1 parent eccb538 commit 0ef8f45

File tree

12 files changed

+330
-20
lines changed

12 files changed

+330
-20
lines changed

next/prisma/schema.prisma

+33
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,39 @@ model VerificationToken {
9999
@@unique([identifier, token])
100100
}
101101

102+
model OAuthCredentials {
103+
id String @id @default(cuid())
104+
installation_id String
105+
provider String
106+
token_type String
107+
access_token String
108+
scope String?
109+
data Json
110+
111+
create_date DateTime @default(now())
112+
update_date DateTime? @updatedAt
113+
delete_date DateTime?
114+
115+
@@unique([installation_id])
116+
@@map("oauth_credentials")
117+
}
118+
119+
model OAuthInstallation {
120+
id String @id @default(cuid())
121+
user_id String
122+
organization_id String?
123+
provider String
124+
state String
125+
126+
create_date DateTime @default(now())
127+
update_date DateTime? @updatedAt
128+
delete_date DateTime?
129+
130+
@@unique([user_id, organization_id, provider])
131+
@@index([state])
132+
@@map("oauth_installation")
133+
}
134+
102135
model Agent {
103136
id String @id @default(cuid())
104137
userId String

platform/poetry.lock

+101-16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

platform/pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ httptools = "^0.5.0"
2525
sentry-sdk = "^1.28.1"
2626
loguru = "^0.7.0"
2727
aiokafka = "^0.8.1"
28-
requests = "2.28.0"
28+
requests = "^2.31.0"
2929
langchain = "0.0.218"
3030
openai = "^0.27.8"
3131
wikipedia = "^1.4.0"
@@ -47,6 +47,7 @@ aws-secretsmanager-caching = "^1.1.1.5"
4747
botocore = "^1.29.153"
4848
stripe = "^5.4.0"
4949
tabula-py = "^2.7.0"
50+
slack-sdk = "^3.21.3"
5051

5152
[tool.poetry.dev-dependencies]
5253
autopep8 = "^2.0.2"

platform/reworkd_platform/db/crud/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from typing import TypeVar
2+
13
from sqlalchemy.ext.asyncio import AsyncSession
24

5+
T = TypeVar("T", bound="BaseCrud")
6+
37

48
class BaseCrud:
59
def __init__(self, session: AsyncSession):
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import secrets
2+
from typing import Optional
3+
4+
from fastapi import Depends
5+
from sqlalchemy import select
6+
from sqlalchemy.ext.asyncio import AsyncSession
7+
8+
from reworkd_platform.db.crud.base import BaseCrud
9+
from reworkd_platform.db.dependencies import get_db_session
10+
from reworkd_platform.db.models.auth import OauthInstallation
11+
from reworkd_platform.schemas import UserBase
12+
13+
14+
class OAuthCrud(BaseCrud):
15+
@classmethod
16+
async def inject(
17+
cls,
18+
session: AsyncSession = Depends(get_db_session),
19+
) -> "OAuthCrud":
20+
return cls(session)
21+
22+
async def create_installation(
23+
self, user: UserBase, provider: str
24+
) -> OauthInstallation:
25+
return await OauthInstallation(
26+
user_id=user.id,
27+
organization_id=user.organization_id,
28+
provider=provider,
29+
state=secrets.token_hex(16),
30+
).save(self.session)
31+
32+
async def get_installation_by_state(
33+
self, state: str
34+
) -> Optional[OauthInstallation]:
35+
query = select(OauthInstallation).filter(OauthInstallation.state == state)
36+
37+
return (await self.session.execute(query)).scalar_one_or_none()

platform/reworkd_platform/db/models/auth.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import String
1+
from sqlalchemy import String, JSON
22
from sqlalchemy.orm import mapped_column
33

44
from reworkd_platform.db.base import TrackedModel
@@ -17,3 +17,23 @@ class OrganizationUser(TrackedModel):
1717
user_id = mapped_column(String, nullable=False)
1818
organization_id = mapped_column(String, nullable=False)
1919
role = mapped_column(String, nullable=False, default="member")
20+
21+
22+
class OauthCredentials(TrackedModel):
23+
__tablename__ = "oauth_credentials"
24+
25+
installation_id = mapped_column(String, nullable=False)
26+
provider = mapped_column(String, nullable=False)
27+
token_type = mapped_column(String, nullable=False)
28+
access_token = mapped_column(String, nullable=False)
29+
scope = mapped_column(String, nullable=True)
30+
data = mapped_column(JSON, nullable=False)
31+
32+
33+
class OauthInstallation(TrackedModel):
34+
__tablename__ = "oauth_installation"
35+
36+
user_id = mapped_column(String, nullable=False)
37+
organization_id = mapped_column(String, nullable=True)
38+
provider = mapped_column(String, nullable=False)
39+
state = mapped_column(String, nullable=False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# from slack.web import WebClient
2+
from abc import ABC, abstractmethod
3+
from typing import TypeVar
4+
5+
from fastapi import Depends, Path
6+
from slack_sdk import WebClient
7+
from slack_sdk.oauth import AuthorizeUrlGenerator
8+
9+
from reworkd_platform.db.crud.oauth import OAuthCrud
10+
from reworkd_platform.db.models.auth import OauthCredentials
11+
from reworkd_platform.schemas import UserBase
12+
from reworkd_platform.settings import Settings, settings as platform_settings
13+
from reworkd_platform.web.api.http_responses import forbidden
14+
15+
T = TypeVar("T", bound="OAuthInstaller")
16+
17+
18+
class OAuthInstaller(ABC):
19+
def __init__(self, crud: OAuthCrud, settings: Settings):
20+
self.crud = crud
21+
self.settings = settings
22+
23+
@abstractmethod
24+
async def install(self, user: UserBase) -> str:
25+
raise NotImplementedError()
26+
27+
@abstractmethod
28+
async def install_callback(self, code: str, state: str) -> None:
29+
raise NotImplementedError()
30+
31+
32+
class SlackInstaller(OAuthInstaller):
33+
PROVIDER = "slack"
34+
35+
async def install(self, user: UserBase) -> str:
36+
installation = await self.crud.create_installation(user, self.PROVIDER)
37+
38+
return AuthorizeUrlGenerator(
39+
client_id=self.settings.slack_client_id,
40+
redirect_uri=self.settings.slack_redirect_uri,
41+
scopes=["chat:write"],
42+
).generate(
43+
state=installation.state,
44+
)
45+
46+
async def install_callback(self, code: str, state: str) -> None:
47+
installation = await self.crud.get_installation_by_state(state)
48+
if not installation:
49+
raise forbidden()
50+
51+
oauth_response = WebClient().oauth_v2_access(
52+
client_id=self.settings.slack_client_id,
53+
client_secret=self.settings.slack_client_secret,
54+
code=code,
55+
state=state,
56+
)
57+
58+
# We should handle token rotation / refresh tokens eventually
59+
# TODO: encode token
60+
await OauthCredentials(
61+
installation_id=installation.id,
62+
provider="slack",
63+
token_type=oauth_response["token_type"],
64+
access_token=oauth_response["access_token"],
65+
scope=oauth_response["scope"],
66+
data=oauth_response.data,
67+
).save(self.crud.session)
68+
69+
70+
integrations = {
71+
SlackInstaller.PROVIDER: SlackInstaller,
72+
}
73+
74+
75+
def installer_factory(
76+
provider: str = Path(description="OAuth Provider"),
77+
crud: OAuthCrud = Depends(OAuthCrud.inject),
78+
) -> OAuthInstaller:
79+
if provider in integrations:
80+
return integrations[provider](crud, platform_settings)
81+
raise NotImplementedError()

platform/reworkd_platform/settings.py

+5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ class Settings(BaseSettings):
102102
ff_mock_mode_enabled: bool = False # Controls whether calls are mocked
103103
max_loops: int = 25 # Maximum number of loops to run
104104

105+
# Settings for slack
106+
slack_client_id: str = ""
107+
slack_client_secret: str = ""
108+
slack_redirect_uri: str = ""
109+
105110
@property
106111
def kafka_consumer_group(self) -> str:
107112
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from reworkd_platform.services.oauth_installers import installer_factory
4+
5+
6+
def test_installer_factory(mocker):
7+
crud = mocker.Mock()
8+
installer_factory("slack", crud)
9+
10+
11+
def test_integration_dne(mocker):
12+
crud = mocker.Mock()
13+
14+
with pytest.raises(NotImplementedError):
15+
installer_factory("asim", crud)

platform/reworkd_platform/tests/workflow/test_if_condition.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@
4242
],
4343
)
4444
async def test_if_condition_success(value_one, operator, value_two, expected_result):
45+
workflow_id = "123"
4546
block = IfCondition(
4647
input=IfInput(value_one=value_one, operator=operator, value_two=value_two)
4748
)
48-
result = await block.run(curr.workflow_id)
49+
result = await block.run(workflow_id)
4950
assert result == IfOutput(result=expected_result)
5051

5152

@@ -59,8 +60,9 @@ async def test_if_condition_success(value_one, operator, value_two, expected_res
5960
],
6061
)
6162
async def test_if_condition_errors(value_one, operator, value_two):
63+
workflow_id = "123"
6264
block = IfCondition(
6365
input=IfInput(value_one=value_one, operator=operator, value_two=value_two)
6466
)
6567
with pytest.raises(ValueError):
66-
await block.run(curr.workflow_id)
68+
await block.run(workflow_id)

platform/reworkd_platform/web/api/auth/views.py

+25
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
from reworkd_platform.db.crud.organization import OrganizationCrud, OrganizationUsers
66
from reworkd_platform.schemas import UserBase
7+
from reworkd_platform.services.oauth_installers import (
8+
installer_factory,
9+
OAuthInstaller,
10+
)
711
from reworkd_platform.services.sockets import websockets
812
from reworkd_platform.web.api.dependencies import get_current_user
913

@@ -44,3 +48,24 @@ async def pusher_authentication(
4448
user: UserBase = Depends(get_current_user),
4549
) -> Dict[str, str]:
4650
return websockets.authenticate(user, channel_name, socket_id)
51+
52+
53+
@router.get("/{provider}")
54+
async def oauth_install(
55+
user: UserBase = Depends(get_current_user),
56+
installer: OAuthInstaller = Depends(installer_factory),
57+
) -> str:
58+
"""Install an OAuth App"""
59+
url = await installer.install(user)
60+
print(url)
61+
return url
62+
63+
64+
@router.get("/{provider}/callback")
65+
async def oauth_callback(
66+
code: str,
67+
state: str,
68+
installer: OAuthInstaller = Depends(installer_factory),
69+
) -> None:
70+
"""Callback for OAuth App"""
71+
return await installer.install_callback(code, state)

scripts/prepare-sync.sh

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
cd "$(dirname "$0")" || exit 1
2+
git reset --hard
3+
24
git fetch origin
35

46
git checkout main

0 commit comments

Comments
 (0)