Skip to content

Commit ac28512

Browse files
committed
feat: support OAuth credentials provider
1 parent d65cdba commit ac28512

19 files changed

+3066
-316
lines changed

alibabacloud_credentials/provider/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .profile import ProfileCredentialsProvider
1111
from .default import DefaultCredentialsProvider
1212
from .cloud_sso import CloudSSOCredentialsProvider
13+
from .oauth import OAuthCredentialsProvider
1314

1415
__all__ = [
1516
'StaticAKCredentialsProvider',
@@ -23,5 +24,6 @@
2324
'CLIProfileCredentialsProvider',
2425
'ProfileCredentialsProvider',
2526
'DefaultCredentialsProvider',
26-
'CloudSSOCredentialsProvider'
27+
'CloudSSOCredentialsProvider',
28+
'OAuthCredentialsProvider'
2729
]

alibabacloud_credentials/provider/cli_profile.py

Lines changed: 400 additions & 2 deletions
Large diffs are not rendered by default.
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import calendar
2+
import json
3+
import logging
4+
import time
5+
from urllib.parse import urlparse, urlencode
6+
from typing import Callable, Optional
7+
8+
from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier
9+
from alibabacloud_credentials.http import HttpOptions
10+
from Tea.core import TeaCore
11+
from alibabacloud_credentials_api import ICredentialsProvider
12+
from alibabacloud_credentials.utils import parameter_helper as ph
13+
from alibabacloud_credentials.exceptions import CredentialException
14+
15+
log = logging.getLogger('credentials')
16+
log.setLevel(logging.INFO)
17+
ch = logging.StreamHandler()
18+
log.addHandler(ch)
19+
20+
# OAuth 令牌更新回调函数类型
21+
OAuthTokenUpdateCallback = Callable[[str, str, str, str, str, int, int], None]
22+
OAuthTokenUpdateCallbackAsync = Callable[[str, str, str, str, str, int, int], None]
23+
24+
25+
def _get_stale_time(expiration: int) -> int:
26+
if expiration < 0:
27+
return int(time.mktime(time.localtime())) + 60 * 60
28+
return expiration - 15 * 60
29+
30+
31+
class OAuthCredentialsProvider(ICredentialsProvider):
32+
DEFAULT_CONNECT_TIMEOUT = 5000
33+
DEFAULT_READ_TIMEOUT = 10000
34+
35+
def __init__(self, *,
36+
client_id: str = None,
37+
sign_in_url: str = None,
38+
access_token: str = None,
39+
access_token_expire: int = 0,
40+
refresh_token: str = None,
41+
http_options: HttpOptions = None,
42+
token_update_callback: Optional[OAuthTokenUpdateCallback] = None,
43+
token_update_callback_async: Optional[OAuthTokenUpdateCallbackAsync] = None):
44+
45+
if not client_id:
46+
raise ValueError('the ClientId is empty')
47+
48+
if not sign_in_url:
49+
raise ValueError('the url for sign-in is empty')
50+
51+
if not refresh_token:
52+
raise ValueError('OAuth access token is empty or expired, please re-login with cli')
53+
54+
self._client_id = client_id
55+
self._sign_in_url = sign_in_url
56+
self._access_token = access_token
57+
self._access_token_expire = access_token_expire
58+
self._refresh_token = refresh_token
59+
self._token_update_callback = token_update_callback
60+
self._token_update_callback_async = token_update_callback_async
61+
62+
self._http_options = http_options if http_options is not None else HttpOptions()
63+
self._runtime_options = {
64+
'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else OAuthCredentialsProvider.DEFAULT_CONNECT_TIMEOUT,
65+
'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else OAuthCredentialsProvider.DEFAULT_READ_TIMEOUT,
66+
'httpsProxy': self._http_options.proxy
67+
}
68+
self._credentials_cache = RefreshCachedSupplier(
69+
refresh_callable=self._refresh_credentials,
70+
refresh_callable_async=self._refresh_credentials_async,
71+
)
72+
73+
def get_credentials(self) -> Credentials:
74+
return self._credentials_cache._sync_call()
75+
76+
async def get_credentials_async(self) -> Credentials:
77+
return await self._credentials_cache._async_call()
78+
79+
def _try_refresh_oauth_token(self) -> None:
80+
current_time = int(time.mktime(time.localtime()))
81+
# 构建刷新令牌请求
82+
r = urlparse(self._sign_in_url)
83+
tea_request = ph.get_new_request()
84+
tea_request.headers['host'] = r.hostname
85+
tea_request.port = r.port
86+
tea_request.protocol = r.scheme
87+
tea_request.method = 'POST'
88+
tea_request.pathname = '/v1/token'
89+
90+
# 设置请求体
91+
body_data = {
92+
'grant_type': 'refresh_token',
93+
'refresh_token': self._refresh_token,
94+
'client_id': self._client_id,
95+
'Timestamp': ph.get_iso_8061_date()
96+
}
97+
tea_request.body = urlencode(body_data)
98+
tea_request.headers['Content-Type'] = 'application/x-www-form-urlencoded'
99+
100+
response = TeaCore.do_action(tea_request, self._runtime_options)
101+
102+
if response.status_code != 200:
103+
raise CredentialException(f'failed to refresh OAuth token, status code: {response.status_code}')
104+
105+
# 解析响应
106+
dic = json.loads(response.body.decode('utf-8'))
107+
if 'access_token' not in dic or 'refresh_token' not in dic:
108+
raise CredentialException(f"failed to refresh OAuth token: {response.body.decode('utf-8')}")
109+
110+
# 更新令牌
111+
new_access_token = dic.get('access_token')
112+
new_refresh_token = dic.get('refresh_token')
113+
expires_in = dic.get('expires_in', 3600)
114+
new_access_token_expire = current_time + expires_in
115+
116+
self._access_token = new_access_token
117+
self._refresh_token = new_refresh_token
118+
self._access_token_expire = new_access_token_expire
119+
120+
async def _try_refresh_oauth_token_async(self) -> None:
121+
current_time = int(time.mktime(time.localtime()))
122+
# 构建刷新令牌请求
123+
r = urlparse(self._sign_in_url)
124+
tea_request = ph.get_new_request()
125+
tea_request.headers['host'] = r.hostname
126+
tea_request.port = r.port
127+
tea_request.protocol = r.scheme
128+
tea_request.method = 'POST'
129+
tea_request.pathname = '/v1/token'
130+
131+
# 设置请求体
132+
body_data = {
133+
'grant_type': 'refresh_token',
134+
'refresh_token': self._refresh_token,
135+
'client_id': self._client_id,
136+
'Timestamp': ph.get_iso_8061_date()
137+
}
138+
tea_request.body = urlencode(body_data)
139+
tea_request.headers['Content-Type'] = 'application/x-www-form-urlencoded'
140+
141+
response = await TeaCore.async_do_action(tea_request, self._runtime_options)
142+
143+
if response.status_code != 200:
144+
raise CredentialException(f'failed to refresh OAuth token, status code: {response.status_code}')
145+
146+
# 解析响应
147+
dic = json.loads(response.body.decode('utf-8'))
148+
if 'access_token' not in dic or 'refresh_token' not in dic:
149+
raise CredentialException(f"failed to refresh OAuth token: {response.body.decode('utf-8')}")
150+
151+
# 更新令牌
152+
new_access_token = dic.get('access_token')
153+
new_refresh_token = dic.get('refresh_token')
154+
expires_in = dic.get('expires_in', 3600)
155+
new_access_token_expire = current_time + expires_in
156+
157+
self._access_token = new_access_token
158+
self._refresh_token = new_refresh_token
159+
self._access_token_expire = new_access_token_expire
160+
161+
def _refresh_credentials(self) -> RefreshResult[Credentials]:
162+
if self._access_token is None or self._access_token_expire <= 0 or self._access_token_expire - int(
163+
time.mktime(time.localtime())) <= 180:
164+
self._try_refresh_oauth_token()
165+
166+
r = urlparse(self._sign_in_url)
167+
tea_request = ph.get_new_request()
168+
tea_request.headers['host'] = r.hostname
169+
tea_request.port = r.port
170+
tea_request.protocol = r.scheme
171+
tea_request.method = 'POST'
172+
tea_request.pathname = '/v1/exchange'
173+
174+
tea_request.headers['Content-Type'] = 'application/json'
175+
tea_request.headers['Authorization'] = f'Bearer {self._access_token}'
176+
177+
response = TeaCore.do_action(tea_request, self._runtime_options)
178+
179+
if response.status_code != 200:
180+
raise CredentialException(
181+
f"error refreshing credentials from OAuth, http_code: {response.status_code}, result: {response.body.decode('utf-8')}")
182+
183+
dic = json.loads(response.body.decode('utf-8'))
184+
if 'error' in dic:
185+
raise CredentialException(
186+
f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}")
187+
188+
if 'accessKeyId' not in dic or 'accessKeySecret' not in dic or 'securityToken' not in dic:
189+
raise CredentialException(
190+
f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}")
191+
192+
# 先转换为时间数组
193+
time_array = time.strptime(dic.get('expiration'), '%Y-%m-%dT%H:%M:%SZ')
194+
# 转换为时间戳
195+
expiration = calendar.timegm(time_array)
196+
credentials = Credentials(
197+
access_key_id=dic.get('accessKeyId'),
198+
access_key_secret=dic.get('accessKeySecret'),
199+
security_token=dic.get('securityToken'),
200+
expiration=expiration,
201+
provider_name=self.get_provider_name()
202+
)
203+
204+
# 调用令牌更新回调函数
205+
if self._token_update_callback:
206+
try:
207+
self._token_update_callback(
208+
self._refresh_token,
209+
self._access_token,
210+
credentials.get_access_key_id(),
211+
credentials.get_access_key_secret(),
212+
credentials.get_security_token(),
213+
self._access_token_expire,
214+
expiration
215+
)
216+
except Exception as e:
217+
log.warning(f'failed to update OAuth tokens in config file: {e}')
218+
219+
return RefreshResult(value=credentials,
220+
stale_time=_get_stale_time(expiration))
221+
222+
async def _refresh_credentials_async(self) -> RefreshResult[Credentials]:
223+
if self._access_token is None or self._access_token_expire <= 0 or self._access_token_expire - int(
224+
time.mktime(time.localtime())) <= 180:
225+
await self._try_refresh_oauth_token_async()
226+
227+
r = urlparse(self._sign_in_url)
228+
tea_request = ph.get_new_request()
229+
tea_request.headers['host'] = r.hostname
230+
tea_request.port = r.port
231+
tea_request.protocol = r.scheme
232+
tea_request.method = 'POST'
233+
tea_request.pathname = '/v1/exchange'
234+
235+
tea_request.headers['Content-Type'] = 'application/json'
236+
tea_request.headers['Authorization'] = f'Bearer {self._access_token}'
237+
238+
response = await TeaCore.async_do_action(tea_request, self._runtime_options)
239+
240+
if response.status_code != 200:
241+
raise CredentialException(
242+
f"error refreshing credentials from OAuth, http_code: {response.status_code}, result: {response.body.decode('utf-8')}")
243+
244+
dic = json.loads(response.body.decode('utf-8'))
245+
if 'error' in dic:
246+
raise CredentialException(
247+
f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}")
248+
249+
if 'accessKeyId' not in dic or 'accessKeySecret' not in dic or 'securityToken' not in dic:
250+
raise CredentialException(
251+
f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}")
252+
253+
# 先转换为时间数组
254+
time_array = time.strptime(dic.get('expiration'), '%Y-%m-%dT%H:%M:%SZ')
255+
# 转换为时间戳
256+
expiration = calendar.timegm(time_array)
257+
credentials = Credentials(
258+
access_key_id=dic.get('accessKeyId'),
259+
access_key_secret=dic.get('accessKeySecret'),
260+
security_token=dic.get('securityToken'),
261+
expiration=expiration,
262+
provider_name=self.get_provider_name()
263+
)
264+
265+
if self._token_update_callback_async:
266+
try:
267+
await self._token_update_callback_async(
268+
self._refresh_token,
269+
self._access_token,
270+
credentials.get_access_key_id(),
271+
credentials.get_access_key_secret(),
272+
credentials.get_security_token(),
273+
self._access_token_expire,
274+
expiration
275+
)
276+
except Exception as e:
277+
log.warning(f'failed to update OAuth tokens in config file: {e}')
278+
279+
return RefreshResult(value=credentials,
280+
stale_time=_get_stale_time(expiration))
281+
282+
def _get_client_id(self) -> str:
283+
"""获取客户端ID"""
284+
return self._client_id
285+
286+
def get_provider_name(self) -> str:
287+
return 'oauth'

tests/provider/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Provider tests package

0 commit comments

Comments
 (0)