-
Notifications
You must be signed in to change notification settings - Fork 28
feat: add agent authorization in agent callback #303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
efb7dfa
0ab0f32
ba66c4a
a642866
c370908
8f674d1
bca1468
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,14 +25,17 @@ | |
| import aiohttp | ||
| import volcenginesdkid | ||
| import volcenginesdkcore | ||
| import volcenginesdksts | ||
|
|
||
| from veadk.integrations.ve_identity.models import ( | ||
| AssumeRoleCredential, | ||
| DCRRegistrationRequest, | ||
| DCRRegistrationResponse, | ||
| OAuth2TokenResponse, | ||
| WorkloadToken, | ||
| ) | ||
| from veadk.auth.veauth.utils import get_credential_from_vefaas_iam | ||
| from veadk.configs.auth_configs import VeIdentityConfig | ||
|
|
||
| from veadk.utils.logger import get_logger | ||
|
|
||
|
|
@@ -77,6 +80,20 @@ def _refresh_creds(self: IdentityClient): | |
| except Exception as e: | ||
| logger.warning(f"Failed to retrieve credentials from VeFaaS IAM: {e}") | ||
|
|
||
| # If there is no session_token and role_trn is configured, execute AssumeRole | ||
| if not session_token and self._identity_config.role_trn and ak and sk: | ||
| try: | ||
| logger.info( | ||
| f"No session token found, attempting AssumeRole with role: {self._identity_config.role_trn}" | ||
| ) | ||
| sts_credentials = self._assume_role(ak, sk) | ||
| ak = sts_credentials.access_key_id | ||
| sk = sts_credentials.secret_access_key | ||
| session_token = sts_credentials.session_token | ||
| logger.info("Successfully assumed role and obtained STS credentials") | ||
| except Exception as e: | ||
| logger.warning(f"Failed to assume role: {e}") | ||
|
|
||
| # Update configuration with the credentials | ||
| self._api_client.api_client.configuration.ak = ak | ||
| self._api_client.api_client.configuration.sk = sk | ||
|
|
@@ -115,6 +132,7 @@ def __init__( | |
| secret_key: Optional[str] = None, | ||
| session_token: Optional[str] = None, | ||
| region: str = "cn-beijing", | ||
| identity_config: Optional[VeIdentityConfig] = None, | ||
| ): | ||
| """Initialize the identity client. | ||
|
|
||
|
|
@@ -128,6 +146,8 @@ def __init__( | |
| KeyError: If required environment variables are not set. | ||
| """ | ||
| self.region = region | ||
| self._identity_config = identity_config or VeIdentityConfig() | ||
|
||
|
|
||
| # Store initial credentials for fallback | ||
| self._initial_access_key = access_key or os.getenv("VOLCENGINE_ACCESS_KEY", "") | ||
| self._initial_secret_key = secret_key or os.getenv("VOLCENGINE_SECRET_KEY", "") | ||
|
|
@@ -146,6 +166,56 @@ def __init__( | |
| volcenginesdkcore.ApiClient(configuration) | ||
| ) | ||
|
|
||
| def _assume_role(self, access_key: str, secret_key: str) -> AssumeRoleCredential: | ||
| """Execute AssumeRole to get STS temporary credentials. | ||
|
|
||
| Args: | ||
| access_key: VolcEngine access key | ||
| secret_key: VolcEngine secret key | ||
|
|
||
| Returns: | ||
| AssumeRoleCredential containing temporary credentials | ||
|
|
||
| Raises: | ||
| Exception: If AssumeRole fails | ||
| """ | ||
| # Create STS client configuration | ||
| sts_config = volcenginesdkcore.Configuration() | ||
| sts_config.region = self.region | ||
| sts_config.ak = access_key | ||
| sts_config.sk = secret_key | ||
|
|
||
| # Create an STS API client | ||
| sts_client = volcenginesdksts.STSApi(volcenginesdkcore.ApiClient(sts_config)) | ||
|
|
||
| # Construct an AssumeRole request | ||
| assume_role_request = volcenginesdksts.AssumeRoleRequest( | ||
| role_trn=self._identity_config.role_trn, | ||
| role_session_name=self._identity_config.role_session_name, | ||
| ) | ||
|
|
||
| logger.info( | ||
| f"Executing AssumeRole for role: {self._identity_config.role_trn}, " | ||
| f"session: {self._identity_config.role_session_name}" | ||
| ) | ||
|
|
||
| response: volcenginesdksts.AssumeRoleResponse = sts_client.assume_role( | ||
| assume_role_request | ||
| ) | ||
|
|
||
| if not response.credentials: | ||
| raise Exception("AssumeRole returned no credentials") | ||
|
|
||
| access_key = response["access_key_id"] | ||
| secret_key = response["secret_access_key"] | ||
| session_token = response["session_token"] | ||
|
|
||
| return AssumeRoleCredential( | ||
| access_key_id=access_key, | ||
| secret_access_key=secret_key, | ||
| session_token=session_token, | ||
| ) | ||
|
|
||
| @refresh_credentials | ||
| def create_oauth2_credential_provider( | ||
| self, request_params: Dict[str, Any] | ||
|
|
@@ -533,3 +603,38 @@ async def create_oauth2_credential_provider_with_dcr( | |
|
|
||
| # Create the credential provider with updated config | ||
| return self.create_oauth2_credential_provider(request_params) | ||
|
|
||
| @refresh_credentials | ||
| def check_permission( | ||
| self, principal_id, operation, resource_id, namespace="default" | ||
| ) -> bool: | ||
| """Check if the principal has permission to perform the operation on the resource. | ||
|
|
||
| Args: | ||
| principal_id: The ID of the principal (user or service). | ||
| operation: The operation to check permission for. | ||
| resource_id: The ID of the resource. | ||
| namespace: The namespace of the resource. Defaults to "default". | ||
|
|
||
| Returns: | ||
| True if the principal has permission, False otherwise. | ||
| """ | ||
| logger.info( | ||
| f"Checking permission for principal {principal_id} on resource {resource_id} for operation {operation}..." | ||
| ) | ||
|
|
||
| request = volcenginesdkid.CheckPermissionRequest( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里request的字段不太对 |
||
| principal_id=principal_id, | ||
| operation=operation, | ||
| resource_id=resource_id, | ||
| namespace=namespace, | ||
| ) | ||
|
|
||
| response: volcenginesdkid.CheckPermissionResponse = ( | ||
| self._api_client.check_permission(request) | ||
| ) | ||
|
|
||
| logger.info( | ||
| f"Permission check result for principal {principal_id} on resource {resource_id}: {response.allowed}" | ||
| ) | ||
| return response.allowed | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional | ||
|
|
||
| from google.genai import types | ||
| from google.adk.agents.callback_context import CallbackContext | ||
|
|
||
| from veadk.integrations.ve_identity.auth_config import _get_default_region | ||
| from veadk.integrations.ve_identity.identity_client import IdentityClient | ||
| from veadk.integrations.ve_identity.token_manager import get_workload_token | ||
| from veadk.utils.logger import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| region = _get_default_region() | ||
| identity_client = IdentityClient(region=region) | ||
|
|
||
|
|
||
| async def check_agent_authorization( | ||
| callback_context: CallbackContext, | ||
| ) -> Optional[types.Content]: | ||
| """Check if the agent is authorized to run using VeIdentity.""" | ||
| workload_token = await get_workload_token( | ||
|
||
| tool_context=callback_context, | ||
| identity_client=identity_client, | ||
| ) | ||
|
|
||
| # Parse role_id from workload_token | ||
| # Format: trn:id:${Region}:${Account}:workloadpool/default/workload/${RoleId} | ||
| role_id = None | ||
| if workload_token: | ||
|
||
| try: | ||
| role_id = workload_token.split("/")[-1] | ||
| logger.debug(f"Parsed role_id: {role_id}") | ||
| except Exception as e: | ||
| logger.warning(f"Failed to parse role_id from workload_token: {e}") | ||
|
|
||
| agent_name = callback_context.agent_name | ||
| user_id = callback_context._invocation_context.user_id | ||
|
|
||
| namespace = "default" | ||
|
||
| user_id = user_id | ||
| action = "invoke" | ||
| workload_id = role_id if role_id else agent_name | ||
|
|
||
| allowed = identity_client.check_permission( | ||
| principal_id=user_id, | ||
|
||
| operation=action, | ||
| resource_id=workload_id, | ||
| namespace=namespace, | ||
| ) | ||
|
|
||
| if allowed: | ||
| logger.debug("Agent is authorized to run.") | ||
| return None | ||
| else: | ||
| logger.warning("Agent is not authorized to run.") | ||
| return types.Content( | ||
| parts=[types.Part(text=f"Agent {agent_name} is not authorized to run.")], | ||
| role="model", | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉后面得给这个 sts_credentials 做个缓存每次调用identity接口都需要请求assume role开销有点大