Skip to content

Commit 95b6c31

Browse files
xiangyan99pvaneck
andauthored
Identity token cred env (Azure#41060)
* Added `AZURE_TOKEN_CREDENTIALS` environment variable support * add trimming * update * updates * udpate tests * fix typo * Update sdk/identity/azure-identity/tests/test_token_credentials_env_async.py Co-authored-by: Paul Van Eck <[email protected]> * Update sdk/identity/azure-identity/tests/test_token_credentials_env.py Co-authored-by: Paul Van Eck <[email protected]> * Update sdk/identity/azure-identity/tests/test_token_credentials_env.py Co-authored-by: Paul Van Eck <[email protected]> * Update sdk/identity/azure-identity/tests/test_token_credentials_env_async.py Co-authored-by: Paul Van Eck <[email protected]> * Update sdk/identity/azure-identity/CHANGELOG.md Co-authored-by: Paul Van Eck <[email protected]> * add tests --------- Co-authored-by: Paul Van Eck <[email protected]>
1 parent 4fe13f2 commit 95b6c31

File tree

7 files changed

+288
-6
lines changed

7 files changed

+288
-6
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Release History
22

3-
## 1.22.1 (Unreleased)
3+
## 1.23.0 (Unreleased)
44

55
### Features Added
66

7+
- Added `AZURE_TOKEN_CREDENTIALS` environment variable to `DefaultAzureCredential` to allow for choosing groups of credentials.
8+
- `prod` for `EnvironmentCredential`, `WorkloadIdentityCredential`, and `ManagedIdentityCredential`.
9+
- `dev` for `SharedTokenCacheCredential`, `AzureCliCredential`, `AzurePowershellCredential`, and `AzureDeveloperCliCredential`.
10+
711
### Breaking Changes
812

913
### Bugs Fixed

sdk/identity/azure-identity/azure/identity/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,5 @@ class EnvironmentVariables:
6464
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"
6565

6666
AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
67+
AZURE_TOKEN_CREDENTIALS = "AZURE_TOKEN_CREDENTIALS"
6768
WORKLOAD_IDENTITY_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)

sdk/identity/azure-identity/azure/identity/_credentials/default.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
133133

134134
process_timeout = kwargs.pop("process_timeout", 10)
135135

136+
token_credentials_env = os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS, "").strip().lower()
136137
exclude_workload_identity_credential = kwargs.pop("exclude_workload_identity_credential", False)
137138
exclude_environment_credential = kwargs.pop("exclude_environment_credential", False)
138139
exclude_managed_identity_credential = kwargs.pop("exclude_managed_identity_credential", False)
@@ -143,6 +144,26 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
143144
exclude_interactive_browser_credential = kwargs.pop("exclude_interactive_browser_credential", True)
144145
exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False)
145146

147+
if token_credentials_env == "dev":
148+
# In dev mode, use only developer credentials
149+
exclude_environment_credential = True
150+
exclude_managed_identity_credential = True
151+
exclude_workload_identity_credential = True
152+
elif token_credentials_env == "prod":
153+
# In prod mode, use only production credentials
154+
exclude_shared_token_cache_credential = True
155+
exclude_visual_studio_code_credential = True
156+
exclude_cli_credential = True
157+
exclude_developer_cli_credential = True
158+
exclude_powershell_credential = True
159+
exclude_interactive_browser_credential = True
160+
elif token_credentials_env != "":
161+
# If the environment variable is set to something other than dev or prod, raise an error
162+
raise ValueError(
163+
f"Invalid value for {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS}: {token_credentials_env}. "
164+
"Valid values are 'dev' or 'prod'."
165+
)
166+
146167
credentials: List[SupportsTokenInfo] = []
147168
within_dac.set(True)
148169
if not exclude_environment_credential:
@@ -155,15 +176,15 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
155176
client_id=cast(str, client_id),
156177
tenant_id=workload_identity_tenant_id,
157178
token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
158-
**kwargs
179+
**kwargs,
159180
)
160181
)
161182
if not exclude_managed_identity_credential:
162183
credentials.append(
163184
ManagedIdentityCredential(
164185
client_id=managed_identity_client_id,
165186
_exclude_workload_identity_credential=exclude_workload_identity_credential,
166-
**kwargs
187+
**kwargs,
167188
)
168189
)
169190
if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported():

sdk/identity/azure-identity/azure/identity/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
VERSION = "1.22.1"
5+
VERSION = "1.23.0"

sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
125125

126126
process_timeout = kwargs.pop("process_timeout", 10)
127127

128+
token_credentials_env = os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS, "").strip().lower()
128129
exclude_workload_identity_credential = kwargs.pop("exclude_workload_identity_credential", False)
129130
exclude_visual_studio_code_credential = kwargs.pop("exclude_visual_studio_code_credential", True)
130131
exclude_developer_cli_credential = kwargs.pop("exclude_developer_cli_credential", False)
@@ -134,6 +135,25 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
134135
exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False)
135136
exclude_powershell_credential = kwargs.pop("exclude_powershell_credential", False)
136137

138+
if token_credentials_env == "dev":
139+
# In dev mode, use only developer credentials
140+
exclude_environment_credential = True
141+
exclude_managed_identity_credential = True
142+
exclude_workload_identity_credential = True
143+
elif token_credentials_env == "prod":
144+
# In prod mode, use only production credentials
145+
exclude_shared_token_cache_credential = True
146+
exclude_visual_studio_code_credential = True
147+
exclude_cli_credential = True
148+
exclude_developer_cli_credential = True
149+
exclude_powershell_credential = True
150+
elif token_credentials_env != "":
151+
# If the environment variable is set to something other than dev or prod, raise an error
152+
raise ValueError(
153+
f"Invalid value for {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS}: {token_credentials_env}. "
154+
"Valid values are 'dev' or 'prod'."
155+
)
156+
137157
credentials: List[AsyncSupportsTokenInfo] = []
138158
within_dac.set(True)
139159
if not exclude_environment_credential:
@@ -146,15 +166,15 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
146166
client_id=cast(str, client_id),
147167
tenant_id=workload_identity_tenant_id,
148168
token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
149-
**kwargs
169+
**kwargs,
150170
)
151171
)
152172
if not exclude_managed_identity_credential:
153173
credentials.append(
154174
ManagedIdentityCredential(
155175
client_id=managed_identity_client_id,
156176
_exclude_workload_identity_credential=exclude_workload_identity_credential,
157-
**kwargs
177+
**kwargs,
158178
)
159179
)
160180
if not exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported():
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import os
6+
from unittest.mock import patch
7+
8+
import pytest
9+
10+
from azure.identity import (
11+
AzureCliCredential,
12+
AzureDeveloperCliCredential,
13+
AzurePowerShellCredential,
14+
DefaultAzureCredential,
15+
EnvironmentCredential,
16+
ManagedIdentityCredential,
17+
SharedTokenCacheCredential,
18+
WorkloadIdentityCredential,
19+
)
20+
from azure.identity._constants import EnvironmentVariables
21+
22+
23+
def test_token_credentials_env_dev():
24+
"""With AZURE_TOKEN_CREDENTIALS=dev, DefaultAzureCredential should use only developer credentials"""
25+
26+
prod_credentials = {EnvironmentCredential, WorkloadIdentityCredential, ManagedIdentityCredential}
27+
28+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "dev"}, clear=False):
29+
credential = DefaultAzureCredential()
30+
31+
# Get the actual credential classes in the chain
32+
actual_classes = {c.__class__ for c in credential.credentials}
33+
34+
# All dev credentials should be present (if supported)
35+
if SharedTokenCacheCredential.supported():
36+
assert SharedTokenCacheCredential in actual_classes
37+
38+
# Other developer credentials should be present
39+
assert AzureCliCredential in actual_classes
40+
assert AzureDeveloperCliCredential in actual_classes
41+
assert AzurePowerShellCredential in actual_classes
42+
43+
# Production credentials should NOT be present
44+
for cred_class in prod_credentials:
45+
if cred_class == WorkloadIdentityCredential:
46+
# Skip this check unless env vars are set
47+
if not all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
48+
continue
49+
assert cred_class not in actual_classes
50+
51+
52+
def test_token_credentials_env_prod():
53+
"""With AZURE_TOKEN_CREDENTIALS=prod, DefaultAzureCredential should use only production credentials"""
54+
55+
dev_credentials = {
56+
SharedTokenCacheCredential,
57+
AzureCliCredential,
58+
AzureDeveloperCliCredential,
59+
AzurePowerShellCredential,
60+
}
61+
62+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "prod"}, clear=False):
63+
# Print to verify the environment variable is set in the test
64+
print(f"AZURE_TOKEN_CREDENTIALS={os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS)}")
65+
66+
credential = DefaultAzureCredential()
67+
68+
# Get the actual credential classes in the chain
69+
actual_classes = {c.__class__ for c in credential.credentials}
70+
71+
# Print which credentials are actually in the chain
72+
print("Credentials in chain:")
73+
for cls in actual_classes:
74+
print(f" - {cls.__name__}")
75+
76+
# Production credentials should be present
77+
assert EnvironmentCredential in actual_classes
78+
assert ManagedIdentityCredential in actual_classes
79+
80+
# Check WorkloadIdentityCredential only if env vars are set
81+
if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
82+
assert WorkloadIdentityCredential in actual_classes
83+
84+
# Developer credentials should NOT be present
85+
for cred_class in dev_credentials:
86+
assert cred_class not in actual_classes
87+
88+
89+
def test_token_credentials_env_case_insensitive():
90+
"""AZURE_TOKEN_CREDENTIALS should be case insensitive"""
91+
92+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "DeV"}, clear=False):
93+
credential = DefaultAzureCredential()
94+
95+
# Get the actual credential classes in the chain
96+
actual_classes = {c.__class__ for c in credential.credentials}
97+
98+
# EnvironmentCredential (prod) should not be present
99+
assert EnvironmentCredential not in actual_classes
100+
101+
# AzureCliCredential (dev) should be present
102+
assert AzureCliCredential in actual_classes
103+
104+
105+
def test_token_credentials_env_invalid():
106+
"""Invalid AZURE_TOKEN_CREDENTIALS value should raise an error"""
107+
108+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "invalid"}, clear=False):
109+
with pytest.raises(ValueError):
110+
credential = DefaultAzureCredential()
111+
112+
113+
def test_token_credentials_env_with_exclude():
114+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "prod"}, clear=False):
115+
credential = DefaultAzureCredential(exclude_environment_credential=True)
116+
actual_classes = {c.__class__ for c in credential.credentials}
117+
118+
assert EnvironmentCredential not in actual_classes
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import os
6+
from unittest.mock import patch
7+
8+
import pytest
9+
10+
from azure.identity.aio import (
11+
AzureCliCredential,
12+
AzureDeveloperCliCredential,
13+
AzurePowerShellCredential,
14+
DefaultAzureCredential,
15+
EnvironmentCredential,
16+
ManagedIdentityCredential,
17+
SharedTokenCacheCredential,
18+
WorkloadIdentityCredential,
19+
)
20+
from azure.identity._constants import EnvironmentVariables
21+
22+
23+
def test_token_credentials_env_dev():
24+
"""With AZURE_TOKEN_CREDENTIALS=dev, DefaultAzureCredential should use only developer credentials"""
25+
26+
prod_credentials = {EnvironmentCredential, WorkloadIdentityCredential, ManagedIdentityCredential}
27+
28+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "dev"}, clear=False):
29+
credential = DefaultAzureCredential()
30+
31+
# Get the actual credential classes in the chain
32+
actual_classes = {c.__class__ for c in credential.credentials}
33+
34+
# All dev credentials should be present (if supported)
35+
if SharedTokenCacheCredential.supported():
36+
assert SharedTokenCacheCredential in actual_classes
37+
38+
# Other developer credentials should be present
39+
assert AzureCliCredential in actual_classes
40+
assert AzureDeveloperCliCredential in actual_classes
41+
assert AzurePowerShellCredential in actual_classes
42+
43+
# Production credentials should NOT be present
44+
for cred_class in prod_credentials:
45+
if cred_class == WorkloadIdentityCredential:
46+
# Skip this check unless env vars are set
47+
if not all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
48+
continue
49+
assert cred_class not in actual_classes
50+
51+
52+
def test_token_credentials_env_prod():
53+
"""With AZURE_TOKEN_CREDENTIALS=prod, DefaultAzureCredential should use only production credentials"""
54+
55+
dev_credentials = {
56+
SharedTokenCacheCredential,
57+
AzureCliCredential,
58+
AzureDeveloperCliCredential,
59+
AzurePowerShellCredential,
60+
}
61+
62+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "prod"}, clear=False):
63+
# Print to verify the environment variable is set in the test
64+
print(f"AZURE_TOKEN_CREDENTIALS={os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS)}")
65+
66+
credential = DefaultAzureCredential()
67+
68+
# Get the actual credential classes in the chain
69+
actual_classes = {c.__class__ for c in credential.credentials}
70+
71+
# Print which credentials are actually in the chain
72+
print("Credentials in chain:")
73+
for cls in actual_classes:
74+
print(f" - {cls.__name__}")
75+
76+
# Production credentials should be present
77+
assert EnvironmentCredential in actual_classes
78+
assert ManagedIdentityCredential in actual_classes
79+
80+
# Check WorkloadIdentityCredential only if env vars are set
81+
if all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS):
82+
assert WorkloadIdentityCredential in actual_classes
83+
84+
# Developer credentials should NOT be present
85+
for cred_class in dev_credentials:
86+
assert cred_class not in actual_classes
87+
88+
89+
def test_token_credentials_env_case_insensitive():
90+
"""AZURE_TOKEN_CREDENTIALS should be case insensitive"""
91+
92+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "DeV"}, clear=False):
93+
credential = DefaultAzureCredential()
94+
95+
# Get the actual credential classes in the chain
96+
actual_classes = {c.__class__ for c in credential.credentials}
97+
98+
# EnvironmentCredential (prod) should not be present
99+
assert EnvironmentCredential not in actual_classes
100+
101+
# AzureCliCredential (dev) should be present
102+
assert AzureCliCredential in actual_classes
103+
104+
105+
def test_token_credentials_env_invalid():
106+
"""Invalid AZURE_TOKEN_CREDENTIALS value should raise an error"""
107+
108+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "invalid"}, clear=False):
109+
with pytest.raises(ValueError):
110+
credential = DefaultAzureCredential()
111+
112+
113+
def test_token_credentials_env_with_exclude():
114+
with patch.dict("os.environ", {EnvironmentVariables.AZURE_TOKEN_CREDENTIALS: "prod"}, clear=False):
115+
credential = DefaultAzureCredential(exclude_environment_credential=True)
116+
actual_classes = {c.__class__ for c in credential.credentials}
117+
118+
assert EnvironmentCredential not in actual_classes

0 commit comments

Comments
 (0)