11import os
22import json
3+ import threading
4+ import fcntl
35from typing import Any , Dict
46
57import aiofiles
1012from .oidc import OIDCRoleArnCredentialsProvider
1113from .static_sts import StaticSTSCredentialsProvider
1214from .cloud_sso import CloudSSOCredentialsProvider
15+ from .oauth import OAuthCredentialsProvider , OAuthTokenUpdateCallback , OAuthTokenUpdateCallbackAsync
1316from .refreshable import Credentials
1417from alibabacloud_credentials_api import ICredentialsProvider
1518from alibabacloud_credentials .utils import auth_constant as ac
@@ -32,10 +35,13 @@ def _load_config(file_path: str) -> Any:
3235class CLIProfileCredentialsProvider (ICredentialsProvider ):
3336
3437 def __init__ (self , * ,
35- profile_name : str = None ):
36- self ._profile_file = os .path .join (ac .HOME , ".aliyun/config.json" )
38+ profile_name : str = None ,
39+ profile_file : str = None ):
40+ self ._profile_file = profile_file or os .path .join (ac .HOME , ".aliyun/config.json" )
3741 self ._profile_name = profile_name or au .environment_profile_name
3842 self .__innerProvider = None
43+ # 文件锁,用于并发安全
44+ self ._file_lock = threading .RLock ()
3945
4046 def _should_reload_credentials_provider (self ) -> bool :
4147 if self .__innerProvider is None :
@@ -175,10 +181,225 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent
175181 access_token = profile .get ('access_token' ),
176182 access_token_expire = profile .get ('cloud_sso_access_token_expire' ),
177183 )
184+ elif mode == "OAuth" :
185+ # 获取 OAuth 配置
186+ site_type = profile .get ('oauth_site_type' , 'CN' )
187+ oauth_base_url_map = {
188+ 'CN' : 'https://oauth.aliyun.com' ,
189+ 'INTL' : 'https://oauth.alibabacloud.com'
190+ }
191+ sign_in_url = oauth_base_url_map .get (site_type .upper ())
192+ if not sign_in_url :
193+ raise CredentialException ('Invalid OAuth site type, support CN or INTL' )
194+
195+ oauth_client_map = {
196+ 'CN' : '4038181954557748008' ,
197+ 'INTL' : '4103531455503354461'
198+ }
199+ client_id = oauth_client_map .get (site_type .upper ())
200+ if not client_id :
201+ raise CredentialException ('Invalid OAuth site type, support CN or INTL' )
202+
203+ return OAuthCredentialsProvider (
204+ client_id = client_id ,
205+ sign_in_url = sign_in_url ,
206+ access_token = profile .get ('oauth_access_token' ),
207+ access_token_expire = profile .get ('oauth_access_token_expire' ),
208+ refresh_token = profile .get ('oauth_refresh_token' ),
209+ token_update_callback = self ._get_oauth_token_update_callback (),
210+ token_update_callback_async = self ._get_oauth_token_update_callback_async (),
211+ )
178212 else :
179213 raise CredentialException (f"unsupported profile mode '{ mode } ' form cli credentials file." )
180214
181215 raise CredentialException (f"unable to get profile with '{ profile_name } ' form cli credentials file." )
182216
183217 def get_provider_name (self ) -> str :
184218 return 'cli_profile'
219+
220+ def _update_oauth_tokens (self , refresh_token : str , access_token : str , access_key : str , secret : str ,
221+ security_token : str , access_token_expire : int , sts_expire : int ) -> None :
222+ """更新 OAuth 令牌并写回配置文件"""
223+ with self ._file_lock :
224+ try :
225+ # 读取现有配置
226+ config = _load_config (self ._profile_file )
227+
228+ # 找到当前 profile 并更新 OAuth 令牌
229+ profile_name = self ._profile_name
230+ if not profile_name :
231+ profile_name = config .get ('current' )
232+ profiles = config .get ('profiles' , [])
233+ profile_tag = False
234+ for profile in profiles :
235+ if profile .get ('name' ) == profile_name :
236+ profile_tag = True
237+ # 更新 OAuth 令牌
238+ profile ['oauth_refresh_token' ] = refresh_token
239+ profile ['oauth_access_token' ] = access_token
240+ profile ['oauth_access_token_expire' ] = access_token_expire
241+ # 更新 STS 凭据
242+ profile ['access_key_id' ] = access_key
243+ profile ['access_key_secret' ] = secret
244+ profile ['sts_token' ] = security_token
245+ profile ['sts_expiration' ] = sts_expire
246+ break
247+
248+ # 写回配置文件
249+ if not profile_tag :
250+ raise CredentialException (f"unable to get profile with '{ profile_name } ' form cli credentials file." )
251+
252+ self ._write_configuration_to_file (self ._profile_file , config )
253+
254+ except Exception as e :
255+ raise CredentialException (f"failed to update OAuth tokens in config file: { e } " )
256+
257+ def _write_configuration_to_file (self , config_path : str , config : Dict ) -> None :
258+ """将配置写入文件,使用原子写入确保数据完整性"""
259+ temp_file = config_path + '.tmp'
260+ try :
261+ # 序列化配置
262+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
263+
264+ # 写入临时文件
265+ with open (temp_file , 'w' , encoding = 'utf-8' ) as f :
266+ f .write (data )
267+
268+ # 原子性重命名,确保文件完整性
269+ os .rename (temp_file , config_path )
270+
271+ except Exception as e :
272+ # 清理临时文件
273+ if os .path .exists (temp_file ):
274+ os .remove (temp_file )
275+ raise e
276+
277+ def _write_configuration_to_file_with_lock (self , config_path : str , config : Dict ) -> None :
278+ """使用操作系统级别的文件锁写入配置文件"""
279+ try :
280+ # 打开文件用于锁定
281+ with open (config_path , 'r+' ) as f :
282+ # 获取独占锁(阻塞其他进程)
283+ fcntl .flock (f .fileno (), fcntl .LOCK_EX )
284+
285+ try :
286+ # 序列化配置
287+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
288+
289+ # 创建临时文件
290+ temp_file = config_path + '.tmp'
291+ with open (temp_file , 'w' , encoding = 'utf-8' ) as temp_f :
292+ temp_f .write (data )
293+
294+ # 原子性重命名
295+ os .rename (temp_file , config_path )
296+
297+ finally :
298+ # 释放锁
299+ fcntl .flock (f .fileno (), fcntl .LOCK_UN )
300+
301+ except Exception as e :
302+ # 清理临时文件
303+ temp_file = config_path + '.tmp'
304+ if os .path .exists (temp_file ):
305+ os .remove (temp_file )
306+ raise e
307+
308+ def _get_oauth_token_update_callback (self ) -> OAuthTokenUpdateCallback :
309+ """获取 OAuth 令牌更新回调函数"""
310+ return lambda refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire : self ._update_oauth_tokens (
311+ refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire
312+ )
313+
314+ async def _write_configuration_to_file_async (self , config_path : str , config : Dict ) -> None :
315+ """异步将配置写入文件,使用原子写入确保数据完整性"""
316+ temp_file = config_path + '.tmp'
317+ try :
318+ # 序列化配置
319+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
320+
321+ # 异步写入临时文件
322+ async with aiofiles .open (temp_file , 'w' , encoding = 'utf-8' ) as f :
323+ await f .write (data )
324+
325+ # 原子性重命名
326+ os .rename (temp_file , config_path )
327+
328+ except Exception as e :
329+ if os .path .exists (temp_file ):
330+ os .remove (temp_file )
331+ raise e
332+
333+ async def _write_configuration_to_file_with_lock_async (self , config_path : str , config : Dict ) -> None :
334+ """异步使用操作系统级别的文件锁写入配置文件"""
335+ try :
336+ # 打开文件用于锁定
337+ with open (config_path , 'r+' ) as f :
338+ # 获取独占锁(阻塞其他进程)
339+ fcntl .flock (f .fileno (), fcntl .LOCK_EX )
340+
341+ try :
342+ # 序列化配置
343+ data = json .dumps (config , indent = 4 , ensure_ascii = False )
344+
345+ # 创建临时文件
346+ temp_file = config_path + '.tmp'
347+ async with aiofiles .open (temp_file , 'w' , encoding = 'utf-8' ) as temp_f :
348+ await temp_f .write (data )
349+
350+ # 原子性重命名
351+ os .rename (temp_file , config_path )
352+
353+ finally :
354+ # 释放锁
355+ fcntl .flock (f .fileno (), fcntl .LOCK_UN )
356+
357+ except Exception as e :
358+ # 清理临时文件
359+ temp_file = config_path + '.tmp'
360+ if os .path .exists (temp_file ):
361+ os .remove (temp_file )
362+ raise e
363+
364+ async def _update_oauth_tokens_async (self , refresh_token : str , access_token : str , access_key : str , secret : str ,
365+ security_token : str , access_token_expire : int , sts_expire : int ) -> None :
366+ """异步更新 OAuth 令牌并写回配置文件"""
367+ try :
368+ with self ._file_lock :
369+ cfg_path = self ._profile_file
370+ conf = await _load_config_async (cfg_path )
371+
372+ # 找到当前 profile 并更新 OAuth 令牌
373+ profile_name = self ._profile_name
374+ if not profile_name :
375+ profile_name = conf .get ('current' )
376+ profiles = conf .get ('profiles' , [])
377+ profile_tag = False
378+ for profile in profiles :
379+ if profile .get ('name' ) == profile_name :
380+ profile_tag = True
381+ # 更新 OAuth 相关字段
382+ profile ['oauth_refresh_token' ] = refresh_token
383+ profile ['oauth_access_token' ] = access_token
384+ profile ['oauth_access_token_expire' ] = access_token_expire
385+ # 更新 STS 凭据
386+ profile ['access_key_id' ] = access_key
387+ profile ['access_key_secret' ] = secret
388+ profile ['sts_token' ] = security_token
389+ profile ['sts_expiration' ] = sts_expire
390+ break
391+
392+ if not profile_tag :
393+ raise CredentialException (f"Profile '{ profile_name } ' not found in config file" )
394+
395+ # 异步写回配置文件
396+ await self ._write_configuration_to_file_with_lock_async (cfg_path , conf )
397+
398+ except Exception as e :
399+ raise CredentialException (f"failed to update OAuth tokens in config file: { e } " )
400+
401+ def _get_oauth_token_update_callback_async (self ) -> OAuthTokenUpdateCallbackAsync :
402+ """获取异步 OAuth 令牌更新回调函数"""
403+ return lambda refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire : self ._update_oauth_tokens_async (
404+ refresh_token , access_token , access_key , secret , security_token , access_token_expire , sts_expire
405+ )
0 commit comments