Skip to content

Commit 83d6d82

Browse files
committed
feat: support OAuth credentials provider
1 parent d65cdba commit 83d6d82

17 files changed

+2341
-251
lines changed

alibabacloud_credentials/provider/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .profile import ProfileCredentialsProvider
1111
from .default import DefaultCredentialsProvider
1212
from .cloud_sso import CloudSSOCredentialsProvider
13+
from .oauth import OAuthCredentialsProvider
1314

1415
__all__ = [
1516
'StaticAKCredentialsProvider',
@@ -23,5 +24,6 @@
2324
'CLIProfileCredentialsProvider',
2425
'ProfileCredentialsProvider',
2526
'DefaultCredentialsProvider',
26-
'CloudSSOCredentialsProvider'
27+
'CloudSSOCredentialsProvider',
28+
'OAuthCredentialsProvider'
2729
]

alibabacloud_credentials/provider/cli_profile.py

Lines changed: 223 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import json
3+
import threading
4+
import fcntl
35
from typing import Any, Dict
46

57
import aiofiles
@@ -10,6 +12,7 @@
1012
from .oidc import OIDCRoleArnCredentialsProvider
1113
from .static_sts import StaticSTSCredentialsProvider
1214
from .cloud_sso import CloudSSOCredentialsProvider
15+
from .oauth import OAuthCredentialsProvider, OAuthTokenUpdateCallback, OAuthTokenUpdateCallbackAsync
1316
from .refreshable import Credentials
1417
from alibabacloud_credentials_api import ICredentialsProvider
1518
from alibabacloud_credentials.utils import auth_constant as ac
@@ -32,10 +35,13 @@ def _load_config(file_path: str) -> Any:
3235
class CLIProfileCredentialsProvider(ICredentialsProvider):
3336

3437
def __init__(self, *,
35-
profile_name: str = None):
36-
self._profile_file = os.path.join(ac.HOME, ".aliyun/config.json")
38+
profile_name: str = None,
39+
profile_file: str = None):
40+
self._profile_file = profile_file or os.path.join(ac.HOME, ".aliyun/config.json")
3741
self._profile_name = profile_name or au.environment_profile_name
3842
self.__innerProvider = None
43+
# 文件锁,用于并发安全
44+
self._file_lock = threading.RLock()
3945

4046
def _should_reload_credentials_provider(self) -> bool:
4147
if self.__innerProvider is None:
@@ -175,10 +181,225 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent
175181
access_token=profile.get('access_token'),
176182
access_token_expire=profile.get('cloud_sso_access_token_expire'),
177183
)
184+
elif mode == "OAuth":
185+
# 获取 OAuth 配置
186+
site_type = profile.get('oauth_site_type', 'CN')
187+
oauth_base_url_map = {
188+
'CN': 'https://oauth.aliyun.com',
189+
'INTL': 'https://oauth.alibabacloud.com'
190+
}
191+
sign_in_url = oauth_base_url_map.get(site_type.upper())
192+
if not sign_in_url:
193+
raise CredentialException('Invalid OAuth site type, support CN or INTL')
194+
195+
oauth_client_map = {
196+
'CN': '4038181954557748008',
197+
'INTL': '4103531455503354461'
198+
}
199+
client_id = oauth_client_map.get(site_type.upper())
200+
if not client_id:
201+
raise CredentialException('Invalid OAuth site type, support CN or INTL')
202+
203+
return OAuthCredentialsProvider(
204+
client_id=client_id,
205+
sign_in_url=sign_in_url,
206+
access_token=profile.get('oauth_access_token'),
207+
access_token_expire=profile.get('oauth_access_token_expire'),
208+
refresh_token=profile.get('oauth_refresh_token'),
209+
token_update_callback=self._get_oauth_token_update_callback(),
210+
token_update_callback_async=self._get_oauth_token_update_callback_async(),
211+
)
178212
else:
179213
raise CredentialException(f"unsupported profile mode '{mode}' form cli credentials file.")
180214

181215
raise CredentialException(f"unable to get profile with '{profile_name}' form cli credentials file.")
182216

183217
def get_provider_name(self) -> str:
184218
return 'cli_profile'
219+
220+
def _update_oauth_tokens(self, refresh_token: str, access_token: str, access_key: str, secret: str,
221+
security_token: str, access_token_expire: int, sts_expire: int) -> None:
222+
"""更新 OAuth 令牌并写回配置文件"""
223+
with self._file_lock:
224+
try:
225+
# 读取现有配置
226+
config = _load_config(self._profile_file)
227+
228+
# 找到当前 profile 并更新 OAuth 令牌
229+
profile_name = self._profile_name
230+
if not profile_name:
231+
profile_name = config.get('current')
232+
profiles = config.get('profiles', [])
233+
profile_tag = False
234+
for profile in profiles:
235+
if profile.get('name') == profile_name:
236+
profile_tag = True
237+
# 更新 OAuth 令牌
238+
profile['oauth_refresh_token'] = refresh_token
239+
profile['oauth_access_token'] = access_token
240+
profile['oauth_access_token_expire'] = access_token_expire
241+
# 更新 STS 凭据
242+
profile['access_key_id'] = access_key
243+
profile['access_key_secret'] = secret
244+
profile['sts_token'] = security_token
245+
profile['sts_expiration'] = sts_expire
246+
break
247+
248+
# 写回配置文件
249+
if not profile_tag:
250+
raise CredentialException(f"unable to get profile with '{profile_name}' form cli credentials file.")
251+
252+
self._write_configuration_to_file(self._profile_file, config)
253+
254+
except Exception as e:
255+
raise CredentialException(f"failed to update OAuth tokens in config file: {e}")
256+
257+
def _write_configuration_to_file(self, config_path: str, config: Dict) -> None:
258+
"""将配置写入文件,使用原子写入确保数据完整性"""
259+
temp_file = config_path + '.tmp'
260+
try:
261+
# 序列化配置
262+
data = json.dumps(config, indent=4, ensure_ascii=False)
263+
264+
# 写入临时文件
265+
with open(temp_file, 'w', encoding='utf-8') as f:
266+
f.write(data)
267+
268+
# 原子性重命名,确保文件完整性
269+
os.rename(temp_file, config_path)
270+
271+
except Exception as e:
272+
# 清理临时文件
273+
if os.path.exists(temp_file):
274+
os.remove(temp_file)
275+
raise e
276+
277+
def _write_configuration_to_file_with_lock(self, config_path: str, config: Dict) -> None:
278+
"""使用操作系统级别的文件锁写入配置文件"""
279+
try:
280+
# 打开文件用于锁定
281+
with open(config_path, 'r+') as f:
282+
# 获取独占锁(阻塞其他进程)
283+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
284+
285+
try:
286+
# 序列化配置
287+
data = json.dumps(config, indent=4, ensure_ascii=False)
288+
289+
# 创建临时文件
290+
temp_file = config_path + '.tmp'
291+
with open(temp_file, 'w', encoding='utf-8') as temp_f:
292+
temp_f.write(data)
293+
294+
# 原子性重命名
295+
os.rename(temp_file, config_path)
296+
297+
finally:
298+
# 释放锁
299+
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
300+
301+
except Exception as e:
302+
# 清理临时文件
303+
temp_file = config_path + '.tmp'
304+
if os.path.exists(temp_file):
305+
os.remove(temp_file)
306+
raise e
307+
308+
def _get_oauth_token_update_callback(self) -> OAuthTokenUpdateCallback:
309+
"""获取 OAuth 令牌更新回调函数"""
310+
return lambda refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire: self._update_oauth_tokens(
311+
refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire
312+
)
313+
314+
async def _write_configuration_to_file_async(self, config_path: str, config: Dict) -> None:
315+
"""异步将配置写入文件,使用原子写入确保数据完整性"""
316+
temp_file = config_path + '.tmp'
317+
try:
318+
# 序列化配置
319+
data = json.dumps(config, indent=4, ensure_ascii=False)
320+
321+
# 异步写入临时文件
322+
async with aiofiles.open(temp_file, 'w', encoding='utf-8') as f:
323+
await f.write(data)
324+
325+
# 原子性重命名
326+
os.rename(temp_file, config_path)
327+
328+
except Exception as e:
329+
if os.path.exists(temp_file):
330+
os.remove(temp_file)
331+
raise e
332+
333+
async def _write_configuration_to_file_with_lock_async(self, config_path: str, config: Dict) -> None:
334+
"""异步使用操作系统级别的文件锁写入配置文件"""
335+
try:
336+
# 打开文件用于锁定
337+
with open(config_path, 'r+') as f:
338+
# 获取独占锁(阻塞其他进程)
339+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
340+
341+
try:
342+
# 序列化配置
343+
data = json.dumps(config, indent=4, ensure_ascii=False)
344+
345+
# 创建临时文件
346+
temp_file = config_path + '.tmp'
347+
async with aiofiles.open(temp_file, 'w', encoding='utf-8') as temp_f:
348+
await temp_f.write(data)
349+
350+
# 原子性重命名
351+
os.rename(temp_file, config_path)
352+
353+
finally:
354+
# 释放锁
355+
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
356+
357+
except Exception as e:
358+
# 清理临时文件
359+
temp_file = config_path + '.tmp'
360+
if os.path.exists(temp_file):
361+
os.remove(temp_file)
362+
raise e
363+
364+
async def _update_oauth_tokens_async(self, refresh_token: str, access_token: str, access_key: str, secret: str,
365+
security_token: str, access_token_expire: int, sts_expire: int) -> None:
366+
"""异步更新 OAuth 令牌并写回配置文件"""
367+
try:
368+
with self._file_lock:
369+
cfg_path = self._profile_file
370+
conf = await _load_config_async(cfg_path)
371+
372+
# 找到当前 profile 并更新 OAuth 令牌
373+
profile_name = self._profile_name
374+
if not profile_name:
375+
profile_name = conf.get('current')
376+
profiles = conf.get('profiles', [])
377+
profile_tag = False
378+
for profile in profiles:
379+
if profile.get('name') == profile_name:
380+
profile_tag = True
381+
# 更新 OAuth 相关字段
382+
profile['oauth_refresh_token'] = refresh_token
383+
profile['oauth_access_token'] = access_token
384+
profile['oauth_access_token_expire'] = access_token_expire
385+
# 更新 STS 凭据
386+
profile['access_key_id'] = access_key
387+
profile['access_key_secret'] = secret
388+
profile['sts_token'] = security_token
389+
profile['sts_expiration'] = sts_expire
390+
break
391+
392+
if not profile_tag:
393+
raise CredentialException(f"Profile '{profile_name}' not found in config file")
394+
395+
# 异步写回配置文件
396+
await self._write_configuration_to_file_with_lock_async(cfg_path, conf)
397+
398+
except Exception as e:
399+
raise CredentialException(f"failed to update OAuth tokens in config file: {e}")
400+
401+
def _get_oauth_token_update_callback_async(self) -> OAuthTokenUpdateCallbackAsync:
402+
"""获取异步 OAuth 令牌更新回调函数"""
403+
return lambda refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire: self._update_oauth_tokens_async(
404+
refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire
405+
)

0 commit comments

Comments
 (0)