Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions src/iam-mcp-server/awslabs/iam_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import argparse
import json
from awslabs.iam_mcp_server.aws_client import get_iam_client
from awslabs.iam_mcp_server.context import Context
from awslabs.iam_mcp_server.context import Context as ServerContext
from awslabs.iam_mcp_server.errors import IamClientError, IamValidationError, handle_iam_error
from awslabs.iam_mcp_server.models import (
AccessKey,
Expand All @@ -37,8 +37,7 @@
UsersListResponse,
)
from loguru import logger
from mcp.server.fastmcp import FastMCP
from mcp.types import CallToolResult
from mcp.server.fastmcp import Context, FastMCP
from pydantic import Field
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -86,7 +85,7 @@

@mcp.tool()
async def list_users(
ctx: CallToolResult,
ctx: Context,
path_prefix: Optional[str] = Field(
description='Path prefix to filter users (e.g., "/division_abc/")', default=None
),
Expand Down Expand Up @@ -154,7 +153,7 @@ async def list_users(

@mcp.tool()
async def get_user(
ctx: CallToolResult, user_name: str = Field(description='The name of the IAM user to retrieve')
ctx: Context, user_name: str = Field(description='The name of the IAM user to retrieve')
) -> UserDetailsResponse:
"""Get detailed information about a specific IAM user.

Expand Down Expand Up @@ -242,7 +241,7 @@ async def get_user(

@mcp.tool()
async def create_user(
ctx: CallToolResult,
ctx: Context,
user_name: str = Field(description='The name of the new IAM user'),
path: str = Field(description='The path for the user', default='/'),
permissions_boundary: Optional[str] = Field(
Expand Down Expand Up @@ -273,7 +272,7 @@ async def create_user(
logger.info(f'Creating IAM user: {user_name}')

# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot create user: server is running in read-only mode')

if not user_name:
Expand Down Expand Up @@ -332,7 +331,7 @@ async def delete_user(
"""
try:
# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot delete user: server is running in read-only mode')

iam = get_iam_client()
Expand Down Expand Up @@ -448,7 +447,7 @@ async def create_role(
"""
try:
# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot create role: server is running in read-only mode')

iam = get_iam_client()
Expand Down Expand Up @@ -633,7 +632,7 @@ async def attach_user_policy(
"""
try:
# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot attach policy: server is running in read-only mode')

iam = get_iam_client()
Expand Down Expand Up @@ -666,7 +665,7 @@ async def detach_user_policy(
"""
try:
# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot detach policy: server is running in read-only mode')

iam = get_iam_client()
Expand Down Expand Up @@ -697,7 +696,7 @@ async def create_access_key(
"""
try:
# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot create access key: server is running in read-only mode')

iam = get_iam_client()
Expand Down Expand Up @@ -737,7 +736,7 @@ async def delete_access_key(
"""
try:
# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError('Cannot delete access key: server is running in read-only mode')

iam = get_iam_client()
Expand Down Expand Up @@ -840,7 +839,7 @@ async def list_groups(
Returns:
GroupsListResponse containing list of groups and metadata
"""
if Context.is_readonly():
if ServerContext.is_readonly():
# List operations are allowed in read-only mode
pass

Expand Down Expand Up @@ -896,7 +895,7 @@ async def get_group(
Returns:
GroupDetailsResponse containing comprehensive group information
"""
if Context.is_readonly():
if ServerContext.is_readonly():
# Get operations are allowed in read-only mode
pass

Expand Down Expand Up @@ -962,7 +961,7 @@ async def create_group(
Returns:
CreateGroupResponse containing the created group details
"""
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamValidationError('Cannot create group in read-only mode')

try:
Expand Down Expand Up @@ -1003,7 +1002,7 @@ async def delete_group(
Returns:
Dictionary containing deletion status
"""
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamValidationError('Cannot delete group in read-only mode')

try:
Expand Down Expand Up @@ -1048,7 +1047,7 @@ async def add_user_to_group(
Returns:
GroupMembershipResponse containing operation status
"""
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamValidationError('Cannot add user to group in read-only mode')

try:
Expand Down Expand Up @@ -1079,7 +1078,7 @@ async def remove_user_from_group(
Returns:
GroupMembershipResponse containing operation status
"""
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamValidationError('Cannot remove user from group in read-only mode')

try:
Expand Down Expand Up @@ -1110,7 +1109,7 @@ async def attach_group_policy(
Returns:
GroupPolicyAttachmentResponse containing operation status
"""
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamValidationError('Cannot attach policy to group in read-only mode')

try:
Expand Down Expand Up @@ -1141,7 +1140,7 @@ async def detach_group_policy(
Returns:
GroupPolicyAttachmentResponse containing operation status
"""
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamValidationError('Cannot detach policy from group in read-only mode')

try:
Expand Down Expand Up @@ -1193,7 +1192,7 @@ async def put_user_policy(
logger.info(f'Creating/updating inline policy {policy_name} for user: {user_name}')

# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError(
'Cannot create/update inline policy: server is running in read-only mode'
)
Expand Down Expand Up @@ -1299,7 +1298,7 @@ async def delete_user_policy(
logger.info(f'Deleting inline policy {policy_name} from user: {user_name}')

# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError(
'Cannot delete inline policy: server is running in read-only mode'
)
Expand Down Expand Up @@ -1355,7 +1354,7 @@ async def put_role_policy(
logger.info(f'Creating/updating inline policy {policy_name} for role: {role_name}')

# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError(
'Cannot create/update inline policy: server is running in read-only mode'
)
Expand Down Expand Up @@ -1461,7 +1460,7 @@ async def delete_role_policy(
logger.info(f'Deleting inline policy {policy_name} from role: {role_name}')

# Check if server is in read-only mode
if Context.is_readonly():
if ServerContext.is_readonly():
raise IamClientError(
'Cannot delete inline policy: server is running in read-only mode'
)
Expand Down Expand Up @@ -1583,7 +1582,7 @@ def main():

# Set read-only mode if specified
if args.readonly:
Context.set_readonly(True)
ServerContext.set_readonly(True)
logger.info('Server started in READ-ONLY mode - all mutating operations are disabled')
else:
logger.info('Server started in FULL ACCESS mode')
Expand Down
92 changes: 92 additions & 0 deletions test-iam-ctx-fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""Test script to verify the ctx parameter fix."""

import asyncio
import os
import sys

# Add the src directory to the path so we can import the modified server
sys.path.insert(0, '/home/plex/development/repos/aws/mcp/src/iam-mcp-server')

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client


async def main():
# Use the local modified version via python -m
server_params = StdioServerParameters(
command="python",
args=["-m", "awslabs.iam_mcp_server.server", "--readonly"],
env={
"AWS_PROFILE": os.environ.get("AWS_PROFILE", "default"),
"AWS_REGION": os.environ.get("AWS_REGION", "us-east-1"),
"PYTHONPATH": "/home/plex/development/repos/aws/mcp/src/iam-mcp-server"
}
)

print("=" * 70)
print("Testing IAM MCP Server ctx Parameter Fix")
print("=" * 70)
print()

async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
print("✓ Connected to server\n")

# Test 1: List tools and check schema
print("Test 1: Check list_users tool schema")
print("-" * 50)
tools_result = await session.list_tools()
list_users_tool = [t for t in tools_result.tools if t.name == "list_users"][0]

schema = list_users_tool.inputSchema
if "properties" in schema:
props = schema["properties"]
if "ctx" in props:
print("❌ FAIL: ctx is still in the schema!")
print(f" Properties: {list(props.keys())}")
return False
else:
print("✅ PASS: ctx is not in the schema")
print(f" Properties: {list(props.keys())}")
print()

# Test 2: Call the tool
print("Test 2: Call list_users tool")
print("-" * 50)
result = await session.call_tool("list_users", arguments={"max_items": 3})

if result.isError:
print(f"❌ FAIL: Tool returned error")
print(f" Error: {result.content[0].text if result.content else 'Unknown'}")
return False
else:
print("✅ PASS: Tool call succeeded!")
if result.content:
import json
from mcp import types
content_block = result.content[0]
if isinstance(content_block, types.TextContent):
data = json.loads(content_block.text)
users = data.get('Users', [])
print(f" Retrieved {len(users)} users")
if users:
print(f" First user: {users[0].get('UserName')}")

print()
print("=" * 70)
print("✅ ALL TESTS PASSED - Fix verified!")
print("=" * 70)
return True


if __name__ == "__main__":
try:
success = asyncio.run(main())
sys.exit(0 if success else 1)
except Exception as e:
print(f"\n❌ Test failed with exception: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
Loading