11import os
22import json
3+ import threading
4+ import platform
35from typing import Any , Dict
46
57import aiofiles
68
9+ # 跨平台文件锁支持
10+ if platform .system () == 'Windows' :
11+ # Windows平台使用msvcrt
12+ import msvcrt
13+ HAS_MSVCRT = True
14+ HAS_FCNTL = False
15+ else :
16+ # 其他平台尝试使用fcntl,如果不可用则不设文件锁
17+ HAS_MSVCRT = False
18+ try :
19+ import fcntl
20+ HAS_FCNTL = True
21+ except ImportError :
22+ HAS_FCNTL = False
23+
724from .static_ak import StaticAKCredentialsProvider
825from .ecs_ram_role import EcsRamRoleCredentialsProvider
926from .ram_role_arn import RamRoleArnCredentialsProvider
1027from .oidc import OIDCRoleArnCredentialsProvider
1128from .static_sts import StaticSTSCredentialsProvider
1229from .cloud_sso import CloudSSOCredentialsProvider
30+ from .oauth import OAuthCredentialsProvider , OAuthTokenUpdateCallback , OAuthTokenUpdateCallbackAsync
1331from .refreshable import Credentials
1432from alibabacloud_credentials_api import ICredentialsProvider
1533from alibabacloud_credentials .utils import auth_constant as ac
@@ -32,10 +50,13 @@ def _load_config(file_path: str) -> Any:
3250class CLIProfileCredentialsProvider (ICredentialsProvider ):
3351
3452 def __init__ (self , * ,
35- profile_name : str = None ):
36- self ._profile_file = os .path .join (ac .HOME , ".aliyun/config.json" )
53+ profile_name : str = None ,
54+ profile_file : str = None ):
55+ self ._profile_file = profile_file or os .path .join (ac .HOME , ".aliyun/config.json" )
3756 self ._profile_name = profile_name or au .environment_profile_name
3857 self .__innerProvider = None
58+ # 文件锁,用于并发安全
59+ self ._file_lock = threading .RLock ()
3960
4061 def _should_reload_credentials_provider (self ) -> bool :
4162 if self .__innerProvider is None :
@@ -175,10 +196,243 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent
175196 access_token = profile .get ('access_token' ),
176197 access_token_expire = profile .get ('cloud_sso_access_token_expire' ),
177198 )
199+ elif mode == "OAuth" :
200+ # 获取 OAuth 配置
201+ site_type = profile .get ('oauth_site_type' , 'CN' )
202+ oauth_base_url_map = {
203+ 'CN' : 'https://oauth.aliyun.com' ,
204+ 'INTL' : 'https://oauth.alibabacloud.com'
205+ }
206+ sign_in_url = oauth_base_url_map .get (site_type .upper ())
207+ if not sign_in_url :
208+ raise CredentialException ('Invalid OAuth site type, support CN or INTL' )
209+
210+ oauth_client_map = {
211+ 'CN' : '4038181954557748008' ,
212+ 'INTL' : '4103531455503354461'
213+ }
214+ client_id = oauth_client_map .get (site_type .upper ())
215+ if not client_id :
216+ raise CredentialException ('Invalid OAuth site type, support CN or INTL' )
217+
218+ return OAuthCredentialsProvider (
219+ client_id = client_id ,
220+ sign_in_url = sign_in_url ,
221+ access_token = profile .get ('oauth_access_token' ),
222+ access_token_expire = profile .get ('oauth_access_token_expire' ),
223+ refresh_token = profile .get ('oauth_refresh_token' ),
224+ token_update_callback = self ._get_oauth_token_update_callback (),
225+ token_update_callback_async = self ._get_oauth_token_update_callback_async (),
226+ )
178227 else :
179228 raise CredentialException (f"unsupported profile mode '{ mode } ' form cli credentials file." )
180229
181230 raise CredentialException (f"unable to get profile with '{ profile_name } ' form cli credentials file." )
182231
183232 def get_provider_name (self ) -> str :
184233 return 'cli_profile'
234+
235+ def _update_oauth_tokens (self , refresh_token : str , access_token : str , access_key : str , secret : str ,
236+ security_token : str , access_token_expire : int , sts_expire : int ) -> None :
237+ """更新 OAuth 令牌并写回配置文件"""
238+ with self ._file_lock :
239+ try :
240+ # 读取现有配置
241+ config = _load_config (self ._profile_file )
242+
243+ # 找到当前 profile 并更新 OAuth 令牌
244+ profile_name = self ._profile_name
245+ if not profile_name :
246+ profile_name = config .get ('current' )
247+ profiles = config .get ('profiles' , [])
248+ profile_tag = False
249+ for profile in profiles :
250+ if profile .get ('name' ) == profile_name :
251+ profile_tag = True
252+ # 更新 OAuth 令牌
253+ profile ['oauth_refresh_token' ] = refresh_token
254+ profile ['oauth_access_token' ] = access_token
255+ profile ['oauth_access_token_expire' ] = access_token_expire
256+ # 更新 STS 凭据
257+ profile ['access_key_id' ] = access_key
258+ profile ['access_key_secret' ] = secret
259+ profile ['sts_token' ] = security_token
260+ profile ['sts_expiration' ] = sts_expire
261+ break
262+
263+ # 写回配置文件
264+ if not profile_tag :
265+ raise CredentialException (f"unable to get profile with '{ profile_name } ' form cli credentials file." )
266+
267+ self ._write_configuration_to_file (self ._profile_file , config )
268+
269+ except Exception as e :
270+ raise CredentialException (f"failed to update OAuth tokens in config file: { e } " )
271+
272+ def _write_configuration_to_file (self , config_path : str , config : Dict ) -> None :
273+ """将配置写入文件,使用原子写入确保数据完整性"""
274+ temp_file = config_path + '.tmp'
275+ try :
276+ # 序列化配置
277+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
278+
279+ # 写入临时文件
280+ with open (temp_file , 'w' , encoding = 'utf-8' ) as f :
281+ f .write (data )
282+
283+ # 原子性重命名,确保文件完整性
284+ os .rename (temp_file , config_path )
285+
286+ except Exception as e :
287+ # 清理临时文件
288+ if os .path .exists (temp_file ):
289+ os .remove (temp_file )
290+ raise e
291+
292+ def _write_configuration_to_file_with_lock (self , config_path : str , config : Dict ) -> None :
293+ """使用操作系统级别的文件锁写入配置文件"""
294+ try :
295+ # 打开文件用于锁定
296+ with open (config_path , 'r+' ) as f :
297+ # 获取独占锁(阻塞其他进程)
298+ if HAS_MSVCRT :
299+ # Windows使用msvcrt
300+ msvcrt .locking (f .fileno (), msvcrt .LK_NBLCK , 1 )
301+ elif HAS_FCNTL :
302+ # Unix/Linux使用fcntl
303+ fcntl .flock (f .fileno (), fcntl .LOCK_EX )
304+ # 如果都不支持,则跳过文件锁(仅进程内保护)
305+
306+ try :
307+ # 序列化配置
308+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
309+
310+ # 创建临时文件
311+ temp_file = config_path + '.tmp'
312+ with open (temp_file , 'w' , encoding = 'utf-8' ) as temp_f :
313+ temp_f .write (data )
314+
315+ # 原子性重命名
316+ os .rename (temp_file , config_path )
317+
318+ finally :
319+ # 释放锁
320+ if HAS_MSVCRT :
321+ msvcrt .locking (f .fileno (), msvcrt .LK_UNLCK , 1 )
322+ elif HAS_FCNTL :
323+ fcntl .flock (f .fileno (), fcntl .LOCK_UN )
324+
325+ except Exception as e :
326+ # 清理临时文件
327+ temp_file = config_path + '.tmp'
328+ if os .path .exists (temp_file ):
329+ os .remove (temp_file )
330+ raise e
331+
332+ def _get_oauth_token_update_callback (self ) -> OAuthTokenUpdateCallback :
333+ """获取 OAuth 令牌更新回调函数"""
334+ return lambda refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire : self ._update_oauth_tokens (
335+ refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire
336+ )
337+
338+ async def _write_configuration_to_file_async (self , config_path : str , config : Dict ) -> None :
339+ """异步将配置写入文件,使用原子写入确保数据完整性"""
340+ temp_file = config_path + '.tmp'
341+ try :
342+ # 序列化配置
343+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
344+
345+ # 异步写入临时文件
346+ async with aiofiles .open (temp_file , 'w' , encoding = 'utf-8' ) as f :
347+ await f .write (data )
348+
349+ # 原子性重命名
350+ os .rename (temp_file , config_path )
351+
352+ except Exception as e :
353+ if os .path .exists (temp_file ):
354+ os .remove (temp_file )
355+ raise e
356+
357+ async def _write_configuration_to_file_with_lock_async (self , config_path : str , config : Dict ) -> None :
358+ """异步使用操作系统级别的文件锁写入配置文件"""
359+ try :
360+ # 打开文件用于锁定
361+ with open (config_path , 'r+' ) as f :
362+ # 获取独占锁(阻塞其他进程)
363+ if HAS_MSVCRT :
364+ # Windows使用msvcrt
365+ msvcrt .locking (f .fileno (), msvcrt .LK_NBLCK , 1 )
366+ elif HAS_FCNTL :
367+ # Unix/Linux使用fcntl
368+ fcntl .flock (f .fileno (), fcntl .LOCK_EX )
369+ # 如果都不支持,则跳过文件锁(仅进程内保护)
370+
371+ try :
372+ # 序列化配置
373+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
374+
375+ # 创建临时文件
376+ temp_file = config_path + '.tmp'
377+ async with aiofiles .open (temp_file , 'w' , encoding = 'utf-8' ) as temp_f :
378+ await temp_f .write (data )
379+
380+ # 原子性重命名
381+ os .rename (temp_file , config_path )
382+
383+ finally :
384+ # 释放锁
385+ if HAS_MSVCRT :
386+ msvcrt .locking (f .fileno (), msvcrt .LK_UNLCK , 1 )
387+ elif HAS_FCNTL :
388+ fcntl .flock (f .fileno (), fcntl .LOCK_UN )
389+
390+ except Exception as e :
391+ # 清理临时文件
392+ temp_file = config_path + '.tmp'
393+ if os .path .exists (temp_file ):
394+ os .remove (temp_file )
395+ raise e
396+
397+ async def _update_oauth_tokens_async (self , refresh_token : str , access_token : str , access_key : str , secret : str ,
398+ security_token : str , access_token_expire : int , sts_expire : int ) -> None :
399+ """异步更新 OAuth 令牌并写回配置文件"""
400+ try :
401+ with self ._file_lock :
402+ cfg_path = self ._profile_file
403+ conf = await _load_config_async (cfg_path )
404+
405+ # 找到当前 profile 并更新 OAuth 令牌
406+ profile_name = self ._profile_name
407+ if not profile_name :
408+ profile_name = conf .get ('current' )
409+ profiles = conf .get ('profiles' , [])
410+ profile_tag = False
411+ for profile in profiles :
412+ if profile .get ('name' ) == profile_name :
413+ profile_tag = True
414+ # 更新 OAuth 相关字段
415+ profile ['oauth_refresh_token' ] = refresh_token
416+ profile ['oauth_access_token' ] = access_token
417+ profile ['oauth_access_token_expire' ] = access_token_expire
418+ # 更新 STS 凭据
419+ profile ['access_key_id' ] = access_key
420+ profile ['access_key_secret' ] = secret
421+ profile ['sts_token' ] = security_token
422+ profile ['sts_expiration' ] = sts_expire
423+ break
424+
425+ if not profile_tag :
426+ raise CredentialException (f"Profile '{ profile_name } ' not found in config file" )
427+
428+ # 异步写回配置文件
429+ await self ._write_configuration_to_file_with_lock_async (cfg_path , conf )
430+
431+ except Exception as e :
432+ raise CredentialException (f"failed to update OAuth tokens in config file: { e } " )
433+
434+ def _get_oauth_token_update_callback_async (self ) -> OAuthTokenUpdateCallbackAsync :
435+ """获取异步 OAuth 令牌更新回调函数"""
436+ return lambda refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire : self ._update_oauth_tokens_async (
437+ refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire
438+ )
0 commit comments