Skip to content

Commit 5d8649e

Browse files
committed
feat: support OAuth credentials provider
1 parent f4a958d commit 5d8649e

File tree

5 files changed

+655
-13
lines changed

5 files changed

+655
-13
lines changed

alibabacloud_credentials/provider/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .profile import ProfileCredentialsProvider
1111
from .default import DefaultCredentialsProvider
1212
from .cloud_sso import CloudSSOCredentialsProvider
13+
from .oauth import OAuthCredentialsProvider
1314

1415
__all__ = [
1516
'StaticAKCredentialsProvider',
@@ -23,5 +24,6 @@
2324
'CLIProfileCredentialsProvider',
2425
'ProfileCredentialsProvider',
2526
'DefaultCredentialsProvider',
26-
'CloudSSOCredentialsProvider'
27+
'CloudSSOCredentialsProvider',
28+
'OAuthCredentialsProvider'
2729
]

alibabacloud_credentials/provider/cli_profile.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .oidc import OIDCRoleArnCredentialsProvider
1111
from .static_sts import StaticSTSCredentialsProvider
1212
from .cloud_sso import CloudSSOCredentialsProvider
13+
from .oauth import OAuthCredentialsProvider
1314
from .refreshable import Credentials
1415
from alibabacloud_credentials_api import ICredentialsProvider
1516
from alibabacloud_credentials.utils import auth_constant as ac
@@ -175,6 +176,12 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent
175176
access_token=profile.get('access_token'),
176177
access_token_expire=profile.get('cloud_sso_access_token_expire'),
177178
)
179+
elif mode == "OAuth":
180+
return OAuthCredentialsProvider(
181+
site_type=profile.get('oauth_site_type'),
182+
access_token=profile.get('oauth_access_token'),
183+
access_token_expire=profile.get('oauth_access_token_expire'),
184+
)
178185
else:
179186
raise CredentialException(f"unsupported profile mode '{mode}' form cli credentials file.")
180187

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import calendar
2+
import json
3+
import time
4+
from urllib.parse import urlparse
5+
6+
from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier
7+
from alibabacloud_credentials.http import HttpOptions
8+
from Tea.core import TeaCore
9+
from alibabacloud_credentials_api import ICredentialsProvider
10+
from alibabacloud_credentials.utils import parameter_helper as ph
11+
from alibabacloud_credentials.exceptions import CredentialException
12+
13+
_oauth_base_url_map = {
14+
'CN': 'https://oauth.aliyun.com',
15+
'INTL': 'https://oauth.alibabacloud.com'
16+
}
17+
18+
19+
def _get_stale_time(expiration: int) -> int:
20+
if expiration < 0:
21+
return int(time.mktime(time.localtime())) + 60 * 60
22+
return expiration - 15 * 60
23+
24+
25+
class OAuthCredentialsProvider(ICredentialsProvider):
26+
DEFAULT_CONNECT_TIMEOUT = 5000
27+
DEFAULT_READ_TIMEOUT = 10000
28+
29+
def __init__(self, *,
30+
site_type: str = None,
31+
access_token: str = None,
32+
access_token_expire: int = 0,
33+
http_options: HttpOptions = None):
34+
35+
self._site_type = site_type or 'CN'
36+
self._sign_in_url = _oauth_base_url_map.get(self._site_type.upper())
37+
self._access_token = access_token
38+
self._access_token_expire = access_token_expire
39+
40+
if self._sign_in_url is None:
41+
raise CredentialException('Invalid OAuth site type, support CN or INTL')
42+
43+
if self._access_token is None or self._access_token_expire == 0 or self._access_token_expire - int(
44+
time.mktime(time.localtime())) <= 0:
45+
raise ValueError(
46+
'OAuth access token is empty or expired, please re-login with cli')
47+
48+
self._http_options = http_options if http_options is not None else HttpOptions()
49+
self._runtime_options = {
50+
'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else OAuthCredentialsProvider.DEFAULT_CONNECT_TIMEOUT,
51+
'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else OAuthCredentialsProvider.DEFAULT_READ_TIMEOUT,
52+
'httpsProxy': self._http_options.proxy
53+
}
54+
self._credentials_cache = RefreshCachedSupplier(
55+
refresh_callable=self._refresh_credentials,
56+
refresh_callable_async=self._refresh_credentials_async,
57+
)
58+
59+
def get_credentials(self) -> Credentials:
60+
return self._credentials_cache._sync_call()
61+
62+
async def get_credentials_async(self) -> Credentials:
63+
return await self._credentials_cache._async_call()
64+
65+
def _refresh_credentials(self) -> RefreshResult[Credentials]:
66+
r = urlparse(self._sign_in_url)
67+
tea_request = ph.get_new_request()
68+
tea_request.headers['host'] = r.hostname
69+
tea_request.port = r.port
70+
tea_request.protocol = r.scheme
71+
tea_request.method = 'POST'
72+
tea_request.pathname = '/v1/exchange'
73+
74+
tea_request.headers['Content-Type'] = 'application/json'
75+
tea_request.headers['Authorization'] = f'Bearer {self._access_token}'
76+
77+
response = TeaCore.do_action(tea_request, self._runtime_options)
78+
79+
if response.status_code != 200:
80+
raise CredentialException(
81+
f'error refreshing credentials from OAuth, http_code: {response.status_code}, result: {response.body.decode("utf-8")}')
82+
83+
dic = json.loads(response.body.decode('utf-8'))
84+
if 'error' in dic:
85+
raise CredentialException(
86+
f'error retrieving credentials from OAuth result: {response.body.decode("utf-8")}')
87+
88+
if 'accessKeyId' not in dic or 'accessKeySecret' not in dic or 'securityToken' not in dic:
89+
raise CredentialException(
90+
f'error retrieving credentials from OAuth result: {response.body.decode("utf-8")}')
91+
92+
# 先转换为时间数组
93+
time_array = time.strptime(dic.get('expiration'), '%Y-%m-%dT%H:%M:%SZ')
94+
# 转换为时间戳
95+
expiration = calendar.timegm(time_array)
96+
credentials = Credentials(
97+
access_key_id=dic.get('accessKeyId'),
98+
access_key_secret=dic.get('accessKeySecret'),
99+
security_token=dic.get('securityToken'),
100+
expiration=expiration,
101+
provider_name=self.get_provider_name()
102+
)
103+
return RefreshResult(value=credentials,
104+
stale_time=_get_stale_time(expiration))
105+
106+
async def _refresh_credentials_async(self) -> RefreshResult[Credentials]:
107+
r = urlparse(self._sign_in_url)
108+
tea_request = ph.get_new_request()
109+
tea_request.headers['host'] = r.hostname
110+
tea_request.port = r.port
111+
tea_request.protocol = r.scheme
112+
tea_request.method = 'POST'
113+
tea_request.pathname = '/v1/exchange'
114+
115+
tea_request.headers['Content-Type'] = 'application/json'
116+
tea_request.headers['Authorization'] = f'Bearer {self._access_token}'
117+
118+
response = await TeaCore.async_do_action(tea_request, self._runtime_options)
119+
120+
if response.status_code != 200:
121+
raise CredentialException(
122+
f'error refreshing credentials from OAuth, http_code: {response.status_code}, result: {response.body.decode("utf-8")}')
123+
124+
dic = json.loads(response.body.decode('utf-8'))
125+
if 'error' in dic:
126+
raise CredentialException(
127+
f'error retrieving credentials from OAuth result: {response.body.decode("utf-8")}')
128+
129+
if 'accessKeyId' not in dic or 'accessKeySecret' not in dic or 'securityToken' not in dic:
130+
raise CredentialException(
131+
f'error retrieving credentials from OAuth result: {response.body.decode("utf-8")}')
132+
133+
# 先转换为时间数组
134+
time_array = time.strptime(dic.get('expiration'), '%Y-%m-%dT%H:%M:%SZ')
135+
# 转换为时间戳
136+
expiration = calendar.timegm(time_array)
137+
credentials = Credentials(
138+
access_key_id=dic.get('accessKeyId'),
139+
access_key_secret=dic.get('accessKeySecret'),
140+
security_token=dic.get('securityToken'),
141+
expiration=expiration,
142+
provider_name=self.get_provider_name()
143+
)
144+
return RefreshResult(value=credentials,
145+
stale_time=_get_stale_time(expiration))
146+
147+
def get_provider_name(self) -> str:
148+
return 'oauth'

tests/provider/test_cli_profile.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
RamRoleArnCredentialsProvider,
1616
EcsRamRoleCredentialsProvider,
1717
OIDCRoleArnCredentialsProvider,
18-
CloudSSOCredentialsProvider
18+
CloudSSOCredentialsProvider,
19+
OAuthCredentialsProvider
1920
)
2021
from alibabacloud_credentials.utils import auth_constant as ac
2122

@@ -91,6 +92,13 @@ def setUp(self):
9192
"cloud_sso_access_config": "test_access_config",
9293
"access_token": "test_access_token",
9394
"cloud_sso_access_token_expire": int(time.mktime(time.localtime())) + 1000
95+
},
96+
{
97+
"name": "oauth_profile",
98+
"mode": "OAuth",
99+
"oauth_site_type": "CN",
100+
"oauth_access_token": "test_oauth_access_token",
101+
"oauth_access_token_expire": int(time.mktime(time.localtime())) + 1000
94102
}
95103
]
96104
}
@@ -275,9 +283,29 @@ def test_get_credentials_valid_cloud_sso(self):
275283
self.assertEqual(credentials_provider._access_token, 'test_access_token')
276284
self.assertTrue(credentials_provider._access_token_expire > int(time.mktime(time.localtime())))
277285

286+
def test_get_credentials_valid_oauth(self):
287+
"""
288+
Test case 8: Valid input, successfully retrieves credentials for OAuth mode
289+
"""
290+
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', False):
291+
with patch('os.path.exists', return_value=True):
292+
with patch('os.path.isfile', return_value=True):
293+
with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config):
294+
provider = CLIProfileCredentialsProvider(profile_name="oauth_profile")
295+
296+
credentials_provider = provider._get_credentials_provider(config=self.config,
297+
profile_name="oauth_profile")
298+
299+
self.assertIsInstance(credentials_provider, OAuthCredentialsProvider)
300+
301+
self.assertEqual(credentials_provider._site_type, 'CN')
302+
self.assertEqual(credentials_provider._sign_in_url, 'https://oauth.aliyun.com')
303+
self.assertEqual(credentials_provider._access_token, 'test_oauth_access_token')
304+
self.assertTrue(credentials_provider._access_token_expire > int(time.mktime(time.localtime())))
305+
278306
def test_get_credentials_cli_profile_disabled(self):
279307
"""
280-
Test case 8: CLI profile disabled raises CredentialException
308+
Test case 9: CLI profile disabled raises CredentialException
281309
"""
282310
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'True'):
283311
provider = CLIProfileCredentialsProvider(profile_name=self.profile_name)
@@ -289,7 +317,7 @@ def test_get_credentials_cli_profile_disabled(self):
289317

290318
def test_get_credentials_profile_name_not_exists(self):
291319
"""
292-
Test case 9: Profile file does not exist raises CredentialException
320+
Test case 10: Profile file does not exist raises CredentialException
293321
"""
294322
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
295323
provider = CLIProfileCredentialsProvider(profile_name='not_exists')
@@ -301,7 +329,7 @@ def test_get_credentials_profile_name_not_exists(self):
301329

302330
def test_get_credentials_profile_file_not_exists(self):
303331
"""
304-
Test case 10: Profile file does not exist raises CredentialException
332+
Test case 11: Profile file does not exist raises CredentialException
305333
"""
306334
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
307335
with patch('os.path.exists', return_value=False):
@@ -314,7 +342,7 @@ def test_get_credentials_profile_file_not_exists(self):
314342

315343
def test_get_credentials_profile_file_not_file(self):
316344
"""
317-
Test case 11: Profile file is not a file raises CredentialException
345+
Test case 12: Profile file is not a file raises CredentialException
318346
"""
319347
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
320348
with patch('os.path.exists', return_value=True):
@@ -328,7 +356,7 @@ def test_get_credentials_profile_file_not_file(self):
328356

329357
def test_get_credentials_invalid_json_format(self):
330358
"""
331-
Test case 12: Invalid JSON format in profile file raises CredentialException
359+
Test case 13: Invalid JSON format in profile file raises CredentialException
332360
"""
333361
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
334362
with patch('os.path.exists', return_value=True):
@@ -345,7 +373,7 @@ def test_get_credentials_invalid_json_format(self):
345373

346374
def test_get_credentials_empty_json(self):
347375
"""
348-
Test case 13: Empty JSON in profile file raises CredentialException
376+
Test case 14: Empty JSON in profile file raises CredentialException
349377
"""
350378
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
351379
with patch('os.path.exists', return_value=True):
@@ -361,7 +389,7 @@ def test_get_credentials_empty_json(self):
361389

362390
def test_get_credentials_missing_profiles(self):
363391
"""
364-
Test case 14: Missing profiles in JSON raises CredentialException
392+
Test case 15: Missing profiles in JSON raises CredentialException
365393
"""
366394
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
367395
with patch('os.path.exists', return_value=True):
@@ -378,7 +406,7 @@ def test_get_credentials_missing_profiles(self):
378406

379407
def test_get_credentials_invalid_profile_mode(self):
380408
"""
381-
Test case 15: Invalid profile mode raises CredentialException
409+
Test case 16: Invalid profile mode raises CredentialException
382410
"""
383411
invalid_config = {
384412
"current": "invalid_profile",
@@ -406,7 +434,7 @@ def test_get_credentials_invalid_profile_mode(self):
406434

407435
def test_get_credentials_async_valid_ak(self):
408436
"""
409-
Test case 16: Valid input, successfully retrieves credentials for AK mode
437+
Test case 17: Valid input, successfully retrieves credentials for AK mode
410438
"""
411439
with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'):
412440
with patch('os.path.exists', return_value=True):
@@ -430,7 +458,7 @@ def test_get_credentials_async_valid_ak(self):
430458
@patch('builtins.open', new_callable=MagicMock)
431459
def test_load_config_file_not_found(self, mock_open):
432460
"""
433-
Test case 17: File not found raises FileNotFoundError
461+
Test case 18: File not found raises FileNotFoundError
434462
"""
435463
mock_open.side_effect = FileNotFoundError(f"No such file or directory: '{self.profile_file}'")
436464

@@ -442,7 +470,7 @@ def test_load_config_file_not_found(self, mock_open):
442470
@patch('builtins.open', new_callable=MagicMock)
443471
def test_load_config_invalid_json(self, mock_open):
444472
"""
445-
Test case 18: Invalid JSON format raises json.JSONDecodeError
473+
Test case 19: Invalid JSON format raises json.JSONDecodeError
446474
"""
447475
invalid_json = "invalid json content"
448476
mock_open.return_value.__enter__.return_value.read.return_value = invalid_json

0 commit comments

Comments
 (0)