Skip to content

Commit 128e31a

Browse files
committed
feat: support OAuth credentials provider
1 parent d65cdba commit 128e31a

19 files changed

+2846
-316
lines changed

alibabacloud_credentials/provider/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .profile import ProfileCredentialsProvider
1111
from .default import DefaultCredentialsProvider
1212
from .cloud_sso import CloudSSOCredentialsProvider
13+
from .oauth import OAuthCredentialsProvider
1314

1415
__all__ = [
1516
'StaticAKCredentialsProvider',
@@ -23,5 +24,6 @@
2324
'CLIProfileCredentialsProvider',
2425
'ProfileCredentialsProvider',
2526
'DefaultCredentialsProvider',
26-
'CloudSSOCredentialsProvider'
27+
'CloudSSOCredentialsProvider',
28+
'OAuthCredentialsProvider'
2729
]

alibabacloud_credentials/provider/cli_profile.py

Lines changed: 256 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,33 @@
11
import os
22
import json
3+
import threading
4+
import platform
35
from typing import Any, Dict
46

57
import aiofiles
68

9+
# 跨平台文件锁支持
10+
if platform.system() == 'Windows':
11+
# Windows平台使用msvcrt
12+
import msvcrt
13+
HAS_MSVCRT = True
14+
HAS_FCNTL = False
15+
else:
16+
# 其他平台尝试使用fcntl,如果不可用则不设文件锁
17+
HAS_MSVCRT = False
18+
try:
19+
import fcntl
20+
HAS_FCNTL = True
21+
except ImportError:
22+
HAS_FCNTL = False
23+
724
from .static_ak import StaticAKCredentialsProvider
825
from .ecs_ram_role import EcsRamRoleCredentialsProvider
926
from .ram_role_arn import RamRoleArnCredentialsProvider
1027
from .oidc import OIDCRoleArnCredentialsProvider
1128
from .static_sts import StaticSTSCredentialsProvider
1229
from .cloud_sso import CloudSSOCredentialsProvider
30+
from .oauth import OAuthCredentialsProvider, OAuthTokenUpdateCallback, OAuthTokenUpdateCallbackAsync
1331
from .refreshable import Credentials
1432
from alibabacloud_credentials_api import ICredentialsProvider
1533
from alibabacloud_credentials.utils import auth_constant as ac
@@ -32,10 +50,13 @@ def _load_config(file_path: str) -> Any:
3250
class CLIProfileCredentialsProvider(ICredentialsProvider):
3351

3452
def __init__(self, *,
35-
profile_name: str = None):
36-
self._profile_file = os.path.join(ac.HOME, ".aliyun/config.json")
53+
profile_name: str = None,
54+
profile_file: str = None):
55+
self._profile_file = profile_file or os.path.join(ac.HOME, ".aliyun/config.json")
3756
self._profile_name = profile_name or au.environment_profile_name
3857
self.__innerProvider = None
58+
# 文件锁,用于并发安全
59+
self._file_lock = threading.RLock()
3960

4061
def _should_reload_credentials_provider(self) -> bool:
4162
if self.__innerProvider is None:
@@ -175,10 +196,243 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent
175196
access_token=profile.get('access_token'),
176197
access_token_expire=profile.get('cloud_sso_access_token_expire'),
177198
)
199+
elif mode == "OAuth":
200+
# 获取 OAuth 配置
201+
site_type = profile.get('oauth_site_type', 'CN')
202+
oauth_base_url_map = {
203+
'CN': 'https://oauth.aliyun.com',
204+
'INTL': 'https://oauth.alibabacloud.com'
205+
}
206+
sign_in_url = oauth_base_url_map.get(site_type.upper())
207+
if not sign_in_url:
208+
raise CredentialException('Invalid OAuth site type, support CN or INTL')
209+
210+
oauth_client_map = {
211+
'CN': '4038181954557748008',
212+
'INTL': '4103531455503354461'
213+
}
214+
client_id = oauth_client_map.get(site_type.upper())
215+
if not client_id:
216+
raise CredentialException('Invalid OAuth site type, support CN or INTL')
217+
218+
return OAuthCredentialsProvider(
219+
client_id=client_id,
220+
sign_in_url=sign_in_url,
221+
access_token=profile.get('oauth_access_token'),
222+
access_token_expire=profile.get('oauth_access_token_expire'),
223+
refresh_token=profile.get('oauth_refresh_token'),
224+
token_update_callback=self._get_oauth_token_update_callback(),
225+
token_update_callback_async=self._get_oauth_token_update_callback_async(),
226+
)
178227
else:
179228
raise CredentialException(f"unsupported profile mode '{mode}' form cli credentials file.")
180229

181230
raise CredentialException(f"unable to get profile with '{profile_name}' form cli credentials file.")
182231

183232
def get_provider_name(self) -> str:
184233
return 'cli_profile'
234+
235+
def _update_oauth_tokens(self, refresh_token: str, access_token: str, access_key: str, secret: str,
236+
security_token: str, access_token_expire: int, sts_expire: int) -> None:
237+
"""更新 OAuth 令牌并写回配置文件"""
238+
with self._file_lock:
239+
try:
240+
# 读取现有配置
241+
config = _load_config(self._profile_file)
242+
243+
# 找到当前 profile 并更新 OAuth 令牌
244+
profile_name = self._profile_name
245+
if not profile_name:
246+
profile_name = config.get('current')
247+
profiles = config.get('profiles', [])
248+
profile_tag = False
249+
for profile in profiles:
250+
if profile.get('name') == profile_name:
251+
profile_tag = True
252+
# 更新 OAuth 令牌
253+
profile['oauth_refresh_token'] = refresh_token
254+
profile['oauth_access_token'] = access_token
255+
profile['oauth_access_token_expire'] = access_token_expire
256+
# 更新 STS 凭据
257+
profile['access_key_id'] = access_key
258+
profile['access_key_secret'] = secret
259+
profile['sts_token'] = security_token
260+
profile['sts_expiration'] = sts_expire
261+
break
262+
263+
# 写回配置文件
264+
if not profile_tag:
265+
raise CredentialException(f"unable to get profile with '{profile_name}' form cli credentials file.")
266+
267+
self._write_configuration_to_file(self._profile_file, config)
268+
269+
except Exception as e:
270+
raise CredentialException(f"failed to update OAuth tokens in config file: {e}")
271+
272+
def _write_configuration_to_file(self, config_path: str, config: Dict) -> None:
273+
"""将配置写入文件,使用原子写入确保数据完整性"""
274+
temp_file = config_path + '.tmp'
275+
try:
276+
# 序列化配置
277+
data = json.dumps(config, indent=4, ensure_ascii=False)
278+
279+
# 写入临时文件
280+
with open(temp_file, 'w', encoding='utf-8') as f:
281+
f.write(data)
282+
283+
# 原子性重命名,确保文件完整性
284+
os.rename(temp_file, config_path)
285+
286+
except Exception as e:
287+
# 清理临时文件
288+
if os.path.exists(temp_file):
289+
os.remove(temp_file)
290+
raise e
291+
292+
def _write_configuration_to_file_with_lock(self, config_path: str, config: Dict) -> None:
293+
"""使用操作系统级别的文件锁写入配置文件"""
294+
try:
295+
# 打开文件用于锁定
296+
with open(config_path, 'r+') as f:
297+
# 获取独占锁(阻塞其他进程)
298+
if HAS_MSVCRT:
299+
# Windows使用msvcrt
300+
msvcrt.locking(f.fileno(), msvcrt.LK_NBLCK, 1)
301+
elif HAS_FCNTL:
302+
# Unix/Linux使用fcntl
303+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
304+
# 如果都不支持,则跳过文件锁(仅进程内保护)
305+
306+
try:
307+
# 序列化配置
308+
data = json.dumps(config, indent=4, ensure_ascii=False)
309+
310+
# 创建临时文件
311+
temp_file = config_path + '.tmp'
312+
with open(temp_file, 'w', encoding='utf-8') as temp_f:
313+
temp_f.write(data)
314+
315+
# 原子性重命名
316+
os.rename(temp_file, config_path)
317+
318+
finally:
319+
# 释放锁
320+
if HAS_MSVCRT:
321+
msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
322+
elif HAS_FCNTL:
323+
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
324+
325+
except Exception as e:
326+
# 清理临时文件
327+
temp_file = config_path + '.tmp'
328+
if os.path.exists(temp_file):
329+
os.remove(temp_file)
330+
raise e
331+
332+
def _get_oauth_token_update_callback(self) -> OAuthTokenUpdateCallback:
333+
"""获取 OAuth 令牌更新回调函数"""
334+
return lambda refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire: self._update_oauth_tokens(
335+
refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire
336+
)
337+
338+
async def _write_configuration_to_file_async(self, config_path: str, config: Dict) -> None:
339+
"""异步将配置写入文件,使用原子写入确保数据完整性"""
340+
temp_file = config_path + '.tmp'
341+
try:
342+
# 序列化配置
343+
data = json.dumps(config, indent=4, ensure_ascii=False)
344+
345+
# 异步写入临时文件
346+
async with aiofiles.open(temp_file, 'w', encoding='utf-8') as f:
347+
await f.write(data)
348+
349+
# 原子性重命名
350+
os.rename(temp_file, config_path)
351+
352+
except Exception as e:
353+
if os.path.exists(temp_file):
354+
os.remove(temp_file)
355+
raise e
356+
357+
async def _write_configuration_to_file_with_lock_async(self, config_path: str, config: Dict) -> None:
358+
"""异步使用操作系统级别的文件锁写入配置文件"""
359+
try:
360+
# 打开文件用于锁定
361+
with open(config_path, 'r+') as f:
362+
# 获取独占锁(阻塞其他进程)
363+
if HAS_MSVCRT:
364+
# Windows使用msvcrt
365+
msvcrt.locking(f.fileno(), msvcrt.LK_NBLCK, 1)
366+
elif HAS_FCNTL:
367+
# Unix/Linux使用fcntl
368+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
369+
# 如果都不支持,则跳过文件锁(仅进程内保护)
370+
371+
try:
372+
# 序列化配置
373+
data = json.dumps(config, indent=4, ensure_ascii=False)
374+
375+
# 创建临时文件
376+
temp_file = config_path + '.tmp'
377+
async with aiofiles.open(temp_file, 'w', encoding='utf-8') as temp_f:
378+
await temp_f.write(data)
379+
380+
# 原子性重命名
381+
os.rename(temp_file, config_path)
382+
383+
finally:
384+
# 释放锁
385+
if HAS_MSVCRT:
386+
msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
387+
elif HAS_FCNTL:
388+
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
389+
390+
except Exception as e:
391+
# 清理临时文件
392+
temp_file = config_path + '.tmp'
393+
if os.path.exists(temp_file):
394+
os.remove(temp_file)
395+
raise e
396+
397+
async def _update_oauth_tokens_async(self, refresh_token: str, access_token: str, access_key: str, secret: str,
398+
security_token: str, access_token_expire: int, sts_expire: int) -> None:
399+
"""异步更新 OAuth 令牌并写回配置文件"""
400+
try:
401+
with self._file_lock:
402+
cfg_path = self._profile_file
403+
conf = await _load_config_async(cfg_path)
404+
405+
# 找到当前 profile 并更新 OAuth 令牌
406+
profile_name = self._profile_name
407+
if not profile_name:
408+
profile_name = conf.get('current')
409+
profiles = conf.get('profiles', [])
410+
profile_tag = False
411+
for profile in profiles:
412+
if profile.get('name') == profile_name:
413+
profile_tag = True
414+
# 更新 OAuth 相关字段
415+
profile['oauth_refresh_token'] = refresh_token
416+
profile['oauth_access_token'] = access_token
417+
profile['oauth_access_token_expire'] = access_token_expire
418+
# 更新 STS 凭据
419+
profile['access_key_id'] = access_key
420+
profile['access_key_secret'] = secret
421+
profile['sts_token'] = security_token
422+
profile['sts_expiration'] = sts_expire
423+
break
424+
425+
if not profile_tag:
426+
raise CredentialException(f"Profile '{profile_name}' not found in config file")
427+
428+
# 异步写回配置文件
429+
await self._write_configuration_to_file_with_lock_async(cfg_path, conf)
430+
431+
except Exception as e:
432+
raise CredentialException(f"failed to update OAuth tokens in config file: {e}")
433+
434+
def _get_oauth_token_update_callback_async(self) -> OAuthTokenUpdateCallbackAsync:
435+
"""获取异步 OAuth 令牌更新回调函数"""
436+
return lambda refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire: self._update_oauth_tokens_async(
437+
refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire
438+
)

0 commit comments

Comments
 (0)